rllm.nn.conv.table_conv.TabTransformerConv

class rllm.nn.conv.table_conv.TabTransformerConv(conv_dim: int, num_heads: int = 8, dropout: float = 0.3, activation: str = 'relu')[source]

Bases: Module

The TabTransformer LayerConv introduced in the “TabTransformer: Tabular Data Modeling Using Contextual Embeddings” paper.

This layer leverages the power of the Transformer architecture to capture complex patterns and relationships within the categorical data.

Parameters:
  • conv_dim (int) – Input/Output dimensionality.

  • num_heads (int, optional) – Number of attention heads (default: 8).

  • dropout (float, optional) – Attention module dropout (default: 0.3).

  • activation (str, optional) – Activation function (default: “relu”).

Example

>>> import torch
>>> from rllm.types import ColType
>>> conv = TabTransformerConv(conv_dim=32, num_heads=8, dropout=0.1)
>>> x = {ColType.CATEGORICAL: torch.randn(8, 10, 32)}
>>> out = conv(x)
forward(x: Dict | Tensor) Dict | Tensor[source]

Encode categorical features with self-attention.

Parameters:

x (Union[Dict, Tensor]) – A container that supports x[ColType.CATEGORICAL] indexing. The categorical tensor is typically shaped as [batch_size, num_categorical_cols, conv_dim].

Returns:

The same container type as input, where x[ColType.CATEGORICAL] is replaced with the transformed tensor.

Return type:

Union[Dict, Tensor]