rllm.nn.conv.table_conv.TromptConv¶
- class rllm.nn.conv.table_conv.TromptConv(in_dim: int, out_dim: int, num_prompts: int, metadata: Dict[ColType, List[Dict[str, Any]]] | None = None, num_groups: int = 2)[source]¶
Bases:
ModuleThe TromptConv Layer introduced in the “Trompt: Towards a Better Deep Neural Network for Tabular Data” paper. Also it is konwn as TromptCell in the original paper.
This layer first derives feature importance based on the emb_column and prompt embeddings x_prompt. Subsequently, it embeds the input features using a pre-encoder to obtain feature embeddings. Finally, it expands the features using the derived feature importance and the feature embeddings.
- Parameters:
in_dim (int) – Input dimensionality.
out_dim (int) – Output dimensionality, and hidden layer dimensionality.
num_prompts (int) – Number of prompts.
num_groups (int) – Number of groups for group normalization (default: 2).
Example
>>> import torch >>> conv = TromptConv(in_dim=10, out_dim=16, num_prompts=4) >>> x = torch.randn(8, 10, 16) >>> x_prompt = torch.randn(8, 4, 16) >>> out = conv(x, x_prompt)
- forward(x: Tensor | Dict[ColType, Tensor], x_prompt: Tensor) Tensor[source]¶
Expand and aggregate feature embeddings conditioned on prompts.
- Parameters:
x (Tensor | Dict[ColType, Tensor]) – Input feature embeddings of shape
[batch_size, in_dim, out_dim]or raw table feature dict.x_prompt (Tensor) – Prompt embeddings of shape
[batch_size, num_prompts, out_dim].
- Returns:
Aggregated prompt representations of shape
[batch_size, num_prompts, out_dim].- Return type:
Tensor