Source code for rllm.nn.conv.table_conv.resnet_conv
from __future__ import annotations
import torch
from torch import Tensor
from torch.nn import (
Linear,
LayerNorm,
BatchNorm1d,
Dropout,
ReLU,
)
[docs]
class ResNetConv(torch.nn.Module):
r"""The ResNet-like TNN LayerConv introduced in the
`"Revisiting Deep Learning Models for Tabular Data"
<https://arxiv.org/abs/2106.11959>`_ paper.
This module applies a two-layer MLP block with optional normalization,
activation, and dropout, then adds a residual shortcut connection.
Args:
in_dim (int): Input feature dimensionality.
out_dim (int): Output feature dimensionality.
normalization (str | None): Normalization type. Supported values are
:obj:`"layer_norm"`, :obj:`"batch_norm"`, or :obj:`None`.
(default: :obj:`"layer_norm"`)
dropout (float): Dropout probability. (default: :obj:`0.0`)
Example:
>>> import torch
>>> conv = ResNetConv(in_dim=16, out_dim=32, normalization="layer_norm", dropout=0.1)
>>> x = torch.randn(64, 16)
>>> out = conv(x)
"""
def __init__(
self,
in_dim: int,
out_dim: int,
normalization: str | None = "layer_norm",
dropout: float = 0.0,
):
super().__init__()
self.lin1 = Linear(in_dim, out_dim)
self.lin2 = Linear(out_dim, out_dim)
self.norm1 = None
self.norm2 = None
if normalization == "layer_norm":
self.norm1 = LayerNorm(out_dim)
self.norm2 = LayerNorm(out_dim)
elif normalization == "batch_norm":
self.norm1 = BatchNorm1d(out_dim)
self.norm2 = BatchNorm1d(out_dim)
else:
self.norm1 = None
self.norm2 = None
if in_dim != out_dim:
self.short_cut = Linear(in_dim, out_dim)
else:
self.short_cut = None
self.relu = ReLU()
self.dropout = Dropout(dropout)
self.reset_parameters()
[docs]
def reset_parameters(self) -> None:
r"""Resets all learnable parameters of the module."""
self.lin1.reset_parameters()
self.lin2.reset_parameters()
if self.norm1 is not None:
self.norm1.reset_parameters()
if self.norm2 is not None:
self.norm2.reset_parameters()
if self.short_cut is not None:
self.short_cut.reset_parameters()
[docs]
def forward(self, x: Tensor) -> Tensor:
r"""Apply residual MLP transformation.
Args:
x (Tensor): Input tensor of shape :obj:`[..., in_dim]`.
Returns:
Tensor: Output tensor of shape :obj:`[..., out_dim]`.
"""
residual = x
x = self.lin1(x)
x = self.norm1(x) if self.norm1 is not None else x
x = self.relu(x)
x = self.dropout(x)
x = self.lin2(x)
x = self.norm2(x) if self.norm2 is not None else x
x = self.relu(x)
x = self.dropout(x)
if self.short_cut is not None:
residual = self.short_cut(residual)
x = x + residual
return x