rllm.nn.conv.table_conv.SAINTConv

class rllm.nn.conv.table_conv.SAINTConv(conv_dim: int, num_cols: int, num_heads: int = 8, dropout: float = 0.3, activation: str = 'relu')[source]

Bases: Module

The SAINTConv Layer introduced in the “SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training” paper.

This layer applies two TransformerEncoder modules: one for aggregating information between columns, and another for aggregating information between samples. This dual attention mechanism allows the model to capture complex relationships both within the features of a single sample and across different samples.

Parameters:
  • conv_dim (int) – Input/Output dimensionality.

  • num_cols (int) – Number of features.

  • num_heads (int, optional) – Number of attention heads (default: 8).

  • dropout (float, optional) – Attention module dropout (default: 0.3).

  • activation (str, optional) – Activation function (default: “relu”).

Example

>>> import torch
>>> conv = SAINTConv(conv_dim=16, num_cols=8, num_heads=4, dropout=0.1)
>>> x = torch.randn(32, 8, 16)
>>> out = conv(x)
forward(x: Tensor) Tensor[source]

Apply column attention then row attention.

Parameters:

x (Tensor) – Input tensor of shape [batch_size, num_cols, conv_dim].

Returns:

Output tensor with the same shape as input.

Return type:

torch.Tensor