Skip to content

vllm.model_executor.layers.attention

Modules:

Name Description
attention
chunked_local_attention
cross_attention
encoder_only_attention
kv_transfer_utils
mla_attention

MLA Common Components

mm_encoder_attention
static_sink_attention

__all__ module-attribute

__all__ = [
    "Attention",
    "ChunkedLocalAttention",
    "CrossAttention",
    "EncoderOnlyAttention",
    "MLAAttention",
    "MMEncoderAttention",
    "StaticSinkAttention",
]

Attention

Bases: Module, AttentionLayerBase

Attention layer.

This class takes query, key, and value tensors as input. The input tensors can either contain prompt tokens or generation tokens. The class does the following:

  1. Store the input key and value tensors in the KV cache.
  2. Perform (multi-head/multi-query/grouped-query) attention.
  3. Return the output tensor.
Source code in vllm/model_executor/layers/attention/attention.py
class Attention(nn.Module, AttentionLayerBase):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        use_alibi_sqrt: bool | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
        head_size_v: int | None = None,
        **extra_impl_args,
    ) -> None:
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
        super().__init__()
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

        vllm_config = get_current_vllm_config()
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False

        # llm-compressor mdls need to set cache_dtype to "fp8" manually.
        if getattr(quant_config, "kv_cache_scheme", None) is not None:
            kv_cache_dtype = "fp8"
            calculate_kv_scales = False
            if cache_config is not None:
                cache_config.cache_dtype = "fp8"
                cache_config.calculate_kv_scales = False

        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
        self.kv_cache_dtype = kv_cache_dtype
        self.calculate_kv_scales = calculate_kv_scales
        if num_kv_heads is None:
            num_kv_heads = num_heads
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
        self.quant_config = quant_config
        self.layer_name = prefix

        self.num_heads = num_heads
        self.head_size = head_size
        self.head_size_v = self.head_size if head_size_v is None else head_size_v
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
        self.has_sink = extra_impl_args.get("sinks") is not None

        # NOTE: model_config may be None during certain tests
        model_config = vllm_config.model_config
        self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
        if attn_backend is None:
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
                use_mla=False,
                has_sink=self.has_sink,
                use_mm_prefix=self.use_mm_prefix,
                attn_type=attn_type,
            )
        else:
            self.attn_backend = attn_backend
        backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
        use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
        if use_alibi_sqrt and not backend_supports_alibi_sqrt:
            raise ValueError(
                f"use_alibi_sqrt is not supported by backend "
                f"{self.attn_backend.get_name()}."
            )
        self.use_alibi_sqrt = bool(use_alibi_sqrt)
        if backend_supports_alibi_sqrt:
            extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
        # prefix caching + batch invariance is currently not supported for
        # FLASHINFER and TRITON_MLA.
        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "FLASHINFER"
                or self.attn_backend.get_name() == "TRITON_MLA"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for FLASHINFER/TRITON_MLA "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

        impl_cls = self.attn_backend.get_impl_cls()
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
        self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
        self.dtype = dtype

        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
        self.use_direct_call = not current_platform.opaque_attention_op()

        self.use_output = self.attn_backend.accept_output_buffer
        compilation_config = vllm_config.compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.attn_type = attn_type

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([])
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
        ]

        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(self, quant_config, prefix)

        # for attn backends supporting query quantization
        self.query_quant = None
        if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
            "fp8"
        ):
            is_per_head = (
                hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
            )
            block_size = self.head_size * self.num_heads // self.num_kv_heads
            self.query_quant = QuantFP8(
                static=True,
                group_shape=GroupShape(-1, block_size)
                if is_per_head
                else GroupShape.PER_TENSOR,
            )

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}

            # check if query quantization is supported
            if self.impl.supports_quant_query_input:
                query, _ = self.query_quant(query, self._q_scale)

        if self.use_output:
            if output_shape is None:
                # Handle both 2D [num_tokens, hidden] and
                # 3D [num_tokens, heads, head_dim] query
                num_tokens = query.shape[0]
                output_shape = torch.Size(
                    (num_tokens, self.num_heads * self.head_size_v)
                )
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
            hidden_size = output_shape[-1]
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size_v)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size_v)
            if self.use_direct_call:
                kv_cache_dummy_dep = None
                if not self.attn_backend.forward_includes_kv_cache_update:
                    kv_cache_dummy_dep = unified_kv_cache_update(
                        key, value, self.layer_name
                    )
                unified_attention_with_output(
                    query,
                    key,
                    value,
                    output,
                    self.layer_name,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
            else:
                kv_cache_dummy_dep = None
                if not self.attn_backend.forward_includes_kv_cache_update and (
                    # torch can only dispatch custom op if a tensor is passed
                    key is not None or value is not None
                ):
                    kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
                        key, value, self.layer_name
                    )
                torch.ops.vllm.unified_attention_with_output(
                    query,
                    key,
                    value,
                    output,
                    self.layer_name,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
                )
            return output.view(-1, hidden_size)
        else:
            assert self.attn_backend.forward_includes_kv_cache_update, (
                "Split KV cache update not supported when output tensor not provided."
            )
            if self.use_direct_call:
                return unified_attention(query, key, value, self.layer_name)
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name
                )

    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
        s += f", backend={self.impl.__class__.__name__}"
        return s

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        self.impl.process_weights_after_loading(act_dtype)

        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                head_size_v=self.head_size_v,
                dtype=self.kv_cache_torch_dtype,
            )

attn_backend instance-attribute

attn_backend = get_attn_backend(
    head_size,
    dtype,
    kv_cache_dtype,
    block_size,
    use_mla=False,
    has_sink=has_sink,
    use_mm_prefix=use_mm_prefix,
    attn_type=attn_type,
)

attn_type instance-attribute

attn_type = attn_type

backend instance-attribute

backend = AttentionBackendEnum[get_name()]

calculate_kv_scales instance-attribute

calculate_kv_scales = calculate_kv_scales

dtype instance-attribute

dtype = dtype

