from __future__ import annotations
from typing import Tuple
import torch
from torch import Tensor
class GLULayer(torch.nn.Module):
r"""Gated Linear Unit (GLU) layer used by ExcelFormer.
The layer first projects input features from ``in_dim`` to ``2 * out_dim``.
The projected tensor is split into value and gate parts, and the output is
computed as ``value * tanh(gate)``.
Args:
in_dim (int): Input feature dimensionality.
out_dim (int): Output feature dimensionality.
Example:
>>> import torch
>>> layer = GLULayer(in_dim=32, out_dim=32)
>>> x = torch.randn(8, 10, 32)
>>> out = layer(x)
"""
def __init__(
self,
in_dim,
out_dim,
):
super().__init__()
self.fc = torch.nn.Linear(in_dim, 2 * out_dim)
self.reset_parameters()
def reset_parameters(self) -> None:
self.fc.reset_parameters()
def forward(self, x: Tensor) -> Tensor:
"""Apply linear projection and gated activation.
Args:
x (Tensor): Input tensor of shape ``[batch_size, num_cols, in_dim]``.
Returns:
Tensor: Output tensor of shape ``[batch_size, num_cols, out_dim]``.
"""
x = self.fc(x)
x, gates = x.chunk(2, dim=2)
return x * torch.tanh(gates)
class SemiPermeableAttention(torch.nn.Module):
r"""Semi-Permeable Attention module propose in the
`"ExcelFormer: A neural network surpassing GBDTs on tabular data"`
<https://arxiv.org/abs/2301.02819>`_ paper.
This module applies causal-style column-wise self-attention to model
dependencies among tabular feature tokens.
Args:
dim (int): Input dimensionality
num_heads (int): Number of heads in Attention module (default: :obj:`8`)
head_dim(int): Dimension of each attention head (default: :obj:`16`)
dropout (float): Percentage of random deactivation (default: :obj:`0.`)
Example:
>>> import torch
>>> attn = SemiPermeableAttention(dim=32, num_heads=4, head_dim=8, dropout=0.1)
>>> x = torch.randn(16, 12, 32)
>>> out = attn(x)
>>> out.shape
torch.Size([16, 12, 32])
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
head_dim: int = 16,
dropout: float = 0.0,
):
super().__init__()
inner_dim = head_dim * num_heads
self.num_heads = num_heads
self.scale = head_dim**-0.5
self.to_qkv = torch.nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = torch.nn.Linear(inner_dim, dim)
self.dropout = torch.nn.Dropout(dropout)
def _rearrange_qkv(self, x: Tensor) -> Tensor:
# reshape b n (h d) -> b h n d
b, num_cols, dim = x.shape
d_head = dim // self.num_heads
x = x.reshape(b, num_cols, self.num_heads, d_head)
x = x.permute(0, 2, 1, 3)
return x
def forward(self, x: Tensor) -> Tensor:
"""Compute semi-permeable self-attention on tabular tokens.
Args:
x (Tensor): Input tensor of shape ``[batch_size, num_cols, dim]``.
Returns:
Tensor: Output tensor of shape ``[batch_size, num_cols, dim]``.
"""
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q = self._rearrange_qkv(q)
k = self._rearrange_qkv(k)
v = self._rearrange_qkv(v)
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
mask = self.get_attention_mask(input_shape=sim.size(), device=sim.device)
attn = (sim + mask) * self.scale
attn = attn.softmax(dim=-1)
dropped_attn = self.dropout(attn)
out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v)
# reshape b h n d -> b n (h d)
out = out.permute(0, 2, 1, 3)
out = out.reshape(out.size(0), out.size(1), -1)
return self.to_out(out)
def reset_parameters(self) -> None:
self.to_qkv.reset_parameters()
self.to_out.reset_parameters()
def get_attention_mask(self, input_shape: Tuple, device) -> Tensor:
bs, num_heads, seq_len, _ = input_shape
seq_ids = torch.arange(seq_len, device=device)
attention_mask = (
seq_ids[None, None, :].repeat(bs, seq_len, 1) <= seq_ids[None, :, None]
)
attention_mask = (1.0 - attention_mask.float()) * -1e4
attention_mask = attention_mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
return attention_mask