×
Community Blog Analyzing the Distributed Inference Process Using vLLM and Ray from the Perspective of Source Code

Analyzing the Distributed Inference Process Using vLLM and Ray from the Perspective of Source Code

This article explores how to implement distributed inference with vLLM and Ray from a source code perspective.

By Zibai

1. Introduction

As Large Language Models (LLMs) continue to grow in size, a single GPU can no longer handle the entire model. For instance, the Qwen-14B-Chat model has weights totaling around 28 GB, while a single NVIDIA A10 GPU has only 24 GB of memory. To deploy the Qwen-14B-Chat model, we need to split it and distribute it across two A10 GPUs, with each GPU loading half of the model. This approach is known as distributed inference. Several frameworks support distributed inference, including vLLM, DeepSpeed-MII, and RTP-LLM.

This article focuses on the vLLM framework, exploring how to implement distributed inference with vLLM and Ray from a source code perspective.

2. Deploying the vLLM Distributed Inference Application in Kubernetes

2.1 Model preparation

Download the Qwen-14B-Chat model to OSS and create the corresponding PV and PVC in the cluster. Name the PVC llm-model.

kubectl apply -f- << EOF
apiVersion: v1
kind: Secret
metadata:
  name: oss-secret
stringData:
  akId: ${your-accesskey-id} # The AccessKey used to access OSS
  akSecret: ${your-accesskey-secert} # The SecretKey used to access OSS
---
apiVersion: v1
kind: PersistentVolume
metadata:
  name: llm-model
  labels:
    alicloud-pvname: llm-model
spec:
  capacity:
    storage: 30Gi 
  accessModes:
    - ReadOnlyMany
  persistentVolumeReclaimPolicy: Retain
  csi:
    driver: ossplugin.csi.alibabacloud.com
    volumeHandle: model-oss
    nodePublishSecretRef:
      name: oss-secret
      namespace: default
    volumeAttributes:
      bucket: ${your-bucket-name}
      url: ${your-bucket-endpoint} # e.g. oss-cn-hangzhou.aliyuncs.com
      otherOpts: "-o umask=022 -o max_stat_cache_size=0 -o allow_other"
      path: "/"
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
  name: llm-model
spec:
  accessModes:
    - ReadOnlyMany
  resources:
    requests:
      storage: 30Gi
  selector:
    matchLabels:
      alicloud-pvname: llm-model
EOF

2.2 Deploying the Distributed vLLM Application

1

1. Run the following command to deploy the vLLM application

kubectl apply -f- <<EOF
apiVersion: apps/v1 
kind: Deployment
metadata:
  name: vllm
  labels:
    app: vllm
spec:
  replicas: 2
  selector:
    matchLabels:
      app: vllm
  template:
    metadata:
      labels:
        app: vllm
    spec:
      affinity:
        podAntiAffinity:
          requiredDuringSchedulingIgnoredDuringExecution:
          - labelSelector:
              matchExpressions:
              - key: app
                operator: In
                values:
                - vllm
            topologyKey: kubernetes.io/hostname
      volumes:
      - name: model
        persistentVolumeClaim:
          claimName: llm-model
      containers:
      - name: vllm
        image: kube-ai-registry.cn-shanghai.cr.aliyuncs.com/kube-ai/vllm:0.4.1
        command:
        - "sh"
        - "-c"
        - "sleep 7d"
        ports:
        - containerPort: 8080
        readinessProbe:
          tcpSocket:
            port: 8080
          initialDelaySeconds: 30
          periodSeconds: 30
        resources:
          limits:
            nvidia.com/gpu: "1"
          requests:
            cpu: 4
            memory: 8Gi
            nvidia.com/gpu: "1"
        volumeMounts:
        - mountPath: /mnt/models
          name: model
EOF

2. Run the following command to start the vLLM application

● Start Ray

  • On Pod1, run:
ray start –head
# After Ray is started, the log displays the ray-head-address.
  • On Pod2, run:
# Set the ray-head-address to the address displayed in the Pod1 log
ray start --address=<ray-head-address> 

● Run the following command to initialize the local model on Pod2:

python3 model_init.py

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

