Source code for rllm.nn.models.transtab

from __future__ import annotations
from typing import Any, Dict, List, Optional, Union, Tuple
import math
import os

import numpy as np
import pandas as pd
import torch
from torch import Tensor

from rllm.types import ColType
from rllm.preprocessing import TokenizerConfig
from rllm.data.table_data import TableData
from rllm.nn.loss import SupervisedVPCL, SelfSupervisedVPCL
from rllm.nn.encoder import TransTabPreEncoder
from rllm.nn.conv.table_conv import TransTabConv


class TransTabCLSToken(torch.nn.Module):
    r"""Prepends a learnable ``[CLS]`` token to each input sequence.

    Args:
        hidden_dim (int): Dimensionality of the ``[CLS]`` embedding.

    Shape:
        - Input: :math:`(N, S, H)` → Output: :math:`(N, S+1, H)`

    Examples::

        >>> cls = TransTabCLSToken(hidden_dim=8)
        >>> out = cls(torch.randn(2, 5, 8))
        >>> out["embedding"].shape
        torch.Size([2, 6, 8])
    """

    def __init__(self, hidden_dim) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        self.weight = torch.nn.Parameter(Tensor(self.hidden_dim))
        torch.nn.init.uniform_(
            self.weight,
            a=-1 / math.sqrt(self.hidden_dim),
            b=1 / math.sqrt(self.hidden_dim),
        )

    def expand(self, *leading_dimensions):
        new_dims = (1,) * (len(leading_dimensions) - 1)
        return self.weight.view(*new_dims, -1).expand(*leading_dimensions, -1)

    def forward(self, embedding, attention_mask=None, **kwargs) -> Tensor:
        embedding = torch.cat([self.expand(len(embedding), 1), embedding], dim=1)
        outputs = {"embedding": embedding}
        if attention_mask is not None:
            attention_mask = torch.cat(
                [
                    torch.ones(attention_mask.shape[0], 1).to(attention_mask.device),
                    attention_mask,
                ],
                1,
            )
        outputs["attention_mask"] = attention_mask
        return outputs


[docs] class TransTab(torch.nn.Module): r"""Base TransTab encoder for tabular data (`"TransTab" <https://arxiv.org/abs/2205.09328>`_). Encodes column names and cell values into token embeddings, prepends a learnable ``[CLS]`` token, and refines the sequence through ``num_layer`` Transformer layers. The final ``[CLS]`` position is returned as a fixed-size table-level embedding. Args: categorical_columns (List[str], optional): Categorical column names. Default: ``None``. numerical_columns (List[str], optional): Numerical column names. Default: ``None``. binary_columns (List[str], optional): Binary column names. Default: ``None``. hidden_dim (int): Shared embedding dimensionality. Default: ``128``. num_layer (int): Number of Transformer layers. Default: ``2``. num_attention_head (int): Number of attention heads. Default: ``8``. hidden_dropout_prob (float): Dropout probability. Default: ``0.1``. layer_norm_eps (float): LayerNorm :math:`\varepsilon`. Default: ``1e-5``. ffn_dim (int): Feedforward inner dimension. Default: ``256``. activation (str): Feedforward activation. Default: ``"relu"``. tokenizer: Pre-trained tokenizer; created automatically when ``None``. Default: ``None``. **kwargs: Forwarded to :class:`~rllm.nn.encoder.TransTabPreEncoder`. Examples:: >>> from rllm.nn.models import TransTab >>> model = TransTab(hidden_dim=32, num_layer=1, num_attention_head=4) """ def __init__( self, categorical_columns: List[str] = None, numerical_columns: List[str] = None, binary_columns: List[str] = None, hidden_dim: int = 128, num_layer: int = 2, num_attention_head: int = 8, hidden_dropout_prob: float = 0.1, layer_norm_eps: float = 1e-5, ffn_dim: int = 256, activation: str = "relu", projection_dim: int = 128, overlap_ratio: float = 0.1, num_partition: int = 2, supervised: bool = True, temperature: float = 10.0, base_temperature: float = 10.0, tokenizer=None, **kwargs, ) -> None: super().__init__() self.hidden_dropout_prob = hidden_dropout_prob self.layer_norm_eps = layer_norm_eps self.categorical_columns = categorical_columns self.numerical_columns = numerical_columns self.binary_columns = binary_columns metadata = { ColType.CATEGORICAL: [], ColType.BINARY: [], ColType.NUMERICAL: [], } self.pre_encoder = TransTabPreEncoder( out_dim=hidden_dim, metadata=metadata, categorical_columns=categorical_columns, numerical_columns=numerical_columns, binary_columns=binary_columns, tokenizer=tokenizer, hidden_dropout_prob=self.hidden_dropout_prob, layer_norm_eps=self.layer_norm_eps, use_align_layer=True, **kwargs, ) self.convs = torch.nn.ModuleList() for _ in range(num_layer): self.convs.append( TransTabConv( conv_dim=hidden_dim, nhead=num_attention_head, dropout=hidden_dropout_prob, dim_feedforward=ffn_dim, activation=activation, layer_norm_eps=self.layer_norm_eps, norm_first=False, use_layer_norm=True, batch_first=True, ) ) self.cls_token = TransTabCLSToken(hidden_dim=hidden_dim) self.projection_head = torch.nn.Linear(hidden_dim, projection_dim, bias=False) self.supervised = supervised self.temperature = temperature self.base_temperature = base_temperature self.num_partition = num_partition self.overlap_ratio = overlap_ratio self.ce_loss = torch.nn.CrossEntropyLoss() @property def device(self) -> torch.device: return next(self.parameters()).device
[docs] def forward( self, x: Union[pd.DataFrame, TableData, Dict[ColType, torch.Tensor]], y: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Encode a table batch into a ``[CLS]`` embedding of shape :math:`(N, H)`. Args: x (TableData): Input table batch. y (Tensor, optional): Unused; kept for subclass API symmetry. Default: ``None``. Returns: Tensor: ``[CLS]`` embeddings of shape :math:`(N, H)`. """ if isinstance(x, TableData) or hasattr(x, "feat_dict"): proc_out = self.pre_encoder(x) else: raise ValueError( f"Expected input type TableData or object with feat_dict, got {type(x)}" ) emb = proc_out["embedding"] # (batch, seq_len, hidden_dim) mask = proc_out["attention_mask"] # (batch, seq_len) # Prepend CLS token cls_out = self.cls_token(emb, attention_mask=mask) emb2 = cls_out["embedding"] # (batch, seq_len+1, hidden_dim) mask2 = cls_out["attention_mask"] # (batch, seq_len+1) # Transformer encoding for conv in self.convs: enc_out = conv(x=emb2, src_key_padding_mask=mask2) emb2 = enc_out final_cls = enc_out[:, 0, :] # (batch, hidden_dim) return final_cls
[docs] def save(self, ckpt_dir: str) -> None: r"""Save model weights to ``ckpt_dir``.""" self.pre_encoder.save(ckpt_dir) os.makedirs(ckpt_dir, exist_ok=True) model_path = os.path.join(ckpt_dir, "pytorch_model.bin") torch.save(self.state_dict(), model_path) print(f"Saved TransTab weights to {model_path}")
[docs] def load(self, ckpt_dir: str) -> None: r"""Restore model weights from ``ckpt_dir`` (``strict=False``).""" self.pre_encoder.load(ckpt_dir) # Synchronize column mapping self.categorical_columns = self.pre_encoder.categorical_columns self.numerical_columns = self.pre_encoder.numerical_columns self.binary_columns = self.pre_encoder.binary_columns model_path = os.path.join(ckpt_dir, "pytorch_model.bin") state_dict = torch.load(model_path, map_location="cpu", weights_only=True) missing, unexpected = self.load_state_dict(state_dict, strict=False) print(f"Loaded TransTab weights from {model_path}") print(f" Missing keys: {missing}") print(f" Unexpected keys: {unexpected}") # Cache the pre-trained table_encoder state for subsequent update calls pe_path = os.path.join(ckpt_dir, "input_encoder.bin") self._table_encoder_state = torch.load( pe_path, map_location="cpu", weights_only=True )
def update(self, config: Dict[str, Any]) -> None: col_map = {k: v for k, v in config.items() if k in ("cat", "num", "bin")} if col_map: self.pre_encoder.update( cat=col_map.get("cat", None), num=col_map.get("num", None), bin=col_map.get("bin", None), ) self.categorical_columns = self.pre_encoder.categorical_columns self.numerical_columns = self.pre_encoder.numerical_columns self.binary_columns = self.pre_encoder.binary_columns print("Extended column mappings in TransTab via table_encoder.update().") if "num_class" in config: self._adapt_to_new_num_class(config["num_class"]) def _adapt_to_new_num_class(self, num_class: int) -> None: r"""Rebuild ``self.clf`` and ``self.loss_fn`` for a new ``num_class``.""" if not hasattr(self, "clf") or num_class == getattr(self, "num_class", None): return self.num_class = num_class hidden_dim = self.cls_token.hidden_dim out_dim = 1 if num_class <= 2 else num_class self.clf = torch.nn.Sequential( torch.nn.LayerNorm(hidden_dim), torch.nn.Linear(hidden_dim, out_dim), ) if num_class > 2: self.loss_fn = torch.nn.CrossEntropyLoss(reduction="none") else: self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="none") print(f"Rebuilt classifier for num_class={num_class}.")
class TransTabClassifier(TransTab): r"""TransTab with a LayerNorm + Linear classification head on the ``[CLS]`` embedding. Supports binary (``num_class <= 2``, BCEWithLogitsLoss) and multi-class (CrossEntropyLoss) settings. Args: categorical_columns (List[str], optional): Categorical column names. Default: ``None``. numerical_columns (List[str], optional): Numerical column names. Default: ``None``. binary_columns (List[str], optional): Binary column names. Default: ``None``. num_class (int): Number of target classes. Default: ``2``. hidden_dim (int): Shared embedding dimensionality. Default: ``128``. num_layer (int): Number of Transformer layers. Default: ``2``. num_attention_head (int): Number of attention heads. Default: ``8``. hidden_dropout_prob (float): Dropout probability. Default: ``0.1``. ffn_dim (int): Feedforward inner dimension. Default: ``256``. activation (str): Feedforward activation. Default: ``"relu"``. tokenizer: Pre-trained tokenizer. Default: ``None``. **kwargs: Forwarded to :class:`TransTab`. Examples:: >>> from rllm.nn.models import TransTabClassifier >>> model = TransTabClassifier(num_class=2, hidden_dim=32) """ def __init__( self, categorical_columns: List[str] = None, numerical_columns: List[str] = None, binary_columns: List[str] = None, num_class: int = 2, hidden_dim: int = 128, num_layer: int = 2, num_attention_head: int = 8, hidden_dropout_prob: float = 0.1, ffn_dim: int = 256, activation: str = "relu", tokenizer=None, **kwargs, ) -> None: super().__init__( categorical_columns=categorical_columns, numerical_columns=numerical_columns, binary_columns=binary_columns, hidden_dim=hidden_dim, num_layer=num_layer, num_attention_head=num_attention_head, hidden_dropout_prob=hidden_dropout_prob, ffn_dim=ffn_dim, activation=activation, tokenizer=tokenizer, **kwargs, ) self.num_class = num_class out_dim = 1 if num_class <= 2 else num_class self.clf = torch.nn.Sequential( torch.nn.LayerNorm(hidden_dim), torch.nn.Linear(hidden_dim, out_dim), ) if num_class > 2: self.loss_fn = torch.nn.CrossEntropyLoss(reduction="none") else: self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="none") def forward( self, x: Union[pd.DataFrame, TableData, Dict[ColType, torch.Tensor]], y: Optional[Union[pd.Series, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: r""" Args: x (TableData): Input table batch. y (Tensor or Series, optional): Labels; computes mean loss when provided. Default: ``None``. Returns: Tuple[Tensor, Optional[Tensor]]: ``(logits [N,1] or [N,K], loss or None)``. """ cls_emb = super().forward(x) logits = self.clf(cls_emb) if y is not None: if isinstance(y, torch.Tensor): y_ts = y.to(self.device) else: y_ts = torch.tensor(y.values, device=self.device) if self.num_class > 2: y_ts = y_ts.long() loss = self.loss_fn(logits, y_ts) else: y_ts = y_ts.float() loss = self.loss_fn(logits.view(-1), y_ts) loss = loss.mean() else: loss = None return logits, loss class TransTabForCL(TransTab): r"""TransTab pre-training via vertical-partition contrastive learning (VPCL). Splits each table batch into ``num_partition`` overlapping column subsets, encodes each view independently, projects the ``[CLS]`` outputs, and optimises a supervised or self-supervised contrastive loss. Args: categorical_columns (List[str], optional): Categorical column names. Default: ``None``. numerical_columns (List[str], optional): Numerical column names. Default: ``None``. binary_columns (List[str], optional): Binary column names. Default: ``None``. hidden_dim (int): Shared embedding dimensionality. Default: ``128``. num_layer (int): Number of Transformer layers. Default: ``2``. num_attention_head (int): Number of attention heads. Default: ``8``. hidden_dropout_prob (float): Dropout probability. Default: ``0``. ffn_dim (int): Feedforward inner dimension. Default: ``256``. projection_dim (int): Projection head output dimension. Default: ``128``. overlap_ratio (float): Column overlap fraction between partitions, in :math:`[0,1)`. Default: ``0.1``. num_partition (int): Number of column subsets per sample. Default: ``2``. supervised (bool): Use supervised contrastive loss when ``True``. Default: ``True``. temperature (float): Contrastive loss temperature. Default: ``10``. base_temperature (float): Base temperature for loss normalisation. Default: ``10``. activation (str): Feedforward activation. Default: ``"relu"``. tokenizer: Pre-trained tokenizer. Default: ``None``. **kwargs: Forwarded to :class:`TransTab`. Examples:: >>> from rllm.nn.models import TransTabForCL >>> model = TransTabForCL(hidden_dim=32, num_partition=2) """ def __init__( self, categorical_columns=None, numerical_columns=None, binary_columns=None, hidden_dim=128, num_layer=2, num_attention_head=8, hidden_dropout_prob=0, ffn_dim=256, projection_dim=128, overlap_ratio=0.1, num_partition=2, supervised=True, temperature=10, base_temperature=10, activation="relu", tokenizer=None, **kwargs, ) -> None: super().__init__( categorical_columns=categorical_columns, numerical_columns=numerical_columns, binary_columns=binary_columns, hidden_dim=hidden_dim, num_layer=num_layer, num_attention_head=num_attention_head, hidden_dropout_prob=hidden_dropout_prob, ffn_dim=ffn_dim, activation=activation, tokenizer=tokenizer, **kwargs, ) assert ( num_partition > 0 ), f"number of contrastive subsets must be greater than 0, got {num_partition}" assert isinstance( num_partition, int ), f"number of constrative subsets must be int, got {type(num_partition)}" assert ( overlap_ratio >= 0 and overlap_ratio < 1 ), f"overlap_ratio must be in [0, 1), got {overlap_ratio}" self.projection_head = torch.nn.Linear(hidden_dim, projection_dim, bias=False) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() self.temperature = temperature self.base_temperature = base_temperature self.num_partition = num_partition self.overlap_ratio = overlap_ratio self.supervised = supervised self.self_sup_criterion = SelfSupervisedVPCL( temperature=self.temperature, base_temperature=self.base_temperature, similarity="dot", # "cosine" ) self.sup_criterion = SupervisedVPCL( temperature=self.temperature, base_temperature=self.base_temperature, similarity="dot", ) def forward(self, x, y=None): r""" Args: x (pd.DataFrame or TableData): Input table batch. y (Tensor, optional): Labels for supervised contrastive loss. Default: ``None``. Returns: Tuple[None, Tensor]: ``(None, loss)`` — no logits during pre-training. """ # Extract DataFrame from input if isinstance(x, pd.DataFrame): df = x elif hasattr(x, "df"): # TableData df = x.df else: raise ValueError( f"Expected input type pd.DataFrame or TableData, got {type(x)}" ) # Build positive pairs by splitting columns at DataFrame level sub_df_list = self._build_positive_pairs(df, self.num_partition) tokenizer_config = TokenizerConfig( tokenizer=self.pre_encoder.tokenizer, pad_token_id=self.pre_encoder.tokenizer.pad_token_id, tokenize_combine=True, include_colname=True, save_colname_token_ids=True, ) # Process each partition through the full pipeline feat_x_list = [] for sub_df in sub_df_list: # Build col_types for this subset if hasattr(x, "col_types"): sub_col_types = { col: x.col_types[col] for col in sub_df.columns if col in x.col_types } else: sub_col_types = self._infer_col_types(sub_df.columns) sub_table = TableData( df=sub_df, col_types=sub_col_types, convert_text_coltypes={ColType.CATEGORICAL}, tokenizer_config=tokenizer_config, ) # Process through table_encoder proc = self.pre_encoder(sub_table) emb = proc["embedding"] mask = proc["attention_mask"] # Add CLS token cls_out = self.cls_token(emb, attention_mask=mask) emb = cls_out["embedding"] mask = cls_out["attention_mask"] for conv in self.convs: enc_out = conv(x=emb, src_key_padding_mask=mask) emb = enc_out # Extract CLS representation and project feat = enc_out[:, 0, :] # [bs, hidden_dim] feat_proj = self.projection_head(feat) # [bs, proj_dim] feat_x_list.append(feat_proj) # Stack multi-view features: [bs, num_partition, proj_dim] feat_x_multiview = torch.stack(feat_x_list, dim=1) # Compute contrastive loss if y is not None and self.supervised: labels = y.to(self.device).long() loss = self.sup_criterion(feat_x_multiview, labels) else: loss = self.self_sup_criterion(feat_x_multiview) return None, loss def _infer_col_types(self, columns): col_types = {} for col in columns: if col in self.pre_encoder.categorical_columns: col_types[col] = ColType.TEXT elif col in self.pre_encoder.numerical_columns: col_types[col] = ColType.NUMERICAL elif col in self.pre_encoder.binary_columns: col_types[col] = ColType.BINARY else: col_types[col] = ColType.TEXT return col_types def _build_positive_pairs(self, x, n): x_cols = x.columns.tolist() sub_col_list = np.array_split(np.array(x_cols), n) len_cols = len(sub_col_list[0]) overlap = int(np.ceil(len_cols * (self.overlap_ratio))) sub_x_list = [] for i, sub_col in enumerate(sub_col_list): if overlap > 0 and i < n - 1: sub_col = np.concatenate([sub_col, sub_col_list[i + 1][:overlap]]) elif overlap > 0 and i == n - 1: sub_col = np.concatenate([sub_col, sub_col_list[i - 1][-overlap:]]) sub_x = x.copy()[sub_col] sub_x_list.append(sub_x) return sub_x_list