rllm.nn.loss.SupervisedVPCL

class rllm.nn.loss.SupervisedVPCL(temperature: float = 10.0, base_temperature: float = 10.0, similarity: Literal['dot', 'cosine'] = 'dot', eps: float = 1e-12)[source]

Bases: ContrastiveLoss

The supervised vertical-partition contrastive loss (Supervised-VPCL) implementation, based on the “TransTab: Learning Transferable Tabular Transformers Across Tables” paper.

It extends supervised contrastive learning to vertically partitioned tabular data. Positive pairs are built from partitions whose source rows share the same class label, while rows from different labels serve as negatives.

\[\ell(X, y) = - \sum_{i=1}^{B} \sum_{j=1}^{B} \sum_{k=1}^{K} \sum_{k'=1}^{K} \mathbf{1}\{y_j = y_i\} \log \frac{ \exp\big(\psi(\mathbf{v}_i^{k}, \mathbf{v}_j^{k'})\big) }{ \sum_{j^{\dagger}=1}^{B}\sum_{k^{\dagger}=1}^{K} \mathbf{1}\{y_{j^{\dagger}} \neq y_i\} \exp\big(\psi(\mathbf{v}_i^{k}, \mathbf{v}_{j^{\dagger}}^{k^{\dagger}})\big) } .\]

where \(B\) is batch size, \(K\) is partition count per sample, \(\mathbf{v}_i^k\) is the partition embedding, and \(y_i\) is the class label for sample \(i\).

Parameters:
  • temperature (float) – Temperature \(\tau\) scaling logits.

  • base_temperature (float) – Reference temperature \(\tau_0\) used in the final scaling factor \(\tau / \tau_0\).

  • similarity (str) – Similarity metric, either "dot" or "cosine".

  • eps (float) – Numerical stability constant.

Shapes:
  • input:

    partition embeddings \((B, K, D)\) labels \((B,)\)

  • output: scalar loss \(()\)

Example

>>> import torch
>>> loss_fn = SupervisedVPCL()
>>> feats = torch.randn(4, 2, 8)
>>> labels = torch.tensor([0, 1, 0, 1])
>>> loss_fn(feats, labels).shape
torch.Size([])
forward(features: Tensor, labels: Tensor) Tensor[source]

Compute supervised vertical-partition contrastive loss.

Parameters:
  • features (torch.Tensor) – Partition embeddings with shape \((B, K, D)\).

  • labels (torch.Tensor) – Class labels with shape \((B,)\).

Returns:

Scalar loss tensor with shape \(()\).

Return type:

torch.Tensor

Example

>>> import torch
>>> loss_fn = SupervisedVPCL()
>>> feats = torch.randn(4, 2, 8)
>>> labels = torch.tensor([0, 1, 0, 1])
>>> out = loss_fn(feats, labels)
>>> out.shape
torch.Size([])