Source code for rllm.nn.conv.table_conv.trompt_conv

from __future__ import annotations

from typing import Any, Dict, List

import torch
from torch import Tensor
import torch.nn.functional as F

from rllm.nn.encoder.col_encoder import EmbeddingEncoder, LinearEncoder
from rllm.nn.encoder import TablePreEncoder
from rllm.types import ColType


[docs] class TromptConv(torch.nn.Module): r"""The TromptConv Layer introduced in the `"Trompt: Towards a Better Deep Neural Network for Tabular Data" <https://arxiv.org/abs/2305.18446>`_ paper. Also it is konwn as TromptCell in the original paper. This layer first derives feature importance based on the `emb_column` and prompt embeddings `x_prompt`. Subsequently, it embeds the input features using a pre-encoder to obtain feature embeddings. Finally, it expands the features using the derived feature importance and the feature embeddings. Args: in_dim (int): Input dimensionality. out_dim (int): Output dimensionality, and hidden layer dimensionality. num_prompts (int): Number of prompts. num_groups (int): Number of groups for group normalization (default: 2). Example: >>> import torch >>> conv = TromptConv(in_dim=10, out_dim=16, num_prompts=4) >>> x = torch.randn(8, 10, 16) >>> x_prompt = torch.randn(8, 4, 16) >>> out = conv(x, x_prompt) """ def __init__( self, in_dim: int, out_dim: int, num_prompts: int, metadata: Dict[ColType, List[Dict[str, Any]]] | None = None, num_groups: int = 2, ): super().__init__() self.num_prompts = num_prompts # Initialize pre-encoder with column-specific encoders col_encoder_dict = { ColType.CATEGORICAL: EmbeddingEncoder( post_module=torch.nn.LayerNorm(out_dim) ), ColType.NUMERICAL: LinearEncoder( in_dim=in_dim, post_module=torch.nn.Sequential( torch.nn.ReLU(), torch.nn.LayerNorm(out_dim), ), ), } self.init_residual_encoder = TablePreEncoder( out_dim=out_dim, metadata=metadata, col_encoder_dict=col_encoder_dict, ) # Learnable parameters for feature importance and prompt embeddings self.emb_column = torch.nn.Parameter(torch.empty(in_dim, out_dim)) self.emb_prompt = torch.nn.Parameter(torch.empty(num_prompts, out_dim)) self.lin_se_prompt = torch.nn.Linear(out_dim * 2, out_dim) self.ln_column = torch.nn.LayerNorm(out_dim) self.ln_prompt = torch.nn.LayerNorm(out_dim) self.expand_weight = torch.nn.Parameter(torch.empty(num_prompts)) self.group_norm = torch.nn.GroupNorm( num_groups=num_groups, num_channels=num_prompts, ) self.reset_parameters() def reset_parameters(self): self.init_residual_encoder.reset_parameters() torch.nn.init.xavier_uniform_(self.emb_column) torch.nn.init.xavier_uniform_(self.emb_prompt) torch.nn.init.xavier_uniform_(self.lin_se_prompt.weight) torch.nn.init.zeros_(self.lin_se_prompt.bias) torch.nn.init.uniform_(self.expand_weight) self.ln_column.reset_parameters() self.ln_prompt.reset_parameters() self.group_norm.reset_parameters()
[docs] def forward( self, x: Tensor | Dict[ColType, Tensor], x_prompt: Tensor, ) -> Tensor: """Expand and aggregate feature embeddings conditioned on prompts. Args: x (Tensor | Dict[ColType, Tensor]): Input feature embeddings of shape ``[batch_size, in_dim, out_dim]`` or raw table feature dict. x_prompt (Tensor): Prompt embeddings of shape ``[batch_size, num_prompts, out_dim]``. Returns: Tensor: Aggregated prompt representations of shape ``[batch_size, num_prompts, out_dim]``. """ # Notation: B=batch_size, P=num_prompts, C=in_dim (feature columns), D=out_dim. # Part 1: Construct feature embeddings x = self.init_residual_encoder(x) # Part 2: Derive feature importances emb_column = self.ln_column(self.emb_column) emb_prompt = self.ln_prompt(self.emb_prompt) se_prompt = emb_prompt.unsqueeze(0).repeat(x.size(0), 1, 1) se_prompt_cat = torch.cat([se_prompt, x_prompt], dim=-1) se_prompt_cat_hat = self.lin_se_prompt(se_prompt_cat) + se_prompt + x_prompt se_column = emb_column.unsqueeze(0).repeat(x_prompt.size(0), 1, 1) m_importance = torch.einsum("ijl,ikl->ijk", se_prompt_cat_hat, se_column) # [B, P, C] m_importance = F.softmax(m_importance, dim=-1) m_importance = m_importance.unsqueeze(dim=-1) # Part 3: Expand feature embeddings to accommodate multiple prompts x_expand_weight = torch.einsum("ijl,k->ikjl", x, self.expand_weight) # [B, P, C, D] x_expand_weight = F.relu(x_expand_weight) x_expand_residual = x.unsqueeze(1).repeat(1, self.num_prompts, 1, 1) x = self.group_norm(x_expand_weight) + x_expand_residual # Residual connection # Part 4: Aggregate feature embeddings across prompts x = (x * m_importance).sum(dim=2) return x