rllm.nn.conv.table_conv.TromptConv

class rllm.nn.conv.table_conv.TromptConv(in_dim: int, out_dim: int, num_prompts: int, metadata: Dict[ColType, List[Dict[str, Any]]] | None = None, num_groups: int = 2)[source]

Bases: Module

The TromptConv Layer introduced in the “Trompt: Towards a Better Deep Neural Network for Tabular Data” paper. Also it is konwn as TromptCell in the original paper.

This layer first derives feature importance based on the emb_column and prompt embeddings x_prompt. Subsequently, it embeds the input features using a pre-encoder to obtain feature embeddings. Finally, it expands the features using the derived feature importance and the feature embeddings.

Parameters:
  • in_dim (int) – Input dimensionality.

  • out_dim (int) – Output dimensionality, and hidden layer dimensionality.

  • num_prompts (int) – Number of prompts.

  • num_groups (int) – Number of groups for group normalization (default: 2).

Example

>>> import torch
>>> conv = TromptConv(in_dim=10, out_dim=16, num_prompts=4)
>>> x = torch.randn(8, 10, 16)
>>> x_prompt = torch.randn(8, 4, 16)
>>> out = conv(x, x_prompt)
forward(x: Tensor | Dict[ColType, Tensor], x_prompt: Tensor) Tensor[source]

Expand and aggregate feature embeddings conditioned on prompts.

Parameters:
  • x (Tensor | Dict[ColType, Tensor]) – Input feature embeddings of shape [batch_size, in_dim, out_dim] or raw table feature dict.

  • x_prompt (Tensor) – Prompt embeddings of shape [batch_size, num_prompts, out_dim].

Returns:

Aggregated prompt representations of shape [batch_size, num_prompts, out_dim].

Return type:

Tensor