Attention#

pydantic model vision_architectures.layers.attention.Attention1DConfig[source]#

Bases: CustomBaseModel

Show JSON schema
{
   "title": "Attention1DConfig",
   "type": "object",
   "properties": {
      "dim": {
         "anyOf": [
            {
               "type": "integer"
            },
            {
               "maxItems": 2,
               "minItems": 2,
               "prefixItems": [
                  {
                     "type": "integer"
                  },
                  {
                     "type": "integer"
                  }
               ],
               "type": "array"
            }
         ],
         "description": "Dimension of the input features. If tuple, (dim_qk, dim_v). Otherwise it is assumed to be dim of both qk and v.",
         "title": "Dim"
      },
      "num_heads": {
         "description": "Number of query heads",
         "title": "Num Heads",
         "type": "integer"
      },
      "ratio_q_to_kv_heads": {
         "default": 1,
         "description": "Ratio of query heads to key/value heads. Useful for MQA/GQA.",
         "title": "Ratio Q To Kv Heads",
         "type": "integer"
      },
      "logit_scale_learnable": {
         "default": false,
         "description": "Whether the logit scale is learnable.",
         "title": "Logit Scale Learnable",
         "type": "boolean"
      },
      "attn_drop_prob": {
         "default": 0.0,
         "description": "Dropout probability for attention weights.",
         "title": "Attn Drop Prob",
         "type": "number"
      },
      "proj_drop_prob": {
         "default": 0.0,
         "description": "Dropout probability for the projection layer.",
         "title": "Proj Drop Prob",
         "type": "number"
      },
      "max_attention_batch_size": {
         "default": -1,
         "description": "Runs attention by splitting the inputs into chunks of this size. 0 means no chunking. Useful for large inputs during inference. (This happens along batch dimension).",
         "title": "Max Attention Batch Size",
         "type": "integer"
      },
      "rotary_position_embeddings_config": {
         "anyOf": [
            {
               "$ref": "#/$defs/RotaryPositionEmbeddings1DConfig"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "description": "Config for rotary position embeddings"
      }
   },
   "$defs": {
      "RotaryPositionEmbeddings1DConfig": {
         "properties": {
            "dim": {
               "anyOf": [
                  {
                     "type": "integer"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Dimension of the position embeddings",
               "title": "Dim"
            },
            "base": {
               "default": 10000.0,
               "description": "Base value for the exponent.",
               "title": "Base",
               "type": "number"
            }
         },
         "title": "RotaryPositionEmbeddings1DConfig",
         "type": "object"
      }
   },
   "required": [
      "dim",
      "num_heads"
   ]
}

Config:
  • arbitrary_types_allowed: bool = True

  • extra: str = ignore

  • validate_default: bool = True

  • validate_assignment: bool = True

  • validate_return: bool = True

Fields:
Validators:
field dim: int | tuple[int, int] [Required]#

Dimension of the input features. If tuple, (dim_qk, dim_v). Otherwise it is assumed to be dim of both qk and v.

Validated by:
field num_heads: int [Required]#

Number of query heads

Validated by:
field ratio_q_to_kv_heads: int = 1#

Ratio of query heads to key/value heads. Useful for MQA/GQA.

Validated by:
field logit_scale_learnable: bool = False#

Whether the logit scale is learnable.

Validated by:
field attn_drop_prob: float = 0.0#

Dropout probability for attention weights.

Validated by:
field proj_drop_prob: float = 0.0#

Dropout probability for the projection layer.

Validated by:
field max_attention_batch_size: int = -1#

Runs attention by splitting the inputs into chunks of this size. 0 means no chunking. Useful for large inputs during inference. (This happens along batch dimension).

Validated by:
field rotary_position_embeddings_config: RotaryPositionEmbeddings1DConfig | None = None#

Config for rotary position embeddings

Validated by:
property num_q_heads: int#
property num_kv_heads: int#
property gqa_mqa_enabled: bool#
property dim_qk: int#
property dim_v: int#
property per_head_dim_qk: int#
validator validate  »  all fields[source]#

Base method for validating the model after creation.

pydantic model vision_architectures.layers.attention.Attention3DConfig[source]#

Bases: Attention1DConfig

Show JSON schema
{
   "title": "Attention3DConfig",
   "type": "object",
   "properties": {
      "dim": {
         "anyOf": [
            {
               "type": "integer"
            },
            {
               "maxItems": 2,
               "minItems": 2,
               "prefixItems": [
                  {
                     "type": "integer"
                  },
                  {
                     "type": "integer"
                  }
               ],
               "type": "array"
            }
         ],
         "description": "Dimension of the input features. If tuple, (dim_qk, dim_v). Otherwise it is assumed to be dim of both qk and v.",
         "title": "Dim"
      },
      "num_heads": {
         "description": "Number of query heads",
         "title": "Num Heads",
         "type": "integer"
      },
      "ratio_q_to_kv_heads": {
         "default": 1,
         "description": "Ratio of query heads to key/value heads. Useful for MQA/GQA.",
         "title": "Ratio Q To Kv Heads",
         "type": "integer"
      },
      "logit_scale_learnable": {
         "default": false,
         "description": "Whether the logit scale is learnable.",
         "title": "Logit Scale Learnable",
         "type": "boolean"
      },
      "attn_drop_prob": {
         "default": 0.0,
         "description": "Dropout probability for attention weights.",
         "title": "Attn Drop Prob",
         "type": "number"
      },
      "proj_drop_prob": {
         "default": 0.0,
         "description": "Dropout probability for the projection layer.",
         "title": "Proj Drop Prob",
         "type": "number"
      },
      "max_attention_batch_size": {
         "default": -1,
         "description": "Runs attention by splitting the inputs into chunks of this size. 0 means no chunking. Useful for large inputs during inference. (This happens along batch dimension).",
         "title": "Max Attention Batch Size",
         "type": "integer"
      },
      "rotary_position_embeddings_config": {
         "anyOf": [
            {
               "$ref": "#/$defs/RotaryPositionEmbeddings3DConfig"
            },
            {
               "type": "null"
            }
         ],
         "default": null,
         "description": "Config for rotary position embeddings"
      }
   },
   "$defs": {
      "RotaryPositionEmbeddings3DConfig": {
         "properties": {
            "dim": {
               "anyOf": [
                  {
                     "type": "integer"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Dimension of the position embeddings",
               "title": "Dim"
            },
            "base": {
               "default": 10000.0,
               "description": "Base value for the exponent.",
               "title": "Base",
               "type": "number"
            },
            "split": {
               "anyOf": [
                  {
                     "maxItems": 3,
                     "minItems": 3,
                     "prefixItems": [
                        {
                           "type": "number"
                        },
                        {
                           "type": "number"
                        },
                        {
                           "type": "number"
                        }
                     ],
                     "type": "array"
                  },
                  {
                     "maxItems": 3,
                     "minItems": 3,
                     "prefixItems": [
                        {
                           "type": "integer"
                        },
                        {
                           "type": "integer"
                        },
                        {
                           "type": "integer"
                        }
                     ],
                     "type": "array"
                  }
               ],
               "default": [
                  0.3333333333333333,
                  0.3333333333333333,
                  0.3333333333333333
               ],
               "description": "Split of the position embeddings. If float, converted to int based on self.dim",
               "title": "Split"
            }
         },
         "title": "RotaryPositionEmbeddings3DConfig",
         "type": "object"
      }
   },
   "required": [
      "dim",
      "num_heads"
   ]
}

Config:
  • arbitrary_types_allowed: bool = True

  • extra: str = ignore

  • validate_default: bool = True

  • validate_assignment: bool = True

  • validate_return: bool = True

Fields:
Validators:

field rotary_position_embeddings_config: RotaryPositionEmbeddings3DConfig | None = None#

Config for rotary position embeddings

Validated by:
class vision_architectures.layers.attention.Attention1D(config={}, relative_position_bias=None, logit_scale=None, checkpointing_level=0, **kwargs)[source]#

Bases: _Attention

class vision_architectures.layers.attention.Attention3D(config={}, relative_position_bias=None, logit_scale=None, checkpointing_level=0, **kwargs)[source]#

Bases: _Attention

Performs attention (MHA, GQA, and MQA) on 3D sequences. This class is designed for 3D input eg. medical images, videos etc.

__init__(config={}, relative_position_bias=None, logit_scale=None, checkpointing_level=0, **kwargs)[source]#

Initializes the Attention1D module.

Parameters:
  • config (Attention3DConfig) – An instance of the Config class that contains all the configuration parameters. It can also be passed as a dictionary and the instance will be created automatically.

  • relative_position_bias (Union[RelativePositionEmbeddings3D, RelativePositionEmbeddings3DMetaNetwork, None]) – Relative position embeddings to be considered during attention. Should be callable.

  • logit_scale (Optional[float]) – Logit scale to be used for attention. If None, it will be initialized based on per-head dimension.

  • checkpointing_level (int) – The level of checkpointing to use for activation checkpointing. Refer to ActivationCheckpointing for more details.

  • **kwargs – Additional keyword arguments for configuration.

forward(query, key, value, channels_first=True, query_grid_shape=None, key_grid_shape=None)[source]#

Forward pass of the Attention3D module.

Terminology: z => depth, y => height, x => width, b => batch size

Parameters:
  • query (Tensor) – Tensor of shape (b, [dim_qk], z_q, y_q, x_q, [dim_qk]) or (b, T_q, dim_qk) representing the input to the query matrix.

  • key (Tensor) – Tensor of shape (b, [dim_qk], z_kv, y_kv, x_kv, [dim_qk]) or (b, T_kv, dim_qk) representing the input to the key matrix.

  • value (Tensor) – Tensor of shape (b, [dim_v], z_kv, y_kv, x_kv, [dim_v]) or (b, T_kv, dim_v) representing the input to the value matrix.

  • channels_first (bool) – Whether the inputs are in channels first format (B, C, …) or not (B, …, C).

  • query_grid_shape (Optional[tuple[int, int, int]]) – Shape of the tokens in 3D. Used to identify the actual 3D matrix and separate it from extra tokens (eg. class tokens) to apply rotary position embeddings. Leading tokens are treated as extra tokens and only trailing tokens are used.

  • key_grid_shape (Optional[tuple[int, int, int]]) – Shape of the tokens in 3D. Used to identify the actual 3D matrix and separate it from extra tokens (eg. class tokens) to apply rotary position embeddings. Leading tokens are treated as extra tokens and only trailing tokens are used.

Returns:

Tensor of shape (b, [dim_qk], z_q, y_q, x_q, [dim_qk]) or (b, T_q, dim_qk) representing output tokens.