Source code for rllm.transforms.table_transforms.one_hot_transform

from __future__ import annotations

import torch.nn.functional as F

from rllm.types import ColType, StatType
from rllm.data import TableData
from .col_transform import ColTransform


[docs] class OneHotTransform(ColTransform): r"""One-hot encodes categorical features. Args: out_dim (int, optional): The output dimensionality for the one-hot encoded features. If set to 0, the dimensionality will be determined by the number of unique categories in the data (default: 0). """ def __init__( self, out_dim: int = 0, ) -> None: super().__init__() self.out_dim = out_dim def forward( self, data: TableData, ) -> TableData: if ColType.CATEGORICAL in data.feat_dict.keys(): stat_list = data.metadata[ColType.CATEGORICAL] feat = data.feat_dict[ColType.CATEGORICAL] # Determine the number of categories # If out_dim is not specified, use the maximum number of categories # If out_dim is specified, use the maximum of the specified value # and the number of categories self.num_categories = max(stats[StatType.COUNT] for stats in stat_list) one_hot_classes = ( self.num_categories if self.num_categories > self.out_dim else self.out_dim ) one_hot_feat = F.one_hot(feat, num_classes=one_hot_classes) data.feat_dict[ColType.CATEGORICAL] = one_hot_feat return data