Source code for rllm.nn.conv.table_conv.saint_conv
from __future__ import annotations
import torch
from torch import Tensor
[docs]
class SAINTConv(torch.nn.Module):
r"""The SAINTConv Layer introduced in the
`"SAINT: Improved Neural Networks for Tabular Data via Row Attention
and Contrastive Pre-Training" <https://arxiv.org/abs/2106.01342>`__ paper.
This layer applies two :obj:`TransformerEncoder` modules: one for aggregating
information between columns, and another for aggregating information
between samples. This dual attention mechanism allows the model to capture
complex relationships both within the features of a single sample and
across different samples.
Args:
conv_dim (int): Input/Output dimensionality.
num_cols (int): Number of features.
num_heads (int, optional): Number of attention heads (default: 8).
dropout (float, optional): Attention module dropout (default: 0.3).
activation (str, optional): Activation function (default: "relu").
Example:
>>> import torch
>>> conv = SAINTConv(conv_dim=16, num_cols=8, num_heads=4, dropout=0.1)
>>> x = torch.randn(32, 8, 16)
>>> out = conv(x)
"""
def __init__(
self,
conv_dim: int,
num_cols: int,
num_heads: int = 8,
dropout: float = 0.3,
activation: str = "relu",
):
super().__init__()
# Column Transformer
col_encoder_layer = torch.nn.TransformerEncoderLayer(
d_model=conv_dim,
nhead=num_heads,
dim_feedforward=conv_dim * 4,
dropout=dropout,
activation=activation,
batch_first=True,
)
col_encoder_norm = torch.nn.LayerNorm(conv_dim)
self.col_transformer = torch.nn.TransformerEncoder(
encoder_layer=col_encoder_layer,
num_layers=1,
norm=col_encoder_norm,
)
# Row Transformer
row_encoder_layer = torch.nn.TransformerEncoderLayer(
d_model=conv_dim * num_cols,
nhead=num_heads,
dim_feedforward=conv_dim * num_cols * 4,
dropout=dropout,
activation=activation,
batch_first=True,
)
row_encoder_norm = torch.nn.LayerNorm(conv_dim * num_cols)
self.row_transformer = torch.nn.TransformerEncoder(
encoder_layer=row_encoder_layer,
num_layers=1,
norm=row_encoder_norm,
)
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Apply column attention then row attention.
Args:
x (Tensor): Input tensor of shape
``[batch_size, num_cols, conv_dim]``.
Returns:
torch.Tensor: Output tensor with the same shape as input.
"""
x = self.col_transformer(x)
shape = x.shape
# Flatten feature dimension for row-wise attention across samples.
x = x.reshape(1, x.shape[0], -1)
x = self.row_transformer(x)
return x.reshape(shape)