Source code for rllm.transforms.table_transforms.table_transform
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, List, Callable, Optional
import torch
from torch import Tensor
from rllm.data import TableData
from rllm.types import ColType, NAMode, StatType
def _reset_parameters_soft(module: torch.nn.Module):
r"""Call reset_parameters() only when it exists. Skip activation module."""
if hasattr(module, "reset_parameters") and callable(module.reset_parameters):
module.reset_parameters()
def _get_na_mask(tensor: Tensor) -> Tensor:
r"""Obtains the NA mask of the input :obj:`Tensor`.
Args:
tensor (Tensor): Input :obj:`Tensor`.
"""
if tensor.is_floating_point():
na_mask = torch.isnan(tensor)
else:
na_mask = tensor == -1
return na_mask
[docs]
class TableTransform(torch.nn.Module, ABC):
r"""Base class for table Transform. This module transforms tensor of some
specific columns type into 3-dimensional column-wise tensor that is input
into tabular deep learning models. Columns with same ColType will be
transformed into tensors. By default, it handles missing values (NaNs)
according to the specified `na_mode`.
Args:
out_dim (int): The output dim dimensionality
col_type (ColType): Column type used for NA-mode validation.
post_module (Module, optional): The post-hoc module applied to the
output, such as activation function and normalization. Must
preserve the shape of the output. If :obj:`None`, no module will be
applied to the output. (default: :obj:`None`)
na_mode (NAMode, optional): The instruction that indicates how to
impute NaN values. (default: :obj:`None`)
transforms (List[Callable], optional): A list of transformation
functions to be applied to the input data. Each function in the
list should take the input data as an argument and return the
transformed data. (default: :obj=`None`)
"""
def __init__(
self,
out_dim: Optional[int] = None,
col_type: Optional[ColType] = None,
post_module: Optional[torch.nn.Module] = None,
na_mode: Optional[Dict[StatType, NAMode]] = None,
transforms: Optional[List[Callable]] = None,
):
r"""Since many attributes are specified later,
this is a fake initialization"""
super().__init__()
if na_mode is not None:
if (
col_type == ColType.NUMERICAL
and na_mode not in NAMode.namode_for_col_type(ColType.NUMERICAL)
):
raise ValueError(f"{na_mode} cannot be used on numerical columns.")
if (
col_type == ColType.CATEGORICAL
and na_mode not in NAMode.namode_for_col_type(ColType.CATEGORICAL)
):
raise ValueError(f"{na_mode} cannot be used on categorical columns.")
else:
na_mode = {
ColType.NUMERICAL: NAMode.MEAN,
ColType.CATEGORICAL: NAMode.MOST_FREQUENT,
ColType.BINARY: NAMode.MOST_FREQUENT,
ColType.TEXT: NAMode.MOST_FREQUENT,
}
self.out_dim = out_dim
self.post_module = post_module
self.na_mode = na_mode
self.transforms = transforms
[docs]
@abstractmethod
def reset_parameters(self):
r"""Initialize the parameters of `post_module`."""
if self.post_module is not None:
if isinstance(self.post_module, torch.nn.Sequential):
for m in self.post_module:
_reset_parameters_soft(m)
else:
_reset_parameters_soft(self.post_module)
[docs]
def forward(
self,
data: TableData,
) -> Tensor:
# NaN handling of the input Tensor
data = self.nan_forward(data)
for transform in self.transforms or []:
data = transform(data)
return data
[docs]
def nan_forward(
self,
data: TableData,
) -> Tensor:
r"""Replace NaN values in input :obj:`Tensor` given
:obj:`na_mode`.
Args:
feat: Input :obj:`Tensor`.
"""
if self.na_mode is None:
return data
# Since we are not changing the number of items in each column, it's
# faster to just clone the values, while reusing the same offset
# object.
feats = data.get_feat_dict()
for col_type, feat in feats.items():
# Skip TEXT type as it's already tokenized (tuple of tensors)
# NaN handling for TEXT is done during tokenization
if col_type == ColType.TEXT:
continue
# Skip if col_type not in na_mode (defensive check)
if col_type not in self.na_mode:
continue
feat = self._fill_nan(feat, data.metadata[col_type], self.na_mode[col_type])
# Handle NaN in case na_mode is None
feats[col_type] = torch.nan_to_num(feat, nan=0)
data.feat_dict = feats
return data
def _fill_nan(
self,
feat: Tensor,
stats_list: Dict[StatType, float],
na_mode: NAMode,
) -> Tensor:
r"""Replace NaN values in input :obj:`Tensor` given :obj:`na_mode`."""
if isinstance(feat, Tensor):
# cache for future use
na_mask = _get_na_mask(feat)
if na_mask.any():
feat = feat.clone()
else:
return feat
else:
raise ValueError(f"Unrecognized type {type(feat)} in na_forward.")
fill_values = []
for col in range(feat.size(1)):
if na_mode == NAMode.MOST_FREQUENT:
fill_value = stats_list[col][StatType.MOST_FREQUENT]
elif na_mode == NAMode.MEAN:
fill_value = stats_list[col][StatType.MEAN]
elif na_mode == NAMode.ZERO:
fill_value = 0
else:
raise ValueError(f"Unsupported NA mode {self.na_mode}")
fill_values.append(fill_value)
if na_mask.ndim == 3:
# when feat is 3D, it is faster to iterate over columns
for col, fill_value in enumerate(fill_values):
col_data = feat[:, col]
col_na_mask = na_mask[:, col].any(dim=-1)
col_data[col_na_mask] = fill_value
else: # na_mask.ndim == 2
fill_values = torch.tensor(fill_values, device=feat.device)
if feat.size(-1) != fill_values.size(-1):
raise ValueError(
"Mismatched feature width during NA fill: "
f"{feat.size(-1)} vs {fill_values.size(-1)}."
)
feat = torch.where(na_mask, fill_values, feat)
# Add better safeguard here to make sure nans are actually
# replaced, expecially when nans are represented as -1's. They are
# very hard to catch as they won't error out.
filled_values = feat
if filled_values.is_floating_point():
if torch.isnan(filled_values).any():
raise ValueError("NaN values remain after NA filling.")
else:
if (filled_values == -1).any():
raise ValueError("Invalid sentinel -1 remains after NA filling.")
return feat