Source code for rllm.transforms.table_transforms.col_normalize

from __future__ import annotations

import torch

from rllm.types import ColType, StatType
from rllm.data import TableData
from .col_transform import ColTransform


[docs] class ColNormalize(ColTransform): r"""The ColNormalize class is designed to normalize numerical features in tabular data. This transformation standardizes the numerical features by subtracting the mean and dividing by the standard deviation. """ def __init__( self, ) -> None: super().__init__() def forward( self, data: TableData, ) -> TableData: if ColType.NUMERICAL in data.feat_dict.keys(): metadata = data.metadata[ColType.NUMERICAL] feat = data.feat_dict[ColType.NUMERICAL] mean = torch.tensor( [stats[StatType.MEAN] for stats in metadata], device=feat.device, dtype=feat.dtype, ) std = torch.tensor( [stats[StatType.STD] for stats in metadata], device=feat.device, dtype=feat.dtype, ) + 1e-6 feat = (feat - mean) / std data.feat_dict[ColType.NUMERICAL] = feat return data