config = AutoConfig.from_pretrained(
    "/mnt/models/Qwen-14B-Chat",
    trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("/mnt/models/Qwen-14B-Chat", trust_remote_code=True)

● On Pod1, run the following command to start the Qwen model:

python3 -m vllm.entrypoints.openai.api_server \
--port 8080 \
--trust-remote-code \
--served-model-name qwen \
--model /mnt/models/Qwen-14B-Chat \
--gpu-memory-utilization 0.95 \
--tensor-parallel-size 2

● Log in to Pod1 to access the application.

kubectl -n <your-namespace> exec -it <pod1-name> bash

curl -H "Content-Type: application/json" \
     http://localhost:8080/v1/chat/completions -X POST \
     -d '{"model": "qwen", "messages": [{"role": "user", "content": "你好"}], "max_tokens": 512, "temperature": 0.7, "top_p": 0.9, "seed": 10, "stop":["<|endoftext|>", "<|im_end|>", "<|im_start|>"]}'

3. Distributed Inference Process Analysis

1. Handler function: vllm/entrypoints/openai/api_server.py main

if __name__ == "__main__":
    # build engine args
    engine_args = AsyncEngineArgs.from_cli_args(args)
    # build engine
    engine = AsyncLLMEngine.from_engine_args(
        engine_args, usage_context=UsageContext.OPENAI_API_SERVER)

    openai_serving_chat = OpenAIServingChat(engine, served_model_names,
                                            args.response_role,
                                            args.lora_modules,
                                            args.chat_template)

    openai_serving_completion = OpenAIServingCompletion(
        engine, served_model_names, args.lora_modules)

    app.root_path = args.root_path
    uvicorn.run(app)

2. Building the LLM Engine

engine = AsyncLLMEngine.from_engine_args(
    engine_args, usage_context=UsageContext.OPENAI_API_SERVER)

def from_engine_args():
    """Creates an async LLM engine from the engine arguments."""
    # Create the engine configs.
    engine_config = engine_args.create_engine_config()

    # Ray cluster initialization
    initialize_ray_cluster(engine_config.parallel_config)
    from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
    executor_class = RayGPUExecutorAsync

    # Create the engine configs.
    engine_config = engine_args.create_engine_config()

    # Ray cluster initialization
    # 1. ray.init()
    # 2. Set the Ray placement strategy based on the number of GPUs in the cluster and the TP degree.
    initialize_ray_cluster(engine_config.parallel_config)
    from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
    executor_class = RayGPUExecutorAsync

    #  Create the async LLM engine.
    engine = cls(...) #Create an AsyncLLMEngine instance.
    # AsyncLLMEngine.__init__ -> self._init_engine -> _AsyncLLMEngine.__init__ -> LLMEngine.__init__ -> executor_class() Call RayGPUExecutorAsync.__init__

3. Initializing the Ray Cluster

Setting up Ray workers involves initializing both the Ray cluster and the Ray workers themselves. During this process, the model is loaded in a distributed manner.

# The RayGPUExecutorAsync inherits the RayGPUExecutor and ExecutorAsyncBase classes and calls the self._init_executor method of RayGPUExecutor during initialization.
def _init_executor(self) -> None:
    # Create the parallel GPU workers. Core code for initializing workers:
    self._init_workers_ray(placement_group)

def _init_workers_ray():
    # Define the worker, using the Worker class from the vllm.worker.worker module.
    # The actor is a RayWorkerWrapper class
    worker = ray.remote(
        num_cpus=0,
        num_gpus=num_gpus,
        scheduling_strategy=scheduling_strategy,
        **ray_remote_kwargs,
    )(RayWorkerWrapper).remote(
        worker_module_name="vllm.worker.worker",
        worker_class_name="Worker",
        trust_remote_code=self.model_config.trust_remote_code,
    )

    # Execute the following methods in sequence on the Ray worker
    self._run_workers("get_node_and_gpu_ids",
                                                use_dummy_driver=True)
    self._run_workers("update_environment_variables",
                      all_args=all_args_to_update_environment_variables)
    self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
    self._run_workers("init_device")
    self._run_workers(
        "load_model",
        max_concurrent_workers=self.parallel_config.
        max_parallel_loading_workers,
    )

def _run_workers():
    # Start the ray workers first.
    ray_worker_outputs = [
        # The worker is the RayWorkerWrapper class defined earlier, inherited from the RayWorkerWrapper class.
        # The RayWorkerWrapper.execute_method is actually called and the method is executed on the remote instance.
        worker.execute_method.remote(method, *worker_args,
                                     **worker_kwargs)
        for (worker, worker_args, worker_kwargs
             ) in zip(self.workers, all_worker_args, all_worker_kwargs)
    ]

def init_worker():
    # worker_module_name is vllm.worker.worker, which is passed in the _init_workers_ray method.
    mod = importlib.import_module(self.worker_module_name)
    # Worker
    worker_class = getattr(mod, self.worker_class_name)
    self.worker = worker_class(*args, **kwargs)
    # Worker.__init__ -> ModelRunner.__init__

def init_device():
    # Initialize the machine information for distributed inference
    """Initialize the distributed environment."""
    init_distributed_environment(parallel_config.world_size, rank,
                                 distributed_init_method, local_rank)

def load_model():
    self.model_runner.load_model() # ModelRunner.load_model() -> vllm.model_executor.model_loader.loader.load_model

Below is the expected log output after load_model() is executed. The log should show two pods, each loading 13.2845 GB, which is half of the model.

INFO 04-26 09:39:46 model_runner.py:173] Loading model weights took 13.2845 GB
(RayWorkerWrapper pid=3327, ip=192.168.12.132) INFO 04-26 09:39:51 model_runner.py:173] Loading model weights took 13.2845 GB

