Source code for rllm.nn.encoder.resnet_pre_encoder

from typing import Any, Dict, List

from rllm.types import ColType
from .table_pre_encoder import TablePreEncoder
from .col_encoder._embedding_encoder import EmbeddingEncoder
from .col_encoder._linear_encoder import LinearEncoder
from .col_encoder._textembedding_encoder import TextEmbeddingEncoder
from .col_encoder._timestamp_encoder import TimestampEncoder


[docs] class ResNetPreEncoder(TablePreEncoder): r"""The pre-encoder for ResNet TNN. This encoder builds column-type-specific pre-encoders, then delegates the shared table encoding pipeline to :class:`PreEncoder`. Args: out_dim (int): The output dimensionality. metadata (Dict[ColType, List[Dict[str, Any]]]): Metadata for each column type, specifying the statistics and properties of the columns. Returns: Encoded outputs are produced when inherited ``forward`` is called. Example: >>> from rllm.nn.encoder import ResNetPreEncoder >>> from rllm.types import ColType >>> metadata = { ... ColType.CATEGORICAL: [{"num_classes": 100}], ... ColType.NUMERICAL: [{"mean": 0.0, "std": 1.0}], ... } >>> encoder = ResNetPreEncoder(out_dim=32, metadata=metadata) """ def __init__( self, out_dim: int, metadata: Dict[ColType, List[Dict[str, Any]]], ) -> None: # Select one col-encoder per column type. col_encoder_dict = { ColType.CATEGORICAL: EmbeddingEncoder(), ColType.NUMERICAL: LinearEncoder(), ColType.TIMESTAMP: TimestampEncoder(), ColType.TEXT: TextEmbeddingEncoder(), } super().__init__(out_dim, metadata, col_encoder_dict)