Source code for rllm.nn.encoder.tab_transformer_pre_encoder

from __future__ import annotations
from typing import Any, Dict, List

from .col_encoder._embedding_encoder import EmbeddingEncoder
from .col_encoder._reshape_encoder import ReshapeEncoder
from .table_pre_encoder import TablePreEncoder
from rllm.types import ColType


[docs] class TabTransformerPreEncoder(TablePreEncoder): r"""The TabTransformerEncoder class is a specialized pre-encoder for the TabTransformer model. It initializes a column-specific encoder dict for categorical and numerical features based on the provided metadata. Specifically, it uses `EmbeddingEncoder` for categorical features and `ReshapeEncoder` for numerical features. Args: out_dim (int): The output dimensionality. metadata (Dict[ColType, List[Dict[str, Any]]]): Metadata for each column type, specifying the statistics and properties of the columns. """ def __init__( self, out_dim: int, metadata: Dict[ColType, List[Dict[str, Any]]], ) -> None: col_encoder_dict = { ColType.CATEGORICAL: EmbeddingEncoder(), ColType.NUMERICAL: ReshapeEncoder(), } super().__init__(out_dim, metadata, col_encoder_dict)