Source code for rllm.nn.encoder.ft_transformer_pre_encoder
from __future__ import annotations
from typing import Any, Dict, List
from .table_pre_encoder import TablePreEncoder
from .col_encoder._embedding_encoder import EmbeddingEncoder
from .col_encoder._linear_encoder import LinearEncoder
from rllm.types import ColType
[docs]
class FTTransformerPreEncoder(TablePreEncoder):
r"""
The FTTransformerPreEncoder class is a specialized pre-encoder for the
FTTransformer model. It initializes a column-specific encoder dict for
categorical and numerical features based on the provided metadata.
Specifically, it uses `EmbeddingEncoder` for categorical features and
`LinearEncoder` for numerical features.
Args:
out_dim (int): The output dimensionality.
metadata (Dict[rllm.types.ColType, List[Dict[str, Any]]]):
Metadata for each column type, specifying the statistics and
properties of the columns.
in_dim (int, optional): The input dimensionality for numerical features
(default: :obj:`1`).
Example:
>>> from rllm.nn.encoder import FTTransformerPreEncoder
>>> encoder = FTTransformerPreEncoder(out_dim=32, metadata={})
"""
def __init__(
self,
out_dim: int,
metadata: Dict[ColType, List[Dict[str, Any]]],
in_dim: int = 1,
) -> None:
col_encoder_dict = {
ColType.CATEGORICAL: EmbeddingEncoder(),
ColType.NUMERICAL: LinearEncoder(in_dim=in_dim),
}
super().__init__(out_dim, metadata, col_encoder_dict)