rllm.nn.conv.table_conv.FTTransformerConv

class rllm.nn.conv.table_conv.FTTransformerConv(conv_dim: int, feedforward_dim: int | None = None, num_heads: int = 8, dropout: float = 0.3, activation: str = 'relu', use_cls: bool = False)[source]

Bases: Module

The FT-Transformer backbone in the “Revisiting Deep Learning Models for Tabular Data” paper.

This module concatenates a learnable CLS token embedding x_cls to the input tensor x and applies a multi-layer Transformer on the concatenated tensor. After the Transformer layer, the output tensor is divided into two parts: (1) x, corresponding to the original input tensor, and (2) x_cls, corresponding to the CLS token tensor.

Parameters:
  • conv_dim (int) – Input/Output dimensionality.

  • feedforward_dim (int, optional) – Hidden dimensionality used by feedforward network of the Transformer model. If None, it will be set to conv_dim (default: None).

  • num_heads (int) – Number of heads in multi-head attention (default: 8)

  • dropout (float) – The dropout value (default: 0.3)

  • activation (str) – The activation function (default: relu)

  • use_cls (bool) – Whether to use a CLS token (default: False).

Example

>>> import torch
>>> conv = FTTransformerConv(conv_dim=32, num_heads=8, use_cls=False)
>>> x = torch.randn(16, 10, 32)
>>> out = conv(x)
forward(x: Tensor) Tensor[source]

CLS-token augmented Transformer convolution.

Parameters:

x (Tensor) – Input tensor of shape [batch_size, num_cols, dim]

Returns:

If use_cls=False, output tensor of shape [batch_size, num_cols, dim] corresponding to input columns. If use_cls=True, output tensor of shape [batch_size, dim] for the CLS token representation.

Return type:

torch.Tensor