Source code for rllm.nn.conv.table_conv.tab_transformer_conv
from __future__ import annotations
from typing import Union, Dict
import torch
from torch import Tensor
from rllm.types import ColType
[docs]
class TabTransformerConv(torch.nn.Module):
r"""The TabTransformer LayerConv introduced in the
`"TabTransformer: Tabular Data Modeling Using Contextual Embeddings"
<https://arxiv.org/abs/2012.06678>`_ paper.
This layer leverages the power of the Transformer architecture to capture
complex patterns and relationships within the categorical data.
Args:
conv_dim (int): Input/Output dimensionality.
num_heads (int, optional): Number of attention heads (default: 8).
dropout (float, optional): Attention module dropout (default: 0.3).
activation (str, optional): Activation function (default: "relu").
Example:
>>> import torch
>>> from rllm.types import ColType
>>> conv = TabTransformerConv(conv_dim=32, num_heads=8, dropout=0.1)
>>> x = {ColType.CATEGORICAL: torch.randn(8, 10, 32)}
>>> out = conv(x)
"""
def __init__(
self,
conv_dim: int,
num_heads: int = 8,
dropout: float = 0.3,
activation: str = "relu",
):
super().__init__()
# One encoder layer models contextual interactions among categorical columns.
encoder_layer = torch.nn.TransformerEncoderLayer(
d_model=conv_dim,
nhead=num_heads,
dim_feedforward=conv_dim,
dropout=dropout,
activation=activation,
batch_first=True,
)
encoder_norm = torch.nn.LayerNorm(conv_dim)
self.transformer = torch.nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=1,
norm=encoder_norm,
)
[docs]
def forward(self, x: Union[Dict, Tensor]) -> Union[Dict, Tensor]:
"""Encode categorical features with self-attention.
Args:
x (Union[Dict, Tensor]): A container that supports
``x[ColType.CATEGORICAL]`` indexing. The categorical tensor
is typically shaped as ``[batch_size, num_categorical_cols, conv_dim]``.
Returns:
Union[Dict, Tensor]: The same container type as input, where
``x[ColType.CATEGORICAL]`` is replaced with the transformed tensor.
"""
if isinstance(x, dict) and ColType.CATEGORICAL in x:
x[ColType.CATEGORICAL] = self.transformer(x[ColType.CATEGORICAL])
elif isinstance(x, Tensor):
x = self.transformer(x)
return x