rllm.nn.models.TransTab¶
- class rllm.nn.models.TransTab(categorical_columns: List[str] | None = None, numerical_columns: List[str] | None = None, binary_columns: List[str] | None = None, hidden_dim: int = 128, num_layer: int = 2, num_attention_head: int = 8, hidden_dropout_prob: float = 0.1, layer_norm_eps: float = 1e-05, ffn_dim: int = 256, activation: str = 'relu', projection_dim: int = 128, overlap_ratio: float = 0.1, num_partition: int = 2, supervised: bool = True, temperature: float = 10.0, base_temperature: float = 10.0, tokenizer=None, **kwargs)[source]¶
Bases:
ModuleBase TransTab encoder for tabular data (“TransTab”).
Encodes column names and cell values into token embeddings, prepends a learnable
[CLS]token, and refines the sequence throughnum_layerTransformer layers. The final[CLS]position is returned as a fixed-size table-level embedding.- Parameters:
categorical_columns (List[str], optional) – Categorical column names. Default:
None.numerical_columns (List[str], optional) – Numerical column names. Default:
None.binary_columns (List[str], optional) – Binary column names. Default:
None.hidden_dim (int) – Shared embedding dimensionality. Default:
128.num_layer (int) – Number of Transformer layers. Default:
2.num_attention_head (int) – Number of attention heads. Default:
8.hidden_dropout_prob (float) – Dropout probability. Default:
0.1.layer_norm_eps (float) – LayerNorm \(\varepsilon\). Default:
1e-5.ffn_dim (int) – Feedforward inner dimension. Default:
256.activation (str) – Feedforward activation. Default:
"relu".tokenizer – Pre-trained tokenizer; created automatically when
None. Default:None.**kwargs – Forwarded to
TransTabPreEncoder.
Examples:
>>> from rllm.nn.models import TransTab >>> model = TransTab(hidden_dim=32, num_layer=1, num_attention_head=4)
- forward(x: DataFrame | TableData | Dict[ColType, Tensor], y: Tensor | None = None) Tensor[source]¶
Encode a table batch into a
[CLS]embedding of shape \((N, H)\).- Parameters:
x (TableData) – Input table batch.
y (Tensor, optional) – Unused; kept for subclass API symmetry. Default:
None.
- Returns:
[CLS]embeddings of shape \((N, H)\).- Return type:
Tensor