Source code for rllm.nn.conv.table_conv.ft_transformer_conv

from __future__ import annotations
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter


[docs] class FTTransformerConv(torch.nn.Module): r"""The FT-Transformer backbone in the `"Revisiting Deep Learning Models for Tabular Data" <https://arxiv.org/abs/2106.11959>`_ paper. This module concatenates a learnable CLS token embedding :obj:`x_cls` to the input tensor :obj:`x` and applies a multi-layer Transformer on the concatenated tensor. After the Transformer layer, the output tensor is divided into two parts: (1) :obj:`x`, corresponding to the original input tensor, and (2) :obj:`x_cls`, corresponding to the CLS token tensor. Args: conv_dim (int): Input/Output dimensionality. feedforward_dim (int, optional): Hidden dimensionality used by feedforward network of the Transformer model. If :obj:`None`, it will be set to :obj:`conv_dim` (default: :obj:`None`). num_heads (int): Number of heads in multi-head attention (default: 8) dropout (float): The dropout value (default: 0.3) activation (str): The activation function (default: :obj:`relu`) use_cls (bool): Whether to use a CLS token (default: :obj:`False`). Example: >>> import torch >>> conv = FTTransformerConv(conv_dim=32, num_heads=8, use_cls=False) >>> x = torch.randn(16, 10, 32) >>> out = conv(x) """ def __init__( self, conv_dim: int, feedforward_dim: Optional[int] = None, num_heads: int = 8, dropout: float = 0.3, activation: str = "relu", use_cls: bool = False, ): super().__init__() self.use_cls = use_cls encoder_layer = torch.nn.TransformerEncoderLayer( d_model=conv_dim, nhead=num_heads, dim_feedforward=feedforward_dim or conv_dim, dropout=dropout, activation=activation, batch_first=True, ) encoder_norm = torch.nn.LayerNorm(conv_dim) self.transformer = torch.nn.TransformerEncoder( encoder_layer=encoder_layer, num_layers=1, norm=encoder_norm, ) self.cls_embedding = Parameter(torch.empty(conv_dim)) self.reset_parameters() def reset_parameters(self): torch.nn.init.normal_(self.cls_embedding, std=0.01) for p in self.transformer.parameters(): if p.dim() > 1: torch.nn.init.xavier_uniform_(p)
[docs] def forward(self, x: Tensor) -> Tensor: r"""CLS-token augmented Transformer convolution. Args: x (Tensor): Input tensor of shape [batch_size, num_cols, dim] Returns: torch.Tensor: If ``use_cls=False``, output tensor of shape ``[batch_size, num_cols, dim]`` corresponding to input columns. If ``use_cls=True``, output tensor of shape ``[batch_size, dim]`` for the CLS token representation. """ B, _, _ = x.shape # [batch_size, num_cols, dim] x_cls = self.cls_embedding.repeat(B, 1, 1) # [batch_size, num_cols + 1, dim] x_concat = torch.cat([x_cls, x], dim=1) # [batch_size, num_cols + 1, dim] x_concat = self.transformer(x_concat) if self.use_cls: return x_concat[:, 0, :] return x_concat[:, 1:, :]