has_sink instance-attribute

has_sink = get('sinks') is not None

head_size instance-attribute

head_size = head_size

head_size_v instance-attribute

head_size_v = (
    head_size if head_size_v is None else head_size_v
)

impl instance-attribute

impl = impl_cls(
    num_heads,
    head_size,
    scale,
    num_kv_heads,
    alibi_slopes,
    sliding_window,
    kv_cache_dtype,
    logits_soft_cap,
    attn_type,
    kv_sharing_target_layer_name,
    **extra_impl_args,
)

kv_cache instance-attribute

kv_cache = [
    (tensor([])) for _ in (range(pipeline_parallel_size))
]

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_cache_torch_dtype instance-attribute

kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
    kv_cache_dtype, model_config
)

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

layer_name instance-attribute

layer_name = prefix

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

quant_config instance-attribute

quant_config = quant_config

query_quant instance-attribute

query_quant = None

sliding_window instance-attribute

sliding_window = sliding_window

use_alibi_sqrt instance-attribute

use_alibi_sqrt = bool(use_alibi_sqrt)

use_direct_call instance-attribute

use_direct_call = not opaque_attention_op()

use_mm_prefix instance-attribute

use_mm_prefix = model_config is not None and is_mm_prefix_lm

use_output instance-attribute

use_output = accept_output_buffer

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    use_alibi_sqrt: bool | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    logits_soft_cap: float | None = None,
    per_layer_sliding_window: int | None = None,
    prefix: str = "",
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: str | None = None,
    attn_backend: type[AttentionBackend] | None = None,
    head_size_v: int | None = None,
    **extra_impl_args,
) -> None

The KV cache is stored inside this class and is accessed via self.kv_cache.

Source code in vllm/model_executor/layers/attention/attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    use_alibi_sqrt: bool | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    logits_soft_cap: float | None = None,
    per_layer_sliding_window: int | None = None,
    prefix: str = "",
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: str | None = None,
    attn_backend: type[AttentionBackend] | None = None,
    head_size_v: int | None = None,
    **extra_impl_args,
) -> None:
    """
    The KV cache is stored inside this class and is accessed via
    `self.kv_cache`.
    """
    super().__init__()
    if per_layer_sliding_window is not None:
        # per-layer sliding window
        sliding_window = per_layer_sliding_window
    elif cache_config is not None:
        # model-level sliding window
        sliding_window = cache_config.sliding_window
    else:
        sliding_window = None

    vllm_config = get_current_vllm_config()
    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
        calculate_kv_scales = cache_config.calculate_kv_scales
    else:
        kv_cache_dtype = "auto"
        block_size = 16
        calculate_kv_scales = False

    # llm-compressor mdls need to set cache_dtype to "fp8" manually.
    if getattr(quant_config, "kv_cache_scheme", None) is not None:
        kv_cache_dtype = "fp8"
        calculate_kv_scales = False
        if cache_config is not None:
            cache_config.cache_dtype = "fp8"
            cache_config.calculate_kv_scales = False

    self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
        kv_cache_dtype, vllm_config.model_config
    )
    self.kv_cache_dtype = kv_cache_dtype
    self.calculate_kv_scales = calculate_kv_scales
    if num_kv_heads is None:
        num_kv_heads = num_heads
    assert num_heads % num_kv_heads == 0, (
        f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
    )
    self.quant_config = quant_config
    self.layer_name = prefix

    self.num_heads = num_heads
    self.head_size = head_size
    self.head_size_v = self.head_size if head_size_v is None else head_size_v
    self.num_kv_heads = num_kv_heads
    self.sliding_window = sliding_window
    self.has_sink = extra_impl_args.get("sinks") is not None

    # NOTE: model_config may be None during certain tests
    model_config = vllm_config.model_config
    self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

    # During model initialization, the default dtype is set as the model
    # weight and activation dtype.
    dtype = torch.get_default_dtype()
    if attn_backend is None:
        self.attn_backend = get_attn_backend(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=False,
            has_sink=self.has_sink,
            use_mm_prefix=self.use_mm_prefix,
            attn_type=attn_type,
        )
    else:
        self.attn_backend = attn_backend
    backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
    use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
    if use_alibi_sqrt and not backend_supports_alibi_sqrt:
        raise ValueError(
            f"use_alibi_sqrt is not supported by backend "
            f"{self.attn_backend.get_name()}."
        )
    self.use_alibi_sqrt = bool(use_alibi_sqrt)
    if backend_supports_alibi_sqrt:
        extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
    # prefix caching + batch invariance is currently not supported for
    # FLASHINFER and TRITON_MLA.
    if (
        cache_config is not None
        and cache_config.enable_prefix_caching
        and vllm_is_batch_invariant()
        and (
            self.attn_backend.get_name() == "FLASHINFER"
            or self.attn_backend.get_name() == "TRITON_MLA"
        )
    ):
        logger.warning_once(
            "Disabling prefix caching for FLASHINFER/TRITON_MLA "
            "with batch invariance, as it is not yet supported.",
            scope="local",
        )
        cache_config.enable_prefix_caching = False

    impl_cls = self.attn_backend.get_impl_cls()
    self.impl = impl_cls(
        num_heads,
        head_size,
        scale,
        num_kv_heads,
        alibi_slopes,
        sliding_window,
        kv_cache_dtype,
        logits_soft_cap,
        attn_type,
        kv_sharing_target_layer_name,
        **extra_impl_args,
    )
    self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
    self.dtype = dtype

    # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
    # torch.compile works by registering the attention as one giant
    # opaque custom op. For other platforms, we directly call them
    # and let torch.compile handle them.
    self.use_direct_call = not current_platform.opaque_attention_op()

    self.use_output = self.attn_backend.accept_output_buffer
    compilation_config = vllm_config.compilation_config
    if prefix in compilation_config.static_forward_context:
        raise ValueError(f"Duplicate layer name: {prefix}")
    compilation_config.static_forward_context[prefix] = self
    self.attn_type = attn_type

    if kv_sharing_target_layer_name is not None:
        validate_kv_sharing_target(
            prefix,
            kv_sharing_target_layer_name,
            compilation_config.static_forward_context,
        )
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

    # use a placeholder kv cache tensor during init, which will be replaced
    # by bind_kv_cache
    # this variable will not be accessed if use_direct_call is True
    self.kv_cache = [
        torch.tensor([])
        for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
    ]

    # Initialize KV cache quantization attributes
    _init_kv_cache_quant(self, quant_config, prefix)

    # for attn backends supporting query quantization
    self.query_quant = None
    if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
        "fp8"
    ):
        is_per_head = (
            hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
        )
        block_size = self.head_size * self.num_heads // self.num_kv_heads
        self.query_quant = QuantFP8(
            static=True,
            group_shape=GroupShape(-1, block_size)
            if is_per_head
            else GroupShape.PER_TENSOR,
        )

