rllm.nn.conv.table_conv.TabTransformerConv¶
- class rllm.nn.conv.table_conv.TabTransformerConv(conv_dim: int, num_heads: int = 8, dropout: float = 0.3, activation: str = 'relu')[source]¶
Bases:
ModuleThe TabTransformer LayerConv introduced in the “TabTransformer: Tabular Data Modeling Using Contextual Embeddings” paper.
This layer leverages the power of the Transformer architecture to capture complex patterns and relationships within the categorical data.
- Parameters:
conv_dim (int) – Input/Output dimensionality.
num_heads (int, optional) – Number of attention heads (default: 8).
dropout (float, optional) – Attention module dropout (default: 0.3).
activation (str, optional) – Activation function (default: “relu”).
Example
>>> import torch >>> from rllm.types import ColType >>> conv = TabTransformerConv(conv_dim=32, num_heads=8, dropout=0.1) >>> x = {ColType.CATEGORICAL: torch.randn(8, 10, 32)} >>> out = conv(x)
- forward(x: Dict | Tensor) Dict | Tensor[source]¶
Encode categorical features with self-attention.
- Parameters:
x (Union[Dict, Tensor]) – A container that supports
x[ColType.CATEGORICAL]indexing. The categorical tensor is typically shaped as[batch_size, num_categorical_cols, conv_dim].- Returns:
The same container type as input, where
x[ColType.CATEGORICAL]is replaced with the transformed tensor.- Return type:
Union[Dict, Tensor]