By Zibai
As Large Language Models (LLMs) continue to grow in size, a single GPU can no longer handle the entire model. For instance, the Qwen-14B-Chat model has weights totaling around 28 GB, while a single NVIDIA A10 GPU has only 24 GB of memory. To deploy the Qwen-14B-Chat model, we need to split it and distribute it across two A10 GPUs, with each GPU loading half of the model. This approach is known as distributed inference. Several frameworks support distributed inference, including vLLM, DeepSpeed-MII, and RTP-LLM.
This article focuses on the vLLM framework, exploring how to implement distributed inference with vLLM and Ray from a source code perspective.
Download the Qwen-14B-Chat model to OSS and create the corresponding PV and PVC in the cluster. Name the PVC llm-model.
kubectl apply -f- << EOF
apiVersion: v1
kind: Secret
metadata:
name: oss-secret
stringData:
akId: ${your-accesskey-id} # The AccessKey used to access OSS
akSecret: ${your-accesskey-secert} # The SecretKey used to access OSS
---
apiVersion: v1
kind: PersistentVolume
metadata:
name: llm-model
labels:
alicloud-pvname: llm-model
spec:
capacity:
storage: 30Gi
accessModes:
- ReadOnlyMany
persistentVolumeReclaimPolicy: Retain
csi:
driver: ossplugin.csi.alibabacloud.com
volumeHandle: model-oss
nodePublishSecretRef:
name: oss-secret
namespace: default
volumeAttributes:
bucket: ${your-bucket-name}
url: ${your-bucket-endpoint} # e.g. oss-cn-hangzhou.aliyuncs.com
otherOpts: "-o umask=022 -o max_stat_cache_size=0 -o allow_other"
path: "/"
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: llm-model
spec:
accessModes:
- ReadOnlyMany
resources:
requests:
storage: 30Gi
selector:
matchLabels:
alicloud-pvname: llm-model
EOF
kubectl apply -f- <<EOF
apiVersion: apps/v1
kind: Deployment
metadata:
name: vllm
labels:
app: vllm
spec:
replicas: 2
selector:
matchLabels:
app: vllm
template:
metadata:
labels:
app: vllm
spec:
affinity:
podAntiAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchExpressions:
- key: app
operator: In
values:
- vllm
topologyKey: kubernetes.io/hostname
volumes:
- name: model
persistentVolumeClaim:
claimName: llm-model
containers:
- name: vllm
image: kube-ai-registry.cn-shanghai.cr.aliyuncs.com/kube-ai/vllm:0.4.1
command:
- "sh"
- "-c"
- "sleep 7d"
ports:
- containerPort: 8080
readinessProbe:
tcpSocket:
port: 8080
initialDelaySeconds: 30
periodSeconds: 30
resources:
limits:
nvidia.com/gpu: "1"
requests:
cpu: 4
memory: 8Gi
nvidia.com/gpu: "1"
volumeMounts:
- mountPath: /mnt/models
name: model
EOF
● Start Ray
ray start –head
# After Ray is started, the log displays the ray-head-address.
# Set the ray-head-address to the address displayed in the Pod1 log
ray start --address=<ray-head-address>
● Run the following command to initialize the local model on Pod2:
python3 model_init.py
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
config = AutoConfig.from_pretrained(
"/mnt/models/Qwen-14B-Chat",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("/mnt/models/Qwen-14B-Chat", trust_remote_code=True)
● On Pod1, run the following command to start the Qwen model:
python3 -m vllm.entrypoints.openai.api_server \
--port 8080 \
--trust-remote-code \
--served-model-name qwen \
--model /mnt/models/Qwen-14B-Chat \
--gpu-memory-utilization 0.95 \
--tensor-parallel-size 2
● Log in to Pod1 to access the application.
kubectl -n <your-namespace> exec -it <pod1-name> bash
curl -H "Content-Type: application/json" \
http://localhost:8080/v1/chat/completions -X POST \
-d '{"model": "qwen", "messages": [{"role": "user", "content": "你好"}], "max_tokens": 512, "temperature": 0.7, "top_p": 0.9, "seed": 10, "stop":["<|endoftext|>", "<|im_end|>", "<|im_start|>"]}'
if __name__ == "__main__":
# build engine args
engine_args = AsyncEngineArgs.from_cli_args(args)
# build engine
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model_names, args.lora_modules)
app.root_path = args.root_path
uvicorn.run(app)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
def from_engine_args():
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
# Ray cluster initialization
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
# Create the engine configs.
engine_config = engine_args.create_engine_config()
# Ray cluster initialization
# 1. ray.init()
# 2. Set the Ray placement strategy based on the number of GPUs in the cluster and the TP degree.
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
# Create the async LLM engine.
engine = cls(...) #Create an AsyncLLMEngine instance.
# AsyncLLMEngine.__init__ -> self._init_engine -> _AsyncLLMEngine.__init__ -> LLMEngine.__init__ -> executor_class() Call RayGPUExecutorAsync.__init__
Setting up Ray workers involves initializing both the Ray cluster and the Ray workers themselves. During this process, the model is loaded in a distributed manner.
# The RayGPUExecutorAsync inherits the RayGPUExecutor and ExecutorAsyncBase classes and calls the self._init_executor method of RayGPUExecutor during initialization.
def _init_executor(self) -> None:
# Create the parallel GPU workers. Core code for initializing workers:
self._init_workers_ray(placement_group)
def _init_workers_ray():
# Define the worker, using the Worker class from the vllm.worker.worker module.
# The actor is a RayWorkerWrapper class
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
trust_remote_code=self.model_config.trust_remote_code,
)
# Execute the following methods in sequence on the Ray worker
self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)
def _run_workers():
# Start the ray workers first.
ray_worker_outputs = [
# The worker is the RayWorkerWrapper class defined earlier, inherited from the RayWorkerWrapper class.
# The RayWorkerWrapper.execute_method is actually called and the method is executed on the remote instance.
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
def init_worker():
# worker_module_name is vllm.worker.worker, which is passed in the _init_workers_ray method.
mod = importlib.import_module(self.worker_module_name)
# Worker
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
# Worker.__init__ -> ModelRunner.__init__
def init_device():
# Initialize the machine information for distributed inference
"""Initialize the distributed environment."""
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
def load_model():
self.model_runner.load_model() # ModelRunner.load_model() -> vllm.model_executor.model_loader.loader.load_model
Below is the expected log output after load_model() is executed. The log should show two pods, each loading 13.2845 GB, which is half of the model.
INFO 04-26 09:39:46 model_runner.py:173] Loading model weights took 13.2845 GB
(RayWorkerWrapper pid=3327, ip=192.168.12.132) INFO 04-26 09:39:51 model_runner.py:173] Loading model weights took 13.2845 GB
Create instances of OpenAIServingChat and OpenAIServingCompletion, then start uvicorn to provide these services externally.
@app.post("/v1/chat/completions")
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
@app.post("/v1/completions")
openai_serving_completion = OpenAIServingCompletion(
engine, served_model_names, args.lora_modules)
app.root_path = args.root_path
uvicorn.run(app)
When --tensor-parallel-size
is greater than 1, ray automatically triggers a distributed deployment.
# Ray cluster initialization
initialize_ray_cluster(engine_config.parallel_config)
The parallel_config is configured as follows: pp = 1, tp = 2, and world_size = 2
{'pipeline_parallel_size': 1, 'tensor_parallel_size': 2, 'worker_use_ray': True, 'max_parallel_loading_workers': None, 'disable_custom_all_reduce': False, 'tokenizer_pool_config': None, 'ray_workers_use_nsight': False, 'placement_group': None, 'world_size': 2}
A placement_group is created for the worker process during initialization.
(1) Determine the total number of GPUs in the Ray cluster.
(2) Request gpu placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
based on the world size.
(3) Create placement_group, and ray will start an actor on the corresponding node.
# Obtain the node and information of the GPU assigned to the node.
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids
# The node and GPU information obtained in the previous step.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
# Parameters used to start the worker
init_worker_all_kwargs = []
# worker_node_and_gpu_ids is the information of the GPU on the worker obtained in Step 2.
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
init_worker_all_kwargs.append(
collect_arg_helper_func(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
))
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
The core method init_worker_distributed_environment constructs the world information for the distributed cluster, similar to world info in Horovod and DeepSpeed.
The parameters for this method are as follows:
work1: self.rank=0, self.local_rank=0, self.distributed_init_method="tcp://192.168.12.120:42167" (ray master)
{'pipeline_parallel_size': 1, 'tensor_parallel_size': 2, 'worker_use_ray': True, 'max_parallel_loading_workers': None, 'disable_custom_all_reduce': False, 'tokenizer_pool_config': None, 'ray_workers_use_nsight': False, 'placement_group': <ray.util.placement_group.PlacementGroup object at 0x7fdeaa896ce0>, 'world_size': 2}, {'id': PlacementGroupID(51489eb26a9335f31ed1bdb4eace04000000), 'bundle_cache': [{'GPU': 1.0}, {'GPU': 1.0}]}, self.rank=0, tcp://192.168.12.120:42167, self.local_rank=0
work2: self.rank=1, self.local_rank=0,self.distributed_init_method="tcp://192.168.12.120:42167"
{'pipeline_parallel_size': 1, 'tensor_parallel_size': 2, 'worker_use_ray': True, 'max_parallel_loading_workers': None, 'disable_custom_all_reduce': False, 'tokenizer_pool_config': None, 'ray_workers_use_nsight': False, 'world_size': 2}, self.rank=1, tcp://192.168.12.120:42167, self.local_rank=0
self.rank is incremented globally. self.local_rank refers to the sequence number of a GPU in a pod.
load_model is used to load a distributed model. It is complex and is described in the following section.
def load_model():
self.model_runner.load_model()
# ModelRunner.load_model() -> vllm.model_executor.model_loader.loader.load_model
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
# Use get_model to obtain the model.
self.model = get_model(
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
)
self.model_memory_usage = m.consumed_memory
logger.info(f"Loading model weights took "
f"{self.model_memory_usage / float(2**30):.4f} GB")
# get_model -> loader.load_model -> DefaultModelLoader.load_model
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
"""Initialize a model with the given configurations."""
# Initialize the model.
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
# Call the load_weights method of the corresponding model.
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
fall_back_to_pt=getattr(
model,
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules():
linear_method = getattr(module, "linear_method", None)
if linear_method is not None:
linear_method.process_weights_after_loading(module)
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
return model.eval()
# Find the specific model based on the model configuration.
def _initialize_model(
model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
"""Initialize a model with the given configurations."""
# The architecture field in the config.json file of Qwen-7B-Chat.
model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config, load_config)
return model_class(config=model_config.hf_config,
linear_method=linear_method,
**_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config))
# model_class refers to <class 'vllm.model_executor.models.qwen.QWenLMHeadModel'>
# QWenLMHeadModel.load_weights
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1),
]
# The weight of each layer of the model and its name.
# self.named_parameters refers to model.named_parameters()
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# name: transformer.h.27.mlp.c_proj.weight
# loaded_weight: tensor(xxx)
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
# If found in stacked_params_mapping, shard_name should be replaced with param_name.
# For example, if name is transformer.h.0.mlp.w1.weight, it should be changed to transformer.h.0.mlp.gate_up_proj.weight.
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Python's for-else syntax; reaching here means no break statement was executed in the loop.
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
# Locate the corresponding weight_loader method for the name.
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# param,weight_loader
lm_head.weight, weight_loader <bound method VocabParallelEmbedding.weight_loader of ParallelLMHead()>
transformer.h.0.attn.c_attn.weight, weight_loader <bound method QKVParallelLinear.weight_loader of QKVParallelLinear()>
transformer.h.0.attn.c_proj.weight, weight_loader <bound method RowParallelLinear.weight_loader of RowParallelLinear()>
transformer.h.0.ln_1.weight, weight_loader <function default_weight_loader at 0x7f66201ee0e0>
transformer.h.0.ln_2.weight, weight_loader <function default_weight_loader at 0x7f66201ee0e0>
transformer.h.0.mlp.c_proj.weight, weight_loader <bound method RowParallelLinear.weight_loader of RowParallelLinear()>
transformer.h.0.mlp.gate_up_proj.weight, weight_loader <bound method MergedColumnParallelLinear.weight_loader of MergedColumnParallelLinear()>
transformer.ln_f.weight, weight_loader <function default_weight_loader at 0x7f66201ee0e0>
transformer.wte.weight, weight_loader <bound method VocabParallelEmbedding.weight_loader of VocabParallelEmbedding()>
Each layer of the model has its own distributed loading method. For example, transformer.h.0.attn.c_proj.weight uses RowParallelLinear.weight_loader.
class RowParallelLinear(torch.nn.Module):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Obtain the tp_rank of the worker and calculate the weight range to be loaded based on the tp_rank.
tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = param.data
if input_dim is not None:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
The model is split using the Megatron-LM algorithm. For more information, please refer to this paper.
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#
(1) Reduce: aggregates the computing results of all GPUs to a specific GPU.
(2) Broadcast: synchronizes data from a GPU to all GPUs.
(3) AllReduce = Reduce + Broadcast
A Transformer layer consists of a self-attention block followed by a two-layer multi-layer perceptron (MLP).
As shown in the figure, the MLP consists of two parts. GeLU is a nonlinear function,
so it cannot use row parallelism and must use column parallelism.
In this case, B needs to use row parallelism. If B uses column parallelism, an all-reduce synchronization is required.
Dropout randomly discards some parameters at a certain ratio, so an all-reduce operation must be performed before Dropout.
In the multi-head attention mechanism, each attention head has its own QKV matrices, and each GPU only calculates part of the attention heads. Therefore, the number of attention heads must be divisible by tp_size. Otherwise, the following error will occur (Qwen-14b with tp=3):
ValueError: Total number of attention heads (40) must be divisible by tensor parallel size (3).
Similarly, an all-reduce operation is required before Dropout.
As a result, each Transformer inference requires two all-reduce operations. In Qwen-14b, there are 40 Transformer layers, so one inference requires 81 all-reduce operations. When deploying inference services across nodes, network communication will be a significant overhead.
[01] Inference Process Analysis
https://zhuanlan.zhihu.com/p/649974825 (In Chinese)
[02] [Deep Learning] [Distributed Training] Pipeline Parallelism, Tensor Parallelism, and 3D Parallelism
https://zhuanlan.zhihu.com/p/617087561 (In Chinese)
[03] Efficient Training Techniques from Hugging Face (IV): Multi-GPU Distributed Training (DP, PP, TP, ZeRO) _zero-dp
https://blog.csdn.net/qq_56591814/article/details/134099476 (In Chinese)
[04] Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
https://arxiv.org/pdf/1909.08053
The Age of Cloud-native: Building Efficient CI Pipeline from Jenkins to Argo Workflows
Best Practices for Large Model Inference in ACK: TensorRT-LLM
164 posts | 29 followers
FollowAlibaba Container Service - November 15, 2024
Alibaba Cloud Community - October 31, 2024
Farruh - January 12, 2024
Alibaba EMR - September 30, 2019
Alibaba Cloud Community - October 10, 2024
Alibaba Cloud Community - December 8, 2021
164 posts | 29 followers
FollowAlibaba Cloud Container Service for Kubernetes is a fully managed cloud container management service that supports native Kubernetes and integrates with other Alibaba Cloud products.
Learn MoreTop-performance foundation models from Alibaba Cloud
Learn MoreProvides a control plane to allow users to manage Kubernetes clusters that run based on different infrastructure resources
Learn MoreAccelerate and secure the development, deployment, and management of containerized applications cost-effectively.
Learn MoreMore Posts by Alibaba Container Service