calc_kv_scales

calc_kv_scales(query, key, value)
Source code in vllm/model_executor/layers/attention/attention.py
def calc_kv_scales(self, query, key, value):
    self._q_scale.copy_(torch.abs(query).max() / self.q_range)
    self._k_scale.copy_(torch.abs(key).max() / self.k_range)
    self._v_scale.copy_(torch.abs(value).max() / self.v_range)
    self._q_scale_float = self._q_scale.item()
    self._k_scale_float = self._k_scale.item()
    self._v_scale_float = self._v_scale.item()
    # We only calculate the scales once
    self.calculate_kv_scales = False

extra_repr

extra_repr() -> str
Source code in vllm/model_executor/layers/attention/attention.py
def extra_repr(self) -> str:
    s = f"head_size={self.impl.head_size}"  # type: ignore
    s += f", num_heads={self.impl.num_heads}"  # type: ignore
    s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
    s += f", scale={self.impl.scale}"  # type: ignore
    s += f", backend={self.impl.__class__.__name__}"
    return s

forward

forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    output_shape: Size | None = None,
) -> Tensor

The KV cache is stored inside this class and is accessed via self.kv_cache.

Attention metadata (attn_metadata) is set using a context manager in the model runner's execute_model method. It is accessed via forward context using vllm.forward_context.get_forward_context().attn_metadata.

Source code in vllm/model_executor/layers/attention/attention.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    # For some alternate attention backends like MLA the attention output
    # shape does not match the query shape, so we optionally let the model
    # definition specify the output tensor shape.
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    """
    The KV cache is stored inside this class and is accessed via
    `self.kv_cache`.

    Attention metadata (`attn_metadata`) is set using a context manager in
    the model runner's `execute_model` method. It is accessed via forward
    context using
    `vllm.forward_context.get_forward_context().attn_metadata`.
    """
    if self.calculate_kv_scales:
        torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
    output_dtype = query.dtype
    if self.query_quant is not None:
        # quantizing with a simple torch operation enables
        # torch.compile to fuse this into previous ops
        # which reduces overheads during decoding.
        # Otherwise queries are quantized using custom ops
        # which causes decoding overheads
        assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}

        # check if query quantization is supported
        if self.impl.supports_quant_query_input:
            query, _ = self.query_quant(query, self._q_scale)

    if self.use_output:
        if output_shape is None:
            # Handle both 2D [num_tokens, hidden] and
            # 3D [num_tokens, heads, head_dim] query
            num_tokens = query.shape[0]
            output_shape = torch.Size(
                (num_tokens, self.num_heads * self.head_size_v)
            )
        output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
        hidden_size = output_shape[-1]
        # Reshape the query, key, and value tensors.
        # NOTE(woosuk): We do this outside the custom op to minimize the
        # CPU overheads from the non-CUDA-graph regions.
        query = query.view(-1, self.num_heads, self.head_size)
        output = output.view(-1, self.num_heads, self.head_size_v)
        if key is not None:
            key = key.view(-1, self.num_kv_heads, self.head_size)
        if value is not None:
            value = value.view(-1, self.num_kv_heads, self.head_size_v)
        if self.use_direct_call:
            kv_cache_dummy_dep = None
            if not self.attn_backend.forward_includes_kv_cache_update:
                kv_cache_dummy_dep = unified_kv_cache_update(
                    key, value, self.layer_name
                )
            unified_attention_with_output(
                query,
                key,
                value,
                output,
                self.layer_name,
                kv_cache_dummy_dep=kv_cache_dummy_dep,
            )
        else:
            kv_cache_dummy_dep = None
            if not self.attn_backend.forward_includes_kv_cache_update and (
                # torch can only dispatch custom op if a tensor is passed
                key is not None or value is not None
            ):
                kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
                    key, value, self.layer_name
                )
            torch.ops.vllm.unified_attention_with_output(
                query,
                key,
                value,
                output,
                self.layer_name,
                kv_cache_dummy_dep=kv_cache_dummy_dep,
            )
        return output.view(-1, hidden_size)
    else:
        assert self.attn_backend.forward_includes_kv_cache_update, (
            "Split KV cache update not supported when output tensor not provided."
        )
        if self.use_direct_call:
            return unified_attention(query, key, value, self.layer_name)
        else:
            return torch.ops.vllm.unified_attention(
                query, key, value, self.layer_name
            )

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/model_executor/layers/attention/attention.py
def get_attn_backend(self) -> type[AttentionBackend]:
    return self.attn_backend

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    # Block size may get updated after model loading, refresh it
    block_size = vllm_config.cache_config.block_size
    # Should not be called for enc-dec or encoder-only attention.
    assert self.attn_type == AttentionType.DECODER
    if self.sliding_window is not None:
        assert not vllm_config.model_config.use_mla, (
            "MLA is not supported for slidingwindow"
        )
        return SlidingWindowSpec(
            block_size=block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            dtype=self.kv_cache_torch_dtype,
            sliding_window=self.sliding_window,
        )
    else:
        return FullAttentionSpec(
            block_size=block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            head_size_v=self.head_size_v,
            dtype=self.kv_cache_torch_dtype,
        )

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm/model_executor/layers/attention/attention.py
def process_weights_after_loading(self, act_dtype: torch.dtype):
    self.impl.process_weights_after_loading(act_dtype)

    # If we should not load quant weights, we initialize the scales to 1.0
    # as the default value. See [Note: Register q/k/v/prob scales in state dict]
    # for more details.
    quant_method = (
        self.quant_config.get_quant_method(self, prefix=self.layer_name)
        if self.quant_config
        else None
    )
    if not should_load_quant_weights(quant_method):
        set_default_quant_scales(self, register_buffer=False)