4. Providing External Services

Create instances of OpenAIServingChat and OpenAIServingCompletion, then start uvicorn to provide these services externally.

@app.post("/v1/chat/completions")
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
                                        args.response_role,
                                        args.lora_modules,
                                        args.chat_template)
@app.post("/v1/completions")
openai_serving_completion = OpenAIServingCompletion(
    engine, served_model_names, args.lora_modules)

app.root_path = args.root_path
uvicorn.run(app)

3.1 Distributed Inference Process

When --tensor-parallel-size is greater than 1, ray automatically triggers a distributed deployment.

1. When building the LLM engine, the Ray cluster is initialized.

# Ray cluster initialization
initialize_ray_cluster(engine_config.parallel_config)

The parallel_config is configured as follows: pp = 1, tp = 2, and world_size = 2

{'pipeline_parallel_size': 1, 'tensor_parallel_size': 2, 'worker_use_ray': True, 'max_parallel_loading_workers': None, 'disable_custom_all_reduce': False, 'tokenizer_pool_config': None, 'ray_workers_use_nsight': False, 'placement_group': None, 'world_size': 2}

A placement_group is created for the worker process during initialization.

(1) Determine the total number of GPUs in the Ray cluster.

(2) Request gpu placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) based on the world size.

(3) Create placement_group, and ray will start an actor on the corresponding node.

2. Run get_node_and_gpu_ids on each worker

# Obtain the node and information of the GPU assigned to the node.
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
    node_id = ray.get_runtime_context().get_node_id()
    gpu_ids = ray.get_gpu_ids()
    return node_id, gpu_ids

3. Run update_environment_variables on each worker

# The node and GPU information obtained in the previous step.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
                                                    use_dummy_driver=True)

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
            "CUDA_VISIBLE_DEVICES":
            ",".join(map(str, node_gpus[node_id])),
            "VLLM_INSTANCE_ID":
            VLLM_INSTANCE_ID,
            "VLLM_TRACE_FUNCTION":
            os.getenv("VLLM_TRACE_FUNCTION", "0"),
        }, ) for (node_id, _) in worker_node_and_gpu_ids]

4. Run init_device on each worker

# Parameters used to start the worker
init_worker_all_kwargs = []
# worker_node_and_gpu_ids is the information of the GPU on the worker obtained in Step 2.
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
    local_rank = node_workers[node_id].index(rank)
    init_worker_all_kwargs.append(
        collect_arg_helper_func(
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
            load_config=self.load_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
            vision_language_config=self.vision_language_config,
            is_driver_worker=rank == 0,
        ))

def init_device(self) -> None:
    if self.device_config.device.type == "cuda":
        # torch.distributed.all_reduce does not free the input tensor until
        # the synchronization point. This causes the memory usage to grow
        # as the number of all_reduce calls increases. This env var disables
        # this behavior.
        # Related issue:
        # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
        os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

        # This env var set by Ray causes exceptions with graph building.
        os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
        self.device = torch.device(f"cuda:{self.local_rank}")
        torch.cuda.set_device(self.device)

        _check_if_gpu_supports_dtype(self.model_config.dtype)
        torch.cuda.empty_cache()
        self.init_gpu_memory = torch.cuda.mem_get_info()[0]
    else:
        raise RuntimeError(
            f"Not support device type: {self.device_config.device}")
    # Initialize the distributed environment.
    init_worker_distributed_environment(self.parallel_config, self.rank,
                                        self.distributed_init_method,
                                        self.local_rank)
    # Set random seed.
    set_random_seed(self.model_config.seed)

The core method init_worker_distributed_environment constructs the world information for the distributed cluster, similar to world info in Horovod and DeepSpeed.

The parameters for this method are as follows:

