rllm.nn.conv.table_conv.TransTabConv

class rllm.nn.conv.table_conv.TransTabConv(conv_dim, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, layer_norm_eps=1e-05, batch_first=True, norm_first=False, use_layer_norm=True)[source]

Bases: Module

Single Transformer encoder layer for TransTab (“TransTab”).

Combines multi-head self-attention with a gated feedforward network, residual connections, dropout, and optional LayerNorm.

Parameters:
  • conv_dim (int) – Input/output embedding dimensionality.

  • nhead (int) – Number of self-attention heads.

  • dim_feedforward (int) – Feedforward inner dimension. Default: 2048.

  • dropout (float) – Dropout probability. Default: 0.1.

  • activation (str or Callable) – Feedforward activation; accepts "relu", "gelu", "selu", "leakyrelu", or a callable. Default: torch.nn.functional.relu.

  • layer_norm_eps (float) – LayerNorm \(\varepsilon\). Default: 1e-5.

  • batch_first (bool) – Expect input as \((N, S, H)\) when True, else \((S, N, H)\). Default: True.

  • norm_first (bool) – Apply LayerNorm before (pre-norm) rather than after (post-norm) each sub-layer. Default: False.

  • use_layer_norm (bool) – Include LayerNorm in each sub-block. Default: True.

Shape:
  • Input: \((N, S, H)\) when batch_first=True.

  • Output: \((N, S, H)\).

Examples:

>>> conv = TransTabConv(conv_dim=32, nhead=4, dim_feedforward=64)
>>> out = conv(torch.randn(8, 10, 32), src_key_padding_mask=torch.ones(8, 10))
>>> out.shape
torch.Size([8, 10, 32])
forward(x, src_mask=None, src_key_padding_mask=None, is_causal=None, **kwargs) Tensor[source]
Parameters:
  • x (Tensor) – Input of shape \((N, S, H)\).

  • src_mask (Tensor, optional) – Additive attention mask \((S, S)\). Default: None.

  • src_key_padding_mask (Tensor, optional) – Attention keep mask of shape \((N, S)\) where True means the token is valid/attended to, and False means masked out. Internally this is converted to PyTorch key_padding_mask semantics (True means ignore). Default: None.

  • is_causal – Unused; present for API compatibility.

Returns:

Same shape as input.

Return type:

Tensor