Source code for rllm.transforms.table_transforms.stack_numerical

from __future__ import annotations

from rllm.types import ColType
from rllm.data import TableData
from .col_transform import ColTransform


[docs] class StackNumerical(ColTransform): r"""The StackNumerical class is designed to transform numerical features in tabular data by stacking them into a specified dimension. This transformation changes the shape of the numerical features from [batch_size, num_cols] to [batch_size, num_cols, out_dim], effectively replicating the values along the new dimension. Args: out_dim (int): The output dimensionality to which the numerical features will be stacked. """ def __init__( self, out_dim: int, ) -> None: super().__init__() self.out_dim = out_dim def forward( self, data: TableData, ) -> TableData: if ColType.NUMERICAL in data.feat_dict.keys(): feat = data.feat_dict[ColType.NUMERICAL] data.feat_dict[ColType.NUMERICAL] = feat.unsqueeze(2).repeat( 1, 1, self.out_dim ) return data