work1: self.rank=0, self.local_rank=0, self.distributed_init_method="tcp://192.168.12.120:42167" (ray master)

{'pipeline_parallel_size': 1, 'tensor_parallel_size': 2, 'worker_use_ray': True, 'max_parallel_loading_workers': None, 'disable_custom_all_reduce': False, 'tokenizer_pool_config': None, 'ray_workers_use_nsight': False, 'placement_group': <ray.util.placement_group.PlacementGroup object at 0x7fdeaa896ce0>, 'world_size': 2}, {'id': PlacementGroupID(51489eb26a9335f31ed1bdb4eace04000000), 'bundle_cache': [{'GPU': 1.0}, {'GPU': 1.0}]}, self.rank=0, tcp://192.168.12.120:42167, self.local_rank=0

work2: self.rank=1, self.local_rank=0,self.distributed_init_method="tcp://192.168.12.120:42167"

{'pipeline_parallel_size': 1, 'tensor_parallel_size': 2, 'worker_use_ray': True, 'max_parallel_loading_workers': None, 'disable_custom_all_reduce': False, 'tokenizer_pool_config': None, 'ray_workers_use_nsight': False, 'world_size': 2}, self.rank=1, tcp://192.168.12.120:42167, self.local_rank=0

self.rank is incremented globally. self.local_rank refers to the sequence number of a GPU in a pod.

5. Execute the load_model method on each worker.

load_model is used to load a distributed model. It is complex and is described in the following section.

3.2 Distributed Model Loading Process

Execute the load_model method on each worker

def load_model():
    self.model_runner.load_model() 

# ModelRunner.load_model() -> vllm.model_executor.model_loader.loader.load_model
def load_model(self) -> None:
    with CudaMemoryProfiler() as m:
         # Use get_model to obtain the model.
        self.model = get_model(
            model_config=self.model_config,
            device_config=self.device_config,
            load_config=self.load_config,
            lora_config=self.lora_config,
            vision_language_config=self.vision_language_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
        )

    self.model_memory_usage = m.consumed_memory
    logger.info(f"Loading model weights took "
                f"{self.model_memory_usage / float(2**30):.4f} GB")

# get_model -> loader.load_model -> DefaultModelLoader.load_model
def load_model(self, *, model_config: ModelConfig,
               device_config: DeviceConfig,
               lora_config: Optional[LoRAConfig],
               vision_language_config: Optional[VisionLanguageConfig],
               parallel_config: ParallelConfig,
               scheduler_config: SchedulerConfig) -> nn.Module:
    with set_default_torch_dtype(model_config.dtype):
        with torch.device(device_config.device):
            """Initialize a model with the given configurations."""
            # Initialize the model.
            model = _initialize_model(model_config, self.load_config,
                                      lora_config, vision_language_config)

        # Call the load_weights method of the corresponding model.
        model.load_weights(
            self._get_weights_iterator(model_config.model,
                                       model_config.revision,
                                       fall_back_to_pt=getattr(
                                           model,
                                           "fall_back_to_pt_during_load",
                                           True)), )
        for _, module in model.named_modules():
            linear_method = getattr(module, "linear_method", None)
            if linear_method is not None:
                linear_method.process_weights_after_loading(module)
            if hasattr(module, "process_weights_after_loading"):
                module.process_weights_after_loading()
    return model.eval()

# Find the specific model based on the model configuration.
def _initialize_model(
        model_config: ModelConfig, load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
        vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
    """Initialize a model with the given configurations."""
    # The architecture field in the config.json file of Qwen-7B-Chat.
    model_class = get_model_architecture(model_config)[0]
    linear_method = _get_linear_method(model_config, load_config)

    return model_class(config=model_config.hf_config,
                       linear_method=linear_method,
                       **_get_model_initialization_kwargs(
                           model_class, lora_config, vision_language_config))

# model_class refers to <class 'vllm.model_executor.models.qwen.QWenLMHeadModel'>
model.load_weights calls the load_weights method of QwenLMHeadModel
# QWenLMHeadModel.load_weights
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("gate_up_proj", "w2", 0),
        ("gate_up_proj", "w1", 1),
    ]

    # The weight of each layer of the model and its name.
    # self.named_parameters refers to model.named_parameters()
    params_dict = dict(self.named_parameters())

    for name, loaded_weight in weights:
        # name: transformer.h.27.mlp.c_proj.weight 
        # loaded_weight: tensor(xxx)
        if "rotary_emb.inv_freq" in name:
            continue
        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            # If found in stacked_params_mapping, shard_name should be replaced with param_name.
            # For example, if name is transformer.h.0.mlp.w1.weight, it should be changed to transformer.h.0.mlp.gate_up_proj.weight.
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            # Python's for-else syntax; reaching here means no break statement was executed in the loop.
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            param = params_dict[name]
            # Locate the corresponding weight_loader method for the name.
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

