Batteries-included performant language model serving with vLLM
We walk through the architecture, dataflow, and some design decisions of the vLLM library.
Introduction
With the advent of powerful open-weights large language models (e.g. the Meta AI Llama series, Mistral (8x)7B, the Deepseek series, and the Qwen series), much attention has been placed on the infrastructure and optimization of local/distributed inference and finetuning. vLLM, developed around the introduction of paged attention1, is one such library. The library boasts an impressive following, with surprisingly thorough documentation, a vast set of features (e.g. quantization, LoRA adapters, tool calling, structured outputs, specultaive decoding, prefix caching, and more), and integrations with extensions that support pre- and post-training. It’s therefore a natural library to peek under the hood of: that’s what we’ll attempt to do in this post.
Note that many components of vLLM deserve blog posts in their own right; subsystems like the vLLM implementation of flash attention, speculative decoding, paged attention, and others may be covered in future articles. In this post, we’ll build intuition for the library’s architecture, dataflow, and core components. Our goal is to develop an understanding of the library’s design, and the way modules interact in the serving flow.2
Entrypoint(s)
vLLM supports multiple entrypoints against which model inference can be
performed: examples include an API server, a command line interface, and Python
classes (LLM
for offline/batch inference, and AsyncLLMEngine
for
real-time inference). Here, we’ll focus on a non-pipeline-parallel scenario
utilizing the Python LLM
entrypoint that implements offline inference (in
hardware, one can imagine this as a single-node, (possibly) multi-GPU setting).
A full understanding of this pipeline will translate well to other, more
complicated scenarios.
The LLM Python Class Entrypoint
The LLM
class is really a user-friendly interface for the LLMEngine
, which
is the workhorse that powers the inference pipeline. The class provides the
following important API signatures, among others:
# NOTE `generate`, `encode`, `embed`, `classify` all folllow similar signatures.
# Their difference is in the model architecture (whether the model is "pooled"
# or not):
def generate(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, list[str]]]] = None,
sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None,
priority: Optional[list[int]] = None,
) -> list[RequestOutput]:
def beam_search(
self,
prompts: list[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams,
) -> list[BeamSearchOutput]:
Inputs are cast to their appropriate types and passed to the workhorse, LLMEngine
,
which we’ll explore next.
The LLM Engine
The LLM engine manages the end-to-end execution of a token generation pipeline,
including a tokenizer, a (possibly distributed) language model, a key-value
cache for intermediate model states, and postprocessing. There exist two
flavors of the LLM engine; LLMEngine
(for offline synchronous execution) and
AsyncLLMEngine
(for online asynchronous execution). The implementations are
somewhat similar; while the former performs all steps of the pipeline in a
synchronous manner, the latter executes as many steps on separate executor
threads as possible, to free the base event loop and support multiple
concurrent generation calls.
While some components of the token generation pipeline are embarrassingly
parallel (e.g. tokenization), other components are more tricky: for example,
natural questions emerge about the batching mechanics of multiple sequences
with generate
requests scheduled close to one another (the answer: continuous
batching), KV cache management when multiple generation calls oversubscribe GPU
memory (the answer: paged attention), and KV cache reuse when multiple requests
share the same prefix (the answer: prefix caching). And these questions become
even more involved when considering their interactions with performance
improvements using techniques such as speculative decoding.
To support modular development and clarify responsibilities of individual components, the vLLM developers segmented the stack into a few core modules: the scheduler, executor, worker, and model runner.
Scheduler
A scheduler is defined per vLLM virtual engine3: in synchronous mode, there only exists one virtual engine, while asynchronous mode supports proper pipelining of requests and therefore supports multiple virtual engines. The job of the scheduler is to make a scheduling decision for the engine containing the batch of sequence groups to perform a forward pass on, along with associated memory management decisions.
Each scheduled sequence group (a group of sequences generated from the same
prompt) is represented by a SequenceGroupMetadata
object, containing its
request ID, metadata about the request, the tokenizer representation of the
input, sampling parameters, and other auxiliary information. The dataclass is
replicated below, with some irrelevant information omitted:
class SequenceGroupMetadata(msgspec.Struct):
request_id: str
is_prompt: bool
seq_data: dict[int, SequenceData]
sampling_params: Optional[SamplingParams]
# The block tables: sequence id => list of physical block numbers:
block_tables: dict[int, list[int]]
do_sample: bool = True
pooling_params: Optional[PoolingParams] = None
lora_request: Optional[LoRARequest] = None
# Block numbers that are already computed, for use in prefix caching:
computed_block_nums: Optional[list[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
# Multimodal information:
token_type_ids: Optional[list[int]] = None
multi_modal_data: Optional[Any] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
token_chunk_size: Optional[int] = None
The memory management information is represented by the SchedulerOutputs
dataclass, replicated below:
@dataclass
class SchedulerOutputs:
# Scheduled sequence groups.
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
# Number of prefill groups scheduled.
num_prefill_groups: int
# Total number of batched tokens.
num_batched_tokens: int
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int, int]]
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int, int]]
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]]
# Sequence groups that are going to be ignored.
ignored_seq_groups: List[SequenceGroup]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# The number of requests in the running queue
running_queue_size: int
preempted: int
At LLM engine initialization time, schedulers are initialized with empty queues
for sequence groups in the waiting, running, and swapped states. When the
engine receives a request via the add_request
method (one request per prompt
in batch mode), it performs tokenization and adds the resulting “processed”
request to the scheduler’s waiting pool4, to be scheduled at the next invocation
of LLMEngine.step
.
When step
is called, the scheduler is first invoked via its schedule
method,
which handles the core business logic for request selection and memory allocation
(ultimately returned as a List[SequenceGroupMetadata]
and SchedulerOutputs
to the engine). The default scheduler is memory-aware (via its block space
manager) and optimizes for throughput, always preferring to batch as many
prefills as possible and delay decodes as necessary56. For more details,
the interested reader is encouraged to examine the _schedule_default
method
in the Scheduler
class.
When the scheduler completes its job, the LLM engine (within step
) receives
the aforementioned scheduled sequence groups and scheduler outputs, which it
consolidates into an ExecuteModelRequest
to be passed to the model executor
for forwarding (in the asynchronous engine, this occurs in a separate thread).
Executors and Workers
A model executor is called from the LLM engine as follows:
self.model_executor: ExecutorBase = ... # Initialized upstream
execute_model_req = ... # The request consolidated from the scheduler
outputs = self.model_executor.execute_model(execute_model_req=execute_model_req)
which is implemented in the base executor as a collective remote procedure call
that runs "execute_model"
on all of the executor’s workers and collects the
response:
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.collective_rpc("execute_model", args=(execute_model_req, ))
return output[0]
The implementation of collective_rpc
is determined by the implementation
of the _run_workers
method per-executor; other executor methods follow a
similar pattern of utilizing collective communication to perform a set of
actions across their subsidiary workers. To date, there exist two executor
implementations: one based on Python’s multiprocessing support (workers are
Python processes; only a single node is supported), and another based on a
distributed Ray cluster (workers are Ray workers; multiple nodes are
supported). An executor, thus, can be understood simply as an abstraction that
defines simple APIs over a collection of workers; each worker handles
implementation details of the executor’s methods, and the executor collects a
response to return to the LLM engine.
Since workers perform model-related operations, they are naturally (hardware)
device-specific. WorkerBase
defines the interface a worker must implement,
with one worker implementation per CPU/GPU/other accelerators7:
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations."""
...
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks."""
...
def get_model(self) -> nn.Module:
...
def load_model(self) -> None:
"""Load model onto target device."""
...
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
...
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
...
def add_lora(self, lora_request: LoRARequest) -> bool:
...
def remove_lora(self, lora_id: int) -> bool:
...
def pin_lora(self, lora_id: int) -> bool:
...
def list_loras(self) -> Set[int]:
...
Within the worker
subdirectory, you’ll see various device-specific worker
implementations; to date, this includes CPU, HPU, Neuron, Openvino, TPU,
and XPU (sadly, no Metal shader support yet).
In summary:
- An executor (either multiprocess or Ray) orchestrates collective communication amongst its workers
- A worker has a device-specific implementation that runs in its encapsulated process or Ray worker
We now turn to the implementation of the model logic that each worker executes; this is encapsulated in the model runner.
Model Runner
The model runner is not technically a fully separate object: a runner for a particular device has a 1:1 mapping with a worker implementation for that device. Instead, the class forms a logical encapsulation of model-related methods that operate on the worker’s device. In particular, while the worker implementation manages the KV cache and other collective communication primitives for its device, the model runner focuses solely on the execution of the model forward pass. Its interface is simple:
class ModelRunnerBase(ABC, Generic[T]):
@abstractmethod
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> T:
"""Make an instance of a ModelRunnerInputBase from the broadcasted tensor
dict."""
...
@abstractmethod
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> T:
"""Prepares the inputs to ModelRunnerBase.execute_model from an execution
request. This method may move data to the worker's local device. It is not
allowed to communicate with other workers or devices."""
...
@abstractmethod
def get_model(self) -> nn.Module:
...
@abstractmethod
def execute_model(
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
...
Note that each model runner has its own device input class (this is the T
generic type above); this representation is constructed from the
ExecuteModelRequest
, and cast to device-specific tensors (or other
representations entirely) that can be forwarded through the nn.Model
representing the worker’s shard of the (potentially distributed) LLM.
To build further intuition, we replicate the input preparation and
model execution implementation for distributed workers below (note
that the device-specific implementations are in <device>_model_runner.py
for the interested reader):
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
str, torch.Tensor]]]:
if self.is_driver_worker:
return self._get_driver_input_and_broadcast(execute_model_req)
else:
return self._get_worker_input_from_broadcast()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
# Some portions elided for simplicity:
inputs = self.prepare_input(execute_model_req)
model_input, worker_input, kwargs = inputs
# Process an execution request; e.g., issue cache operations
# via the cache engine:
self.execute_worker(worker_input)
# Pipeline parallelism support:
intermediate_tensors = None
orig_model_execute_time = 0.0
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
# Pipeline parallelism support:
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None]
# Output is List[SamplerOutput]
return output
The output of each model runner (= each worker)’s execution of execute_model
is pooled across workers in the executor implementation, and returned to the
engine (either synchronously or asynchronously, depending on the type of
engine). The engine then applies any necessary postprocessing and returns to
the user, completing a call to step
.
Summary
In this post, we’ve touched on the core components that form a vLLM inference pass. While we haven’t spent much time on the optimizations that make vLLM so powerful (e.g. the implementation of paged attention, speculative decoding, flash attention kernels, etc.), hopefully understanding the inference pipeline end-to-end sheds light on where to dive deeper (for example, the scheduler via the block allocator and the worker via the cache manager both have a role to play in managing KV cache blocks; the scheduler manages continuous batching with tradeoffs between prefill and decode phases, and prefix caching is handled by the block allocator, which is managed by the scheduler).
Notes
-
Hence the name: “virtual”LLM, akin to (paged) virtual memory in an operating system. ↩
-
I’m sure I missed out on a lot. Please feel free to drop me a line and correct me. ↩
-
Virtual engines are a core component of the vLLM implementation of pipeline parallelism. Concretely, pipeline parallelism shards a model across multiple devices such that each node hosts a subset of the model’s layers; this is differentiated from tensor parallelism in which model layers themselves (and therefore model intermediate tensors) are sharded along batch (data parallelism) or non-batch dimensions. To achieve proper pipelining, we require multiple independent “streams” that are data independent of one another and can occupy accelerator time when other streams are being executed on other devices—this leads to the concept of a virtual engine in vLLM, where each virtual engine manages one such stream of execution. ↩
-
The processed sequence group is precisely added to the scheduler (recall, there exists one scheduler per virtual engine = stream of execution) that has the fewest unfinished sequences, to attempt to load-balance across streams. ↩
-
A typical language model inference call can be split into two phases: (typically compute-bound) “prefill”, in which the key-value cache of the prompt is filled by a parallelizable computation of (causal) attention across prompt tokens, and (typically memory-bound) “decode”, in which new tokens are autoregressively generated (potentially with the help of techniques such as speculative decoding). ↩
-
There is more subtelty when “chunked prefill” is enabled, which chunks and places prefill and decode requests in the same batch to improve GPU utilization. More on that in a later post. ↩
-
Most of the worker interface is shared with its parent executor interface. ↩