ChunkedLocalAttention

Bases: Attention

Source code in vllm/model_executor/layers/attention/chunked_local_attention.py
class ChunkedLocalAttention(Attention):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        attention_chunk_size: int,
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        kv_sharing_target_layer_name: str | None = None,
        prefix: str = "",
    ):
        self.attention_chunk_size = attention_chunk_size
        dtype = torch.get_default_dtype()
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        underlying_attn_backend = get_attn_backend(
            head_size, dtype, kv_cache_dtype, block_size
        )
        attn_backend = create_chunked_local_attention_backend(
            underlying_attn_backend, attention_chunk_size, block_size
        )

        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
            alibi_slopes=alibi_slopes,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
            kv_sharing_target_layer_name=kv_sharing_target_layer_name,
            attn_backend=attn_backend,
        )

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        assert self.attention_chunk_size
        return ChunkedLocalAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            dtype=self.kv_cache_torch_dtype,
            attention_chunk_size=self.attention_chunk_size,
        )

attention_chunk_size instance-attribute

attention_chunk_size = attention_chunk_size

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    attention_chunk_size: int,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    kv_sharing_target_layer_name: str | None = None,
    prefix: str = "",
)
Source code in vllm/model_executor/layers/attention/chunked_local_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    attention_chunk_size: int,
    num_kv_heads: int | None = None,
    alibi_slopes: list[float] | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    kv_sharing_target_layer_name: str | None = None,
    prefix: str = "",
):
    self.attention_chunk_size = attention_chunk_size
    dtype = torch.get_default_dtype()
    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    underlying_attn_backend = get_attn_backend(
        head_size, dtype, kv_cache_dtype, block_size
    )
    attn_backend = create_chunked_local_attention_backend(
        underlying_attn_backend, attention_chunk_size, block_size
    )

    super().__init__(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        num_kv_heads=num_kv_heads,
        alibi_slopes=alibi_slopes,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=prefix,
        kv_sharing_target_layer_name=kv_sharing_target_layer_name,
        attn_backend=attn_backend,
    )

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/chunked_local_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    assert self.attention_chunk_size
    return ChunkedLocalAttentionSpec(
        block_size=vllm_config.cache_config.block_size,
        num_kv_heads=self.num_kv_heads,
        head_size=self.head_size,
        dtype=self.kv_cache_torch_dtype,
        attention_chunk_size=self.attention_chunk_size,
    )

CrossAttention

Bases: Attention

Cross-attention for encoder-decoder models. Handles attention between decoder queries and encoder keys/values.

Source code in vllm/model_executor/layers/attention/cross_attention.py
class CrossAttention(Attention):
    """
    Cross-attention for encoder-decoder models.
    Handles attention between decoder queries and encoder keys/values.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        cache_config: CacheConfig | None = None,
        attn_type: str | None = None,
        **kwargs,
    ):
        dtype = torch.get_default_dtype()

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        if attn_type is not None:
            assert attn_type == AttentionType.ENCODER_DECODER, (
                "CrossAttention only supports AttentionType.ENCODER_DECODER"
            )

        underlying_attn_backend = get_attn_backend(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            attn_type=AttentionType.ENCODER_DECODER,
        )
        attn_backend = create_cross_attention_backend(underlying_attn_backend)

        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            cache_config=cache_config,
            attn_backend=attn_backend,
            attn_type=AttentionType.ENCODER_DECODER,
            **kwargs,
        )

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        return CrossAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            dtype=self.kv_cache_torch_dtype,
        )

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: CacheConfig | None = None,
    attn_type: str | None = None,
    **kwargs,
)
Source code in vllm/model_executor/layers/attention/cross_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: CacheConfig | None = None,
    attn_type: str | None = None,
    **kwargs,
):
    dtype = torch.get_default_dtype()

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    if attn_type is not None:
        assert attn_type == AttentionType.ENCODER_DECODER, (
            "CrossAttention only supports AttentionType.ENCODER_DECODER"
        )

    underlying_attn_backend = get_attn_backend(
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        attn_type=AttentionType.ENCODER_DECODER,
    )
    attn_backend = create_cross_attention_backend(underlying_attn_backend)

    super().__init__(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        cache_config=cache_config,
        attn_backend=attn_backend,
        attn_type=AttentionType.ENCODER_DECODER,
        **kwargs,
    )

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/cross_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    return CrossAttentionSpec(
        block_size=vllm_config.cache_config.block_size,
        num_kv_heads=self.num_kv_heads,
        head_size=self.head_size,
        dtype=self.kv_cache_torch_dtype,
    )

EncoderOnlyAttention

Bases: Attention

Encoder attention is a special case that doesn't need a KV Cache.

Source code in vllm/model_executor/layers/attention/encoder_only_attention.py
class EncoderOnlyAttention(Attention):
    """
    Encoder attention is a special case that doesn't need a KV Cache.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        cache_config: CacheConfig | None = None,
        attn_type: str | None = None,
        **kwargs,
    ):
        dtype = torch.get_default_dtype()

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        underlying_attn_backend = get_attn_backend(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            attn_type=AttentionType.ENCODER_ONLY,
        )

        attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)

        if attn_type is not None:
            assert attn_type == AttentionType.ENCODER_ONLY, (
                "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
            )

        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            cache_config=cache_config,
            attn_backend=attn_backend,
            attn_type=AttentionType.ENCODER_ONLY,
            **kwargs,
        )

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Does not need KV cache
        return None

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: CacheConfig | None = None,
    attn_type: str | None = None,
    **kwargs,
)
Source code in vllm/model_executor/layers/attention/encoder_only_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    cache_config: CacheConfig | None = None,
    attn_type: str | None = None,
    **kwargs,
):
    dtype = torch.get_default_dtype()

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    underlying_attn_backend = get_attn_backend(
        head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        attn_type=AttentionType.ENCODER_ONLY,
    )

    attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)

    if attn_type is not None:
        assert attn_type == AttentionType.ENCODER_ONLY, (
            "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
        )

    super().__init__(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        cache_config=cache_config,
        attn_backend=attn_backend,
        attn_type=AttentionType.ENCODER_ONLY,
        **kwargs,
    )

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/encoder_only_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    # Does not need KV cache
    return None

MLAAttention

Bases: Module, AttentionLayerBase

Multi-Head Latent Attention layer.

This class takes query, and compressed key/value tensors as input. The class does the following:

  1. Store the input key and value tensors in the KV cache.
  2. Perform (multi-head/multi-query/grouped-query) attention.
  3. Return the output tensor.
Source code in vllm/model_executor/layers/attention/mla_attention.py
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int | None,
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        use_sparse: bool = False,
        indexer: object | None = None,
        **extra_impl_args,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
        self.quant_config = quant_config

        # Initialize KV cache quantization attributes
        self.kv_cache_dtype = kv_cache_dtype
        self.calculate_kv_scales = calculate_kv_scales
        _init_kv_cache_quant(self, quant_config, prefix)

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )

        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "TRITON_MLA"
                or self.attn_backend.get_name() == "FLASHINFER"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for TRITON_MLA / FLASHINFER "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
            **extra_impl_args,
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            if self.attn_backend.accept_output_buffer:
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

attn_backend instance-attribute

attn_backend = get_attn_backend(
    head_size,
    dtype,
    kv_cache_dtype,
    block_size,
    use_mla=True,
    use_sparse=use_sparse,
)

calculate_kv_scales instance-attribute

calculate_kv_scales = calculate_kv_scales

head_size instance-attribute

head_size = kv_lora_rank + qk_rope_head_dim

impl instance-attribute

impl = impl_cls(
    num_heads=num_heads,
    head_size=head_size,
    scale=scale,
    num_kv_heads=1,
    alibi_slopes=None,
    sliding_window=None,
    kv_cache_dtype=kv_cache_dtype,
    logits_soft_cap=None,
    attn_type=DECODER,
    kv_sharing_target_layer_name=None,
    q_lora_rank=q_lora_rank,
    kv_lora_rank=kv_lora_rank,
    qk_nope_head_dim=qk_nope_head_dim,
    qk_rope_head_dim=qk_rope_head_dim,
    qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
    v_head_dim=v_head_dim,
    kv_b_proj=kv_b_proj,
    indexer=indexer,
    **extra_impl_args,
)

k_range instance-attribute

k_range = tensor(K_SCALE_CONSTANT, dtype=float32)

kv_cache instance-attribute

kv_cache = [
    (tensor([])) for _ in (range(pipeline_parallel_size))
]

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_lora_rank instance-attribute

kv_lora_rank = kv_lora_rank

layer_name instance-attribute

layer_name = prefix

num_heads instance-attribute

num_heads = num_heads

q_lora_rank instance-attribute

q_lora_rank = q_lora_rank

q_range instance-attribute

q_range = tensor(Q_SCALE_CONSTANT, dtype=float32)

qk_nope_head_dim instance-attribute

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim instance-attribute

qk_rope_head_dim = qk_rope_head_dim

quant_config instance-attribute

quant_config = quant_config

scale instance-attribute

scale = scale

use_direct_call instance-attribute

use_direct_call = not opaque_attention_op()

use_sparse instance-attribute

use_sparse = use_sparse

v_head_dim instance-attribute

v_head_dim = v_head_dim

v_range instance-attribute

v_range = tensor(V_SCALE_CONSTANT, dtype=float32)

__init__

__init__(
    num_heads: int,
    scale: float,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: int | None,
    kv_lora_rank: int,
    kv_b_proj: ColumnParallelLinear,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
    use_sparse: bool = False,
    indexer: object | None = None,
    **extra_impl_args,
)
Source code in vllm/model_executor/layers/attention/mla_attention.py
def __init__(
    self,
    num_heads: int,
    scale: float,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    q_lora_rank: int | None,
    kv_lora_rank: int,
    kv_b_proj: ColumnParallelLinear,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
    use_sparse: bool = False,
    indexer: object | None = None,
    **extra_impl_args,
):
    super().__init__()
    self.num_heads = num_heads
    self.scale = scale
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.v_head_dim = v_head_dim
    self.q_lora_rank = q_lora_rank
    self.kv_lora_rank = kv_lora_rank
    self.head_size = kv_lora_rank + qk_rope_head_dim
    self.layer_name = prefix

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
        calculate_kv_scales = cache_config.calculate_kv_scales
    else:
        kv_cache_dtype = "auto"
        block_size = 16
        calculate_kv_scales = False
    self.quant_config = quant_config

    # Initialize KV cache quantization attributes
    self.kv_cache_dtype = kv_cache_dtype
    self.calculate_kv_scales = calculate_kv_scales
    _init_kv_cache_quant(self, quant_config, prefix)

    dtype = torch.get_default_dtype()
    self.attn_backend = get_attn_backend(
        self.head_size,
        dtype,
        kv_cache_dtype,
        block_size,
        use_mla=True,
        use_sparse=use_sparse,
    )

    if (
        cache_config is not None
        and cache_config.enable_prefix_caching
        and vllm_is_batch_invariant()
        and (
            self.attn_backend.get_name() == "TRITON_MLA"
            or self.attn_backend.get_name() == "FLASHINFER"
        )
    ):
        logger.warning_once(
            "Disabling prefix caching for TRITON_MLA / FLASHINFER "
            "with batch invariance, as it is not yet supported.",
            scope="local",
        )
        cache_config.enable_prefix_caching = False

    impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
    self.impl = impl_cls(
        num_heads=self.num_heads,
        head_size=self.head_size,
        scale=self.scale,
        num_kv_heads=1,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype=self.kv_cache_dtype,
        logits_soft_cap=None,
        attn_type=AttentionType.DECODER,
        kv_sharing_target_layer_name=None,
        # MLA Args
        q_lora_rank=self.q_lora_rank,
        kv_lora_rank=self.kv_lora_rank,
        qk_nope_head_dim=self.qk_nope_head_dim,
        qk_rope_head_dim=self.qk_rope_head_dim,
        qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
        v_head_dim=self.v_head_dim,
        kv_b_proj=kv_b_proj,
        indexer=indexer,
        **extra_impl_args,
    )

    self.use_direct_call = not current_platform.opaque_attention_op()

    compilation_config = get_current_vllm_config().compilation_config
    if prefix in compilation_config.static_forward_context:
        raise ValueError(f"Duplicate layer name: {prefix}")
    compilation_config.static_forward_context[prefix] = self

    self.kv_cache = [
        torch.tensor([])
        for _ in range(
            get_current_vllm_config().parallel_config.pipeline_parallel_size
        )
    ]

    self.use_sparse = use_sparse

    # Initialize q/k/v range constants.
    self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
    self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
    self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

calc_kv_scales

calc_kv_scales(
    q: Tensor, kv_c_normed: Tensor, k_pe: Tensor
) -> None

Optional scale calculation for MLA inputs.

Mirrors Attention.calc_kv_scales. Not all MLA backends require this

Source code in vllm/model_executor/layers/attention/mla_attention.py
def calc_kv_scales(
    self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
) -> None:
    """Optional scale calculation for MLA inputs.

    Mirrors Attention.calc_kv_scales. Not all MLA backends require this
    """
    # Use safe defaults if ranges are not present
    q_range = getattr(self, "q_range", torch.tensor(1.0))
    k_range = getattr(self, "k_range", torch.tensor(1.0))
    v_range = getattr(self, "v_range", torch.tensor(1.0))

    self._q_scale.copy_(torch.abs(q).max() / q_range)
    # kv_c_normed is the compressed KV representation; use it for k/v
    kv_abs_max = torch.abs(kv_c_normed).max()
    self._k_scale.copy_(kv_abs_max / k_range)
    self._v_scale.copy_(kv_abs_max / v_range)
    self._q_scale_float = self._q_scale.item()
    self._k_scale_float = self._k_scale.item()
    self._v_scale_float = self._v_scale.item()
    self.calculate_kv_scales = False

forward

forward(
    q: Tensor,
    kv_c_normed: Tensor,
    k_pe: Tensor,
    output_shape: Size | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mla_attention.py
def forward(
    self,
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    if self.calculate_kv_scales:
        torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

    if self.use_direct_call:
        forward_context: ForwardContext = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        if isinstance(attn_metadata, dict):
            attn_metadata = attn_metadata[self.layer_name]
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]

        if self.attn_backend.accept_output_buffer:
            output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
            self.impl.forward(
                self,
                q,
                kv_c_normed,
                k_pe,
                self_kv_cache,
                attn_metadata,
                output=output,
            )
            return output
        else:
            return self.impl.forward(
                self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
            )
    else:
        if self.attn_backend.accept_output_buffer:
            output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
            torch.ops.vllm.unified_mla_attention_with_output(
                q,
                kv_c_normed,
                k_pe,
                output,
                self.layer_name,
            )
            return output
        else:
            return torch.ops.vllm.unified_mla_attention(
                q,
                kv_c_normed,
                k_pe,
                self.layer_name,
            )

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/model_executor/layers/attention/mla_attention.py
def get_attn_backend(self) -> type[AttentionBackend]:
    return self.attn_backend

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/mla_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    kv_cache_dtype = kv_cache_dtype_str_to_dtype(
        self.kv_cache_dtype, vllm_config.model_config
    )
    return MLAAttentionSpec(
        block_size=vllm_config.cache_config.block_size,
        num_kv_heads=1,
        head_size=self.head_size,
        dtype=kv_cache_dtype,
        cache_dtype_str=vllm_config.cache_config.cache_dtype,
    )

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm/model_executor/layers/attention/mla_attention.py
def process_weights_after_loading(self, act_dtype: torch.dtype):
    if hasattr(self.impl, "process_weights_after_loading"):
        self.impl.process_weights_after_loading(act_dtype)

    # If we should not load quant weights, we initialize the scales to 1.0
    # as the default value. See [Note: Register q/k/v/prob scales in state dict]
    # for more details.
    quant_method = (
        self.quant_config.get_quant_method(self, prefix=self.layer_name)
        if self.quant_config
        else None
    )
    if not should_load_quant_weights(quant_method):
        set_default_quant_scales(self, register_buffer=False)

MMEncoderAttention

Bases: CustomOp

Multi-headed attention without any cache, used for multimodal encoder.

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
    """Multi-headed attention without any cache, used for multimodal encoder."""

    # --8<-- [end:mm_encoder_attn]

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float | None = None,
        num_kv_heads: int | None = None,
        prefix: str = "",
    ) -> None:
        """
        Args:
            num_heads: number of attention heads per partition.
            head_size: hidden_size per attention head.
            scale: scale factor.
            num_kv_heads: number of kv heads.
            prefix: This has no effect, it is only here to make it easier to
                    swap between Attention and MultiHeadAttention
        """
        super().__init__()

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        self.layer_name = prefix

        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
            f"divisible by num_kv_heads ({self.num_kv_heads})"
        )
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()

        # Get device-specific vision attention backend.
        self.attn_backend = get_vit_attn_backend(
            head_size=head_size,
            dtype=dtype,
        )

        self.is_flash_attn_backend = self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }

        self._fa_version = (
            get_flash_attn_version() if self.is_flash_attn_backend else None
        )

        logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

    @classmethod
    def enabled(cls) -> bool:
        return True

    def maybe_reshape_qkv_to_4d(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        bsz: int,
        q_len: int,
        kv_len: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Reshape query, key, value to 4D tensors:
        (batch_size, seq_len, num_heads, head_size)
        """
        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

        return query, key, value

    def _forward_sdpa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.maybe_reshape_qkv_to_4d(
            query, key, value, bsz, q_len, kv_len
        )

        output = vit_torch_sdpa_wrapper(
            q=query,
            k=key,
            v=value,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def _forward_fa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        assert (cu_seqlens is not None and max_seqlen is not None) or (
            cu_seqlens is None and max_seqlen is None
        ), "cu_seqlens and max_seqlen should be both set or both None."

        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.maybe_reshape_qkv_to_4d(
            query, key, value, bsz, q_len, kv_len
        )

        output = vit_flash_attn_wrapper(
            q=query,
            k=key,
            v=value,
            batch_size=bsz,
            is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
            fa_version=self._fa_version,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def forward_native(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_cuda(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        if self.is_flash_attn_backend:
            return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
            return self._forward_sdpa(query, key, value, cu_seqlens)
        else:
            raise ValueError(
                f"Unsupported multi-modal encoder attention backend for CUDA: "
                f"{self.attn_backend}."
            )

    def forward_cpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_xpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        assert self.is_flash_attn_backend, (
            "XPU only supports FLASH_ATTN for vision attention."
        )
        return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

_fa_version instance-attribute

_fa_version = (
    get_flash_attn_version()
    if is_flash_attn_backend
    else None
)

attn_backend instance-attribute

attn_backend = get_vit_attn_backend(
    head_size=head_size, dtype=dtype
)

head_size instance-attribute

head_size = head_size

is_flash_attn_backend instance-attribute

is_flash_attn_backend = attn_backend in {
    FLASH_ATTN,
    ROCM_AITER_FA,
}

layer_name instance-attribute

layer_name = prefix

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = (
    num_heads if num_kv_heads is None else num_kv_heads
)

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = scale

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
) -> None

Parameters:

Name Type Description Default
num_heads int

number of attention heads per partition.

required
head_size int

hidden_size per attention head.

required
scale float | None

scale factor.

None
num_kv_heads int | None

number of kv heads.

None
prefix str

This has no effect, it is only here to make it easier to swap between Attention and MultiHeadAttention

''
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
) -> None:
    """
    Args:
        num_heads: number of attention heads per partition.
        head_size: hidden_size per attention head.
        scale: scale factor.
        num_kv_heads: number of kv heads.
        prefix: This has no effect, it is only here to make it easier to
                swap between Attention and MultiHeadAttention
    """
    super().__init__()

    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = scale
    self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
    self.layer_name = prefix

    assert self.num_heads % self.num_kv_heads == 0, (
        f"num_heads ({self.num_heads}) is not "
        f"divisible by num_kv_heads ({self.num_kv_heads})"
    )
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    # During model initialization, the default dtype is set as the model
    # weight and activation dtype.
    dtype = torch.get_default_dtype()

    # Get device-specific vision attention backend.
    self.attn_backend = get_vit_attn_backend(
        head_size=head_size,
        dtype=dtype,
    )

    self.is_flash_attn_backend = self.attn_backend in {
        AttentionBackendEnum.FLASH_ATTN,
        AttentionBackendEnum.ROCM_AITER_FA,
    }

    self._fa_version = (
        get_flash_attn_version() if self.is_flash_attn_backend else None
    )

    logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

_forward_fa

_forward_fa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_fa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    assert (cu_seqlens is not None and max_seqlen is not None) or (
        cu_seqlens is None and max_seqlen is None
    ), "cu_seqlens and max_seqlen should be both set or both None."

    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.maybe_reshape_qkv_to_4d(
        query, key, value, bsz, q_len, kv_len
    )

    output = vit_flash_attn_wrapper(
        q=query,
        k=key,
        v=value,
        batch_size=bsz,
        is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
        fa_version=self._fa_version,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

_forward_sdpa

_forward_sdpa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_sdpa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.maybe_reshape_qkv_to_4d(
        query, key, value, bsz, q_len, kv_len
    )

    output = vit_torch_sdpa_wrapper(
        q=query,
        k=key,
        v=value,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

enabled classmethod

enabled() -> bool
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
@classmethod
def enabled(cls) -> bool:
    return True

forward_cpu

forward_cpu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_cpu(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    return self._forward_sdpa(query, key, value, cu_seqlens)

forward_cuda

forward_cuda(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_cuda(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    if self.is_flash_attn_backend:
        return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
    elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
        return self._forward_sdpa(query, key, value, cu_seqlens)
    else:
        raise ValueError(
            f"Unsupported multi-modal encoder attention backend for CUDA: "
            f"{self.attn_backend}."
        )

forward_native

forward_native(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_native(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    return self._forward_sdpa(query, key, value, cu_seqlens)

forward_xpu

forward_xpu(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def forward_xpu(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    assert self.is_flash_attn_backend, (
        "XPU only supports FLASH_ATTN for vision attention."
    )
    return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)

maybe_reshape_qkv_to_4d

maybe_reshape_qkv_to_4d(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[Tensor, Tensor, Tensor]

Reshape query, key, value to 4D tensors: (batch_size, seq_len, num_heads, head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def maybe_reshape_qkv_to_4d(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Reshape query, key, value to 4D tensors:
    (batch_size, seq_len, num_heads, head_size)
    """
    query = query.view(bsz, q_len, self.num_heads, self.head_size)
    key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
    value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

    if (num_repeat := self.num_queries_per_kv) > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_repeat, dim=2)
        value = torch.repeat_interleave(value, num_repeat, dim=2)

    return query, key, value

StaticSinkAttention

Bases: Attention, CustomOp

Attention with static sink tokens

Source code in vllm/model_executor/layers/attention/static_sink_attention.py
@CustomOp.register("static_sink_attention")
class StaticSinkAttention(Attention, CustomOp):
    """
    Attention with static sink tokens
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        sink_len: int,
        attn_backend: type[AttentionBackend] | None = None,
        cache_config: CacheConfig | None = None,
        **kwargs,
    ):
        dtype = torch.get_default_dtype()

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            block_size = 16

        if attn_backend is not None:
            underlying_attn_backend = attn_backend
        else:
            underlying_attn_backend = get_attn_backend(
                head_size, dtype, kv_cache_dtype, block_size
            )
        attn_backend = create_static_sink_attention_backend(
            underlying_attn_backend,  # type: ignore[arg-type]
            sink_len=sink_len,
        )
        Attention.__init__(
            self=self,
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            cache_config=cache_config,
            attn_backend=attn_backend,
            **kwargs,
        )
        CustomOp.__init__(self)

        self.sink_len = sink_len
        self.block_size = block_size
        self.sink_populated = False
        self.sink_key = None
        self.sink_value = None

    def update_sink_kv(self, sink_key, sink_value) -> None:
        self.sink_key = sink_key
        self.sink_value = sink_value

    def forward_native(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        assert self.sink_key is not None and self.sink_value is not None, (
            "sink_key and sink_value have not been prepared"
        )
        if not self.sink_populated:
            forward_context: ForwardContext = get_forward_context()
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]
            torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)

        return super().forward(query, key, value, output_shape)

    def forward_cuda(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output_shape: torch.Size | None = None,
    ) -> torch.Tensor:
        return self.forward_native(query, key, value, output_shape)

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def populate_sink_kv(self, self_kv_cache):
        sink_kv_slot_mapping = torch.arange(
            self.block_size,
            self.sink_len + self.block_size,
            device=torch.cuda.current_device(),
            dtype=torch.long,
        )
        triton_reshape_and_cache_flash_diffkv(
            self.sink_key,
            self.sink_value,
            self_kv_cache,
            sink_kv_slot_mapping,
            self.kv_cache_dtype,
            self._k_scale,
            self._v_scale,
        )
        # We only populate the sink_key and sink_value once
        self.sink_populated = True

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER

        return SinkFullAttentionSpec(
            block_size=block_size,
            num_kv_heads=self.num_kv_heads,
            head_size=self.head_size,
            head_size_v=self.head_size_v,
            sink_len=self.sink_len,
            dtype=self.kv_cache_torch_dtype,
        )

block_size instance-attribute

block_size = block_size

sink_key instance-attribute

sink_key = None

sink_len instance-attribute

sink_len = sink_len

sink_populated instance-attribute

sink_populated = False

sink_value instance-attribute

sink_value = None

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    sink_len: int,
    attn_backend: type[AttentionBackend] | None = None,
    cache_config: CacheConfig | None = None,
    **kwargs,
)
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    sink_len: int,
    attn_backend: type[AttentionBackend] | None = None,
    cache_config: CacheConfig | None = None,
    **kwargs,
):
    dtype = torch.get_default_dtype()

    if cache_config is not None:
        kv_cache_dtype = cache_config.cache_dtype
        block_size = cache_config.block_size
    else:
        kv_cache_dtype = "auto"
        block_size = 16

    if attn_backend is not None:
        underlying_attn_backend = attn_backend
    else:
        underlying_attn_backend = get_attn_backend(
            head_size, dtype, kv_cache_dtype, block_size
        )
    attn_backend = create_static_sink_attention_backend(
        underlying_attn_backend,  # type: ignore[arg-type]
        sink_len=sink_len,
    )
    Attention.__init__(
        self=self,
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        cache_config=cache_config,
        attn_backend=attn_backend,
        **kwargs,
    )
    CustomOp.__init__(self)

    self.sink_len = sink_len
    self.block_size = block_size
    self.sink_populated = False
    self.sink_key = None
    self.sink_value = None

forward

forward(*args, **kwargs)
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def forward(self, *args, **kwargs):
    return self._forward_method(*args, **kwargs)

forward_cuda

forward_cuda(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    output_shape: Size | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def forward_cuda(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    return self.forward_native(query, key, value, output_shape)

forward_native

forward_native(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    output_shape: Size | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def forward_native(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output_shape: torch.Size | None = None,
) -> torch.Tensor:
    assert self.sink_key is not None and self.sink_value is not None, (
        "sink_key and sink_value have not been prepared"
    )
    if not self.sink_populated:
        forward_context: ForwardContext = get_forward_context()
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]
        torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)

    return super().forward(query, key, value, output_shape)

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    # Block size may get updated after model loading, refresh it
    block_size = vllm_config.cache_config.block_size
    # Should not be called for enc-dec or encoder-only attention.
    assert self.attn_type == AttentionType.DECODER

    return SinkFullAttentionSpec(
        block_size=block_size,
        num_kv_heads=self.num_kv_heads,
        head_size=self.head_size,
        head_size_v=self.head_size_v,
        sink_len=self.sink_len,
        dtype=self.kv_cache_torch_dtype,
    )

populate_sink_kv

populate_sink_kv(self_kv_cache)
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def populate_sink_kv(self, self_kv_cache):
    sink_kv_slot_mapping = torch.arange(
        self.block_size,
        self.sink_len + self.block_size,
        device=torch.cuda.current_device(),
        dtype=torch.long,
    )
    triton_reshape_and_cache_flash_diffkv(
        self.sink_key,
        self.sink_value,
        self_kv_cache,
        sink_kv_slot_mapping,
        self.kv_cache_dtype,
        self._k_scale,
        self._v_scale,
    )
    # We only populate the sink_key and sink_value once
    self.sink_populated = True

update_sink_kv

update_sink_kv(sink_key, sink_value) -> None
Source code in vllm/model_executor/layers/attention/static_sink_attention.py
def update_sink_kv(self, sink_key, sink_value) -> None:
    self.sink_key = sink_key
    self.sink_value = sink_value