Weight of each model layer and its weight_loader method

# param,weight_loader

lm_head.weight, weight_loader <bound method VocabParallelEmbedding.weight_loader of ParallelLMHead()>
 transformer.h.0.attn.c_attn.weight, weight_loader <bound method QKVParallelLinear.weight_loader of QKVParallelLinear()>
 transformer.h.0.attn.c_proj.weight, weight_loader <bound method RowParallelLinear.weight_loader of RowParallelLinear()>
 transformer.h.0.ln_1.weight, weight_loader <function default_weight_loader at 0x7f66201ee0e0>
 transformer.h.0.ln_2.weight, weight_loader <function default_weight_loader at 0x7f66201ee0e0>
 transformer.h.0.mlp.c_proj.weight, weight_loader <bound method RowParallelLinear.weight_loader of RowParallelLinear()>
 transformer.h.0.mlp.gate_up_proj.weight, weight_loader <bound method MergedColumnParallelLinear.weight_loader of MergedColumnParallelLinear()>
 transformer.ln_f.weight, weight_loader <function default_weight_loader at 0x7f66201ee0e0>
 transformer.wte.weight, weight_loader <bound method VocabParallelEmbedding.weight_loader of VocabParallelEmbedding()>

Each layer of the model has its own distributed loading method. For example, transformer.h.0.attn.c_proj.weight uses RowParallelLinear.weight_loader.

class RowParallelLinear(torch.nn.Module):
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
      # Obtain the tp_rank of the worker and calculate the weight range to be loaded based on the tp_rank.
      tp_rank = get_tensor_model_parallel_rank()
      input_dim = getattr(param, "input_dim", None)
      param_data = param.data
      if input_dim is not None:
          shard_size = param_data.shape[input_dim]
          start_idx = tp_rank * shard_size
          loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                               shard_size)
      assert param_data.shape == loaded_weight.shape
      param_data.copy_(loaded_weight)

The model is split using the Megatron-LM algorithm. For more information, please refer to this paper.

4. Distributed Model Splitting Algorithm Megatron-LM

4.1 Distributed Node Communication: AllReduce

https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#

(1) Reduce: aggregates the computing results of all GPUs to a specific GPU.

2

(2) Broadcast: synchronizes data from a GPU to all GPUs.

3

(3) AllReduce = Reduce + Broadcast

4

4.2 Transformer Splitting

A Transformer layer consists of a self-attention block followed by a two-layer multi-layer perceptron (MLP).

5

MLP

6

As shown in the figure, the MLP consists of two parts. GeLU is a nonlinear function,

7

so it cannot use row parallelism and must use column parallelism.

8

In this case, B needs to use row parallelism. If B uses column parallelism, an all-reduce synchronization is required.

9

Dropout randomly discards some parameters at a certain ratio, so an all-reduce operation must be performed before Dropout.

Self-Attention

10

In the multi-head attention mechanism, each attention head has its own QKV matrices, and each GPU only calculates part of the attention heads. Therefore, the number of attention heads must be divisible by tp_size. Otherwise, the following error will occur (Qwen-14b with tp=3):

ValueError: Total number of attention heads (40) must be divisible by tensor parallel size (3).

11

Similarly, an all-reduce operation is required before Dropout.

As a result, each Transformer inference requires two all-reduce operations. In Qwen-14b, there are 40 Transformer layers, so one inference requires 81 all-reduce operations. When deploying inference services across nodes, network communication will be a significant overhead.

References

[01] Inference Process Analysis
https://zhuanlan.zhihu.com/p/649974825 (In Chinese)

[02] [Deep Learning] [Distributed Training] Pipeline Parallelism, Tensor Parallelism, and 3D Parallelism
https://zhuanlan.zhihu.com/p/617087561 (In Chinese)

[03] Efficient Training Techniques from Hugging Face (IV): Multi-GPU Distributed Training (DP, PP, TP, ZeRO) _zero-dp
https://blog.csdn.net/qq_56591814/article/details/134099476 (In Chinese)

[04] Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
https://arxiv.org/pdf/1909.08053

0 1 0
Share on

Alibaba Container Service

156 posts | 29 followers

You may also like

Comments

Alibaba Container Service

156 posts | 29 followers

Related Products