rllm.nn.models.TransTab

class rllm.nn.models.TransTab(categorical_columns: List[str] | None = None, numerical_columns: List[str] | None = None, binary_columns: List[str] | None = None, hidden_dim: int = 128, num_layer: int = 2, num_attention_head: int = 8, hidden_dropout_prob: float = 0.1, layer_norm_eps: float = 1e-05, ffn_dim: int = 256, activation: str = 'relu', projection_dim: int = 128, overlap_ratio: float = 0.1, num_partition: int = 2, supervised: bool = True, temperature: float = 10.0, base_temperature: float = 10.0, tokenizer=None, **kwargs)[source]

Bases: Module

Base TransTab encoder for tabular data (“TransTab”).

Encodes column names and cell values into token embeddings, prepends a learnable [CLS] token, and refines the sequence through num_layer Transformer layers. The final [CLS] position is returned as a fixed-size table-level embedding.

Parameters:
  • categorical_columns (List[str], optional) – Categorical column names. Default: None.

  • numerical_columns (List[str], optional) – Numerical column names. Default: None.

  • binary_columns (List[str], optional) – Binary column names. Default: None.

  • hidden_dim (int) – Shared embedding dimensionality. Default: 128.

  • num_layer (int) – Number of Transformer layers. Default: 2.

  • num_attention_head (int) – Number of attention heads. Default: 8.

  • hidden_dropout_prob (float) – Dropout probability. Default: 0.1.

  • layer_norm_eps (float) – LayerNorm \(\varepsilon\). Default: 1e-5.

  • ffn_dim (int) – Feedforward inner dimension. Default: 256.

  • activation (str) – Feedforward activation. Default: "relu".

  • tokenizer – Pre-trained tokenizer; created automatically when None. Default: None.

  • **kwargs – Forwarded to TransTabPreEncoder.

Examples:

>>> from rllm.nn.models import TransTab
>>> model = TransTab(hidden_dim=32, num_layer=1, num_attention_head=4)
forward(x: DataFrame | TableData | Dict[ColType, Tensor], y: Tensor | None = None) Tensor[source]

Encode a table batch into a [CLS] embedding of shape \((N, H)\).

Parameters:
  • x (TableData) – Input table batch.

  • y (Tensor, optional) – Unused; kept for subclass API symmetry. Default: None.

Returns:

[CLS] embeddings of shape \((N, H)\).

Return type:

Tensor

load(ckpt_dir: str) None[source]

Restore model weights from ckpt_dir (strict=False).

save(ckpt_dir: str) None[source]

Save model weights to ckpt_dir.