rllm.nn.conv.table_conv.SAINTConv¶
- class rllm.nn.conv.table_conv.SAINTConv(conv_dim: int, num_cols: int, num_heads: int = 8, dropout: float = 0.3, activation: str = 'relu')[source]¶
Bases:
ModuleThe SAINTConv Layer introduced in the “SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training” paper.
This layer applies two
TransformerEncodermodules: one for aggregating information between columns, and another for aggregating information between samples. This dual attention mechanism allows the model to capture complex relationships both within the features of a single sample and across different samples.- Parameters:
conv_dim (int) – Input/Output dimensionality.
num_cols (int) – Number of features.
num_heads (int, optional) – Number of attention heads (default: 8).
dropout (float, optional) – Attention module dropout (default: 0.3).
activation (str, optional) – Activation function (default: “relu”).
Example
>>> import torch >>> conv = SAINTConv(conv_dim=16, num_cols=8, num_heads=4, dropout=0.1) >>> x = torch.randn(32, 8, 16) >>> out = conv(x)