Source code for rllm.nn.encoder.transtab_pre_encoder

from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import json
import os

import numpy as np
import pandas as pd
import torch
from torch import Tensor
from transformers import BertTokenizerFast

from .table_pre_encoder import TablePreEncoder
from .col_encoder._transtab_num_embedding_encoder import TransTabNumEmbeddingEncoder
from .col_encoder._transtab_word_embedding_encoder import TransTabWordEmbeddingEncoder

from rllm.types import ColType
from rllm.data.table_data import TableData


[docs] class TransTabPreEncoder(TablePreEncoder): r"""Pre-encoder for TransTab (`"TransTab" <https://arxiv.org/abs/2205.09328>`_). Converts a :class:`~rllm.data.table_data.TableData` ``feat_dict`` into token embeddings consumable by downstream Transformer layers, handling tokenizer management, column deduplication, and word/numeric sub-encoders. Args: out_dim (int): Output embedding dimensionality (d_model). metadata (Dict[ColType, List[Dict[str, Any]]]): Per-column statistics metadata. categorical_columns (List[str], optional): Categorical column names. Default: ``None``. numerical_columns (List[str], optional): Numerical column names. Default: ``None``. binary_columns (List[str], optional): Binary column names. Default: ``None``. tokenizer (BertTokenizerFast, optional): Pre-initialised tokenizer; takes precedence over ``tokenizer_dir``. Default: ``None``. tokenizer_dir (str): Tokenizer directory; ``"bert-base-uncased"`` is downloaded here when absent. Default: ``"./tokenizer"``. hidden_dropout_prob (float): Dropout for the word-embedding sub-encoder. Default: ``0.0``. layer_norm_eps (float): LayerNorm :math:`\varepsilon`. Default: ``1e-5``. use_align_layer (bool): Apply a linear projection before concatenation. Default: ``True``. disable_tokenizer_parallel (bool): Set ``TOKENIZERS_PARALLELISM=false``. Default: ``True``. ignore_duplicate_cols (bool): Auto-rename duplicates instead of raising. Default: ``False``. """ def __init__( self, out_dim: int, metadata: Dict[ColType, List[Dict[str, Any]]], categorical_columns: Optional[List[str]] = None, numerical_columns: Optional[List[str]] = None, binary_columns: Optional[List[str]] = None, tokenizer: Optional[BertTokenizerFast] = None, tokenizer_dir: str = "./tokenizer", hidden_dropout_prob: float = 0.0, layer_norm_eps: float = 1e-5, use_align_layer: bool = True, disable_tokenizer_parallel: bool = True, ignore_duplicate_cols: bool = False, ) -> None: self._init_tokenizer(tokenizer, tokenizer_dir, disable_tokenizer_parallel) self._init_columns( categorical_columns, numerical_columns, binary_columns, ignore_duplicate_cols, ) col_encoder_dict = { ColType.CATEGORICAL: TransTabWordEmbeddingEncoder( vocab_size=self.tokenizer.vocab_size, out_dim=out_dim, padding_idx=self.tokenizer.pad_token_id, hidden_dropout_prob=hidden_dropout_prob, layer_norm_eps=layer_norm_eps, ), ColType.BINARY: TransTabWordEmbeddingEncoder( vocab_size=self.tokenizer.vocab_size, out_dim=out_dim, padding_idx=self.tokenizer.pad_token_id, hidden_dropout_prob=hidden_dropout_prob, layer_norm_eps=layer_norm_eps, ), ColType.NUMERICAL: TransTabNumEmbeddingEncoder(hidden_dim=out_dim), } super().__init__(out_dim, metadata, col_encoder_dict) self.align_layer = ( torch.nn.Linear(out_dim, out_dim, bias=False) if use_align_layer else torch.nn.Identity() ) # Tokenizer management def _init_tokenizer( self, tokenizer: Optional[BertTokenizerFast], tokenizer_dir: str, disable_tokenizer_parallel: bool, ) -> None: if tokenizer is not None: self.tokenizer = tokenizer elif os.path.exists(tokenizer_dir): self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_dir) else: self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") self.tokenizer.save_pretrained(tokenizer_dir) self.tokenizer.model_max_length = 512 if disable_tokenizer_parallel: os.environ["TOKENIZERS_PARALLELISM"] = "false" # Column-name management @staticmethod def _deduplicate_preserve_order(seq: List[str]) -> List[str]: seen: set[str] = set() result: List[str] = [] for x in seq: if x not in seen: seen.add(x) result.append(x) return result def _init_columns( self, categorical_columns: Optional[List[str]], numerical_columns: Optional[List[str]], binary_columns: Optional[List[str]], ignore_duplicate_cols: bool, ) -> None: self.categorical_columns: List[str] = ( self._deduplicate_preserve_order(categorical_columns) if categorical_columns else [] ) self.numerical_columns: List[str] = ( self._deduplicate_preserve_order(numerical_columns) if numerical_columns else [] ) self.binary_columns: List[str] = ( self._deduplicate_preserve_order(binary_columns) if binary_columns else [] ) self.ignore_duplicate_cols = ignore_duplicate_cols col_ok, dup = self._check_column_overlap( self.categorical_columns, self.numerical_columns, self.binary_columns ) if not col_ok: if not self.ignore_duplicate_cols: for c in dup: print( f"ERROR: Find duplicate cols named `{c}`; " f"set ignore_duplicate_cols=True to auto-resolve." ) raise ValueError("Column overlap detected; aborting.") else: self._solve_duplicate_cols(dup)
[docs] def update( self, cat: Optional[List[str]] = None, num: Optional[List[str]] = None, bin: Optional[List[str]] = None, ) -> None: r"""Extend column lists with new names and recheck for duplicates. Args: cat (List[str], optional): New categorical columns. num (List[str], optional): New numerical columns. bin (List[str], optional): New binary columns. Raises: ValueError: On duplicate columns when ``ignore_duplicate_cols`` is ``False``. """ if cat: self.categorical_columns.extend(cat) self.categorical_columns = list(set(self.categorical_columns)) if num: self.numerical_columns.extend(num) self.numerical_columns = list(set(self.numerical_columns)) if bin: self.binary_columns.extend(bin) self.binary_columns = list(set(self.binary_columns)) col_ok, dup = self._check_column_overlap( self.categorical_columns, self.numerical_columns, self.binary_columns ) if not col_ok: if not self.ignore_duplicate_cols: for c in dup: print( f"ERROR: Find duplicate cols named `{c}`; " f"set ignore_duplicate_cols=True to auto-resolve." ) raise ValueError("Column overlap detected after update; aborting.") else: self._solve_duplicate_cols(dup)
@staticmethod def _check_column_overlap( cat_cols: Optional[List[str]] = None, num_cols: Optional[List[str]] = None, bin_cols: Optional[List[str]] = None, ) -> Tuple[bool, List[str]]: all_cols: List[str] = [] if cat_cols: all_cols += cat_cols if num_cols: all_cols += num_cols if bin_cols: all_cols += bin_cols if not all_cols: print("WARNING: No columns specified; default to categorical.") return True, [] counter = collections.Counter(all_cols) dup = [col for col, cnt in counter.items() if cnt > 1] return len(dup) == 0, dup def _solve_duplicate_cols(self, duplicate_cols: List[str]) -> None: for col in duplicate_cols: print(f"WARNING: Auto-resolving duplicate column `{col}`") if col in self.categorical_columns: self.categorical_columns.remove(col) self.categorical_columns.append(f"[cat]{col}") if col in self.numerical_columns: self.numerical_columns.remove(col) self.numerical_columns.append(f"[num]{col}") if col in self.binary_columns: self.binary_columns.remove(col) self.binary_columns.append(f"[bin]{col}") # feat_dict adaptation (convert TableData feat_dict → TransTab layout) def _adapt_feat_dict( self, feat_dict: Dict[ColType, Tensor | Tuple[Tensor, Tensor]], colname_token_ids: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, shuffle: bool = False, ) -> Dict[str, Tensor | None]: r"""Adapt a ``feat_dict`` from :class:`TableData` into the TransTab tensor layout. Args: feat_dict: ``{ColType.TEXT: (ids [B,L], mask [B,L]), ColType.NUMERICAL: [B,C], ...}`` colname_token_ids: Mapping from column name to ``(token_ids, attention_mask)``. shuffle (bool): Randomly shuffle column order within each type. Default: ``False``. Returns: Dict with keys ``x_num``, ``num_col_input_ids``, ``num_att_mask``, ``x_cat_input_ids``, ``cat_att_mask``, ``x_bin_input_ids``, ``bin_att_mask``. """ out: Dict[str, Tensor | None] = { "x_num": None, "num_col_input_ids": None, "num_att_mask": None, "x_cat_input_ids": None, "cat_att_mask": None, "x_bin_input_ids": None, "bin_att_mask": None, } # TEXT (mapped to TransTab categorical) if ColType.TEXT in feat_dict: text_data = feat_dict[ColType.TEXT] if isinstance(text_data, tuple): out["x_cat_input_ids"] = text_data[0].long() out["cat_att_mask"] = text_data[1].long() else: raise ValueError( "TEXT features must be tokenized (tuple of ids and mask)" ) if ColType.NUMERICAL in feat_dict: out["x_num"] = feat_dict[ColType.NUMERICAL].float() if colname_token_ids is not None: num_cols = [ c for c in colname_token_ids.keys() if c in self.numerical_columns ] if shuffle: np.random.shuffle(num_cols) if num_cols: num_ids_list = [colname_token_ids[c][0] for c in num_cols] num_mask_list = [colname_token_ids[c][1] for c in num_cols] out["num_col_input_ids"] = torch.stack(num_ids_list, dim=0).long() out["num_att_mask"] = torch.stack(num_mask_list, dim=0).long() if ColType.BINARY in feat_dict: x_bin = feat_dict[ColType.BINARY] if colname_token_ids is not None: bin_cols = [ c for c in colname_token_ids.keys() if c in self.binary_columns ] if shuffle: np.random.shuffle(bin_cols) if bin_cols: batch_size = x_bin.shape[0] bin_texts: List[str] = [] for i in range(batch_size): active_cols = [ col for j, col in enumerate(bin_cols) if j < x_bin.shape[1] and x_bin[i, j].item() > 0.5 ] bin_texts.append(" ".join(active_cols)) tokens = self.tokenizer( bin_texts, padding=True, truncation=True, add_special_tokens=False, return_tensors="pt", ) if tokens["input_ids"].shape[1] > 0: out["x_bin_input_ids"] = tokens["input_ids"] out["bin_att_mask"] = tokens["attention_mask"] return out # Encoding helpers @property def device(self) -> torch.device: return next(self.parameters()).device def _encode_feat_dict( self, feat_dict: Dict[ColType, Tensor | Tuple[Tensor, ...]], ) -> Dict[ColType, Tensor]: feat_encoded: Dict[ColType, Tensor] = {} for col_type, feat in feat_dict.items(): if col_type == ColType.NUMERICAL: col_ids, col_mask, raw_vals = feat token_emb = self.col_encoder_dict[ColType.CATEGORICAL.value](col_ids) mask = col_mask.unsqueeze(-1) token_emb = token_emb * mask col_emb = token_emb.sum(1) / mask.sum(1) num_emb = self.col_encoder_dict[ColType.NUMERICAL.value]( col_emb, raw_vals=raw_vals ) feat_encoded[col_type] = num_emb else: if isinstance(feat, tuple): input_ids = feat[0] else: input_ids = feat feat_encoded[col_type] = self.col_encoder_dict[col_type.value]( input_ids ) return feat_encoded def _collect_masks( self, feat_dict: Dict[ColType, Tensor | Tuple[Tensor, ...]], emb_dict: Dict[ColType, Tensor], df_masks: Optional[Dict[str, Tensor]], ) -> Dict[ColType, Tensor]: masks: Dict[ColType, Tensor] = {} if ColType.NUMERICAL in emb_dict: B, n_num, _ = emb_dict[ColType.NUMERICAL].shape masks[ColType.NUMERICAL] = torch.ones(B, n_num, device=self.device) if df_masks is not None: if "cat_att_mask" in df_masks and ColType.CATEGORICAL in emb_dict: masks[ColType.CATEGORICAL] = ( df_masks["cat_att_mask"].to(self.device).float() ) if "bin_att_mask" in df_masks and ColType.BINARY in emb_dict: masks[ColType.BINARY] = df_masks["bin_att_mask"].to(self.device).float() else: for ct in (ColType.CATEGORICAL, ColType.BINARY): if ct in emb_dict: feat = feat_dict.get(ct) if isinstance(feat, tuple) and len(feat) >= 2: masks[ct] = feat[1].to(self.device).float() else: B, n_cols, _ = emb_dict[ct].shape masks[ct] = torch.ones(B, n_cols, device=self.device) return masks def _align_and_concat( self, emb_dict: Dict[ColType, Tensor], masks: Dict[ColType, Tensor], ) -> Dict[str, Tensor]: emb_list: List[Tensor] = [] mask_list: List[Tensor] = [] if ColType.NUMERICAL in emb_dict: emb_list.append(self.align_layer(emb_dict[ColType.NUMERICAL])) mask_list.append(masks[ColType.NUMERICAL]) if ColType.CATEGORICAL in emb_dict: emb_list.append(self.align_layer(emb_dict[ColType.CATEGORICAL])) mask_list.append(masks[ColType.CATEGORICAL]) if ColType.BINARY in emb_dict: emb_list.append(self.align_layer(emb_dict[ColType.BINARY])) mask_list.append(masks[ColType.BINARY]) if len(emb_list) == 0: raise ValueError("No features were encoded; check column configuration.") all_emb = torch.cat(emb_list, dim=1) # [B, total_seq_len, D] all_mask = torch.cat(mask_list, dim=1) # [B, total_seq_len] return {"embedding": all_emb, "attention_mask": all_mask}
[docs] def forward( self, x: Union[pd.DataFrame, Dict[ColType, Tensor | Tuple[Tensor, ...]], TableData], *, shuffle: bool = False, align_and_concat: bool = True, return_dict: bool = False, requires_grad: bool = False, ) -> Union[Dict[str, Tensor], Dict[ColType, Tensor], Tensor]: r"""Encode a table batch into embeddings. Args: x (TableData): Materialised input table batch. shuffle (bool): Shuffle column order within each type. Default: ``False``. align_and_concat (bool): Apply alignment projection and concatenate all type embeddings. Default: ``True``. return_dict (bool): Return ``Dict[ColType, Tensor]`` instead of a concatenated tensor when ``align_and_concat=False``. Default: ``False``. requires_grad (bool): Enable gradients during encoding. Default: ``False``. Returns: ``{"embedding": [B, S, H], "attention_mask": [B, S]}`` when ``align_and_concat=True``; otherwise a dict or concatenated tensor. """ grad_ctx = ( (lambda: torch.enable_grad()) if requires_grad else (lambda: torch.no_grad()) ) with grad_ctx(): if isinstance(x, TableData) or hasattr(x, "feat_dict"): if ( hasattr(x, "if_materialized") and callable(x.if_materialized) and not x.if_materialized() ): raise ValueError( "TableData must be materialized before passing to " "TransTabPreEncoder. Call table_data.lazy_materialize() first." ) data = self._adapt_feat_dict( feat_dict=x.feat_dict, colname_token_ids=getattr(x, "colname_token_ids", None), shuffle=shuffle, ) feat_dict: Dict[ColType, Tensor | Tuple[Tensor, ...]] = {} if data["x_cat_input_ids"] is not None: feat_dict[ColType.CATEGORICAL] = ( data["x_cat_input_ids"].to(self.device), data["cat_att_mask"].to(self.device), ) if data["x_bin_input_ids"] is not None: feat_dict[ColType.BINARY] = ( data["x_bin_input_ids"].to(self.device), data["bin_att_mask"].to(self.device), ) if data["x_num"] is not None: feat_dict[ColType.NUMERICAL] = ( data["num_col_input_ids"].to(self.device), data["num_att_mask"].to(self.device), data["x_num"].to(self.device), ) emb_dict = self._encode_feat_dict(feat_dict) if not align_and_concat: if return_dict: return emb_dict return ( torch.cat(list(emb_dict.values()), dim=1) if len(emb_dict) > 0 else None ) df_masks = { "cat_att_mask": data["cat_att_mask"], "bin_att_mask": data["bin_att_mask"], } masks = self._collect_masks(feat_dict, emb_dict, df_masks=df_masks) return self._align_and_concat(emb_dict, masks) else: raise TypeError( "TransTabPreEncoder.forward: x must be a TableData or an " "object with a feat_dict attribute." )
[docs] def save(self, path: str) -> None: r"""Save tokenizer, column config, and encoder weights to ``path``.""" # Tokenizer & column config (backward-compatible directory layout) save_path = os.path.join(path, "extractor") os.makedirs(save_path, exist_ok=True) self.tokenizer.save_pretrained(os.path.join(save_path, "tokenizer")) col_type_dict = { "categorical": self.categorical_columns, "numerical": self.numerical_columns, "binary": self.binary_columns, } with open( os.path.join(save_path, "extractor.json"), "w", encoding="utf-8" ) as f: json.dump(col_type_dict, f, ensure_ascii=False) # Encoder weights os.makedirs(path, exist_ok=True) encoder_path = os.path.join(path, "input_encoder.bin") torch.save(self.state_dict(), encoder_path) print(f"Saved TransTabPreEncoder weights to {encoder_path}")
[docs] def load(self, ckpt_dir: str) -> None: r"""Restore tokenizer, column config, and encoder weights from ``ckpt_dir``.""" tokenizer_path = os.path.join(ckpt_dir, "extractor", "tokenizer") coltype_path = os.path.join(ckpt_dir, "extractor", "extractor.json") self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path) with open(coltype_path, "r", encoding="utf-8") as f: col_type_dict = json.load(f) self.categorical_columns = col_type_dict.get("categorical", []) self.numerical_columns = col_type_dict.get("numerical", []) self.binary_columns = col_type_dict.get("binary", []) print(f"Loaded column configuration from {coltype_path}") encoder_path = os.path.join(ckpt_dir, "input_encoder.bin") try: state_dict = torch.load(encoder_path, map_location="cpu", weights_only=True) except TypeError: state_dict = torch.load(encoder_path, map_location="cpu") missing, unexpected = self.load_state_dict(state_dict, strict=False) print(f"Loaded TransTabPreEncoder weights from {encoder_path}") print(f" Missing keys: {missing}") print(f" Unexpected keys: {unexpected}")