rllm.nn.conv.table_conv.FTTransformerConv¶
- class rllm.nn.conv.table_conv.FTTransformerConv(conv_dim: int, feedforward_dim: int | None = None, num_heads: int = 8, dropout: float = 0.3, activation: str = 'relu', use_cls: bool = False)[source]¶
Bases:
ModuleThe FT-Transformer backbone in the “Revisiting Deep Learning Models for Tabular Data” paper.
This module concatenates a learnable CLS token embedding
x_clsto the input tensorxand applies a multi-layer Transformer on the concatenated tensor. After the Transformer layer, the output tensor is divided into two parts: (1)x, corresponding to the original input tensor, and (2)x_cls, corresponding to the CLS token tensor.- Parameters:
conv_dim (int) – Input/Output dimensionality.
feedforward_dim (int, optional) – Hidden dimensionality used by feedforward network of the Transformer model. If
None, it will be set toconv_dim(default:None).num_heads (int) – Number of heads in multi-head attention (default: 8)
dropout (float) – The dropout value (default: 0.3)
activation (str) – The activation function (default:
relu)use_cls (bool) – Whether to use a CLS token (default:
False).
Example
>>> import torch >>> conv = FTTransformerConv(conv_dim=32, num_heads=8, use_cls=False) >>> x = torch.randn(16, 10, 32) >>> out = conv(x)
- forward(x: Tensor) Tensor[source]¶
CLS-token augmented Transformer convolution.
- Parameters:
x (Tensor) – Input tensor of shape [batch_size, num_cols, dim]
- Returns:
If
use_cls=False, output tensor of shape[batch_size, num_cols, dim]corresponding to input columns. Ifuse_cls=True, output tensor of shape[batch_size, dim]for the CLS token representation.- Return type:
torch.Tensor