rllm.nn.conv.table_conv.TransTabConv¶
- class rllm.nn.conv.table_conv.TransTabConv(conv_dim, nhead, dim_feedforward=2048, dropout=0.1, activation=<function relu>, layer_norm_eps=1e-05, batch_first=True, norm_first=False, use_layer_norm=True)[source]¶
Bases:
ModuleSingle Transformer encoder layer for TransTab (“TransTab”).
Combines multi-head self-attention with a gated feedforward network, residual connections, dropout, and optional LayerNorm.
- Parameters:
conv_dim (int) – Input/output embedding dimensionality.
nhead (int) – Number of self-attention heads.
dim_feedforward (int) – Feedforward inner dimension. Default:
2048.dropout (float) – Dropout probability. Default:
0.1.activation (str or Callable) – Feedforward activation; accepts
"relu","gelu","selu","leakyrelu", or a callable. Default:torch.nn.functional.relu.layer_norm_eps (float) – LayerNorm \(\varepsilon\). Default:
1e-5.batch_first (bool) – Expect input as \((N, S, H)\) when
True, else \((S, N, H)\). Default:True.norm_first (bool) – Apply LayerNorm before (pre-norm) rather than after (post-norm) each sub-layer. Default:
False.use_layer_norm (bool) – Include LayerNorm in each sub-block. Default:
True.
- Shape:
Input: \((N, S, H)\) when
batch_first=True.Output: \((N, S, H)\).
Examples:
>>> conv = TransTabConv(conv_dim=32, nhead=4, dim_feedforward=64) >>> out = conv(torch.randn(8, 10, 32), src_key_padding_mask=torch.ones(8, 10)) >>> out.shape torch.Size([8, 10, 32])
- forward(x, src_mask=None, src_key_padding_mask=None, is_causal=None, **kwargs) Tensor[source]¶
- Parameters:
x (Tensor) – Input of shape \((N, S, H)\).
src_mask (Tensor, optional) – Additive attention mask \((S, S)\). Default:
None.src_key_padding_mask (Tensor, optional) – Attention keep mask of shape \((N, S)\) where
Truemeans the token is valid/attended to, andFalsemeans masked out. Internally this is converted to PyTorchkey_padding_masksemantics (Truemeans ignore). Default:None.is_causal – Unused; present for API compatibility.
- Returns:
Same shape as input.
- Return type:
Tensor