rllm.nn.loss.ContrastiveLoss

class rllm.nn.loss.ContrastiveLoss(temperature: float = 10.0, base_temperature: float = 10.0, similarity: Literal['dot', 'cosine'] = 'dot', eps: float = 1e-12)[source]

Bases: BaseLoss

Generalized InfoNCE-style contrastive loss with a customizable positive mask.

This class provides a reusable implementation of the InfoNCE / SupCon contrastive objective used in self-supervised, supervised, and vertical-partition contrastive learning. Subclasses only define how positives are selected (via pos_mask); all numerical and normalisation steps are handled here.

Per-anchor loss:

\[\ell_i = -\frac{1}{|P(i)|} \sum_{p \in P(i)} \log \frac{\exp(s_{ip} / \tau)} {\sum_{a \neq i} \exp(s_{ia} / \tau)}\]

Batch loss (with scaling factor \(\frac{\tau}{\tau_0}\)):

\[\mathcal{L} = \frac{\tau}{\tau_0} \cdot \frac{1}{N} \sum_{i=1}^{N} \ell_i\]
Parameters:
  • temperature (float) – Temperature \(\tau\) scaling the logits.

  • base_temperature (float) – Reference temperature \(\tau_0\).

  • similarity (str) – Similarity metric, "dot" or "cosine".

  • eps (float) – Numerical stability constant added to log-denominators.

Example

>>> import torch
>>> loss_fn = ContrastiveLoss(temperature=1.0, similarity="dot")
>>> feats = torch.randn(4, 8)
>>> pos_mask = torch.eye(4)
>>> loss_fn(feats, pos_mask).ndim
0
forward(feats: Tensor, pos_mask: Tensor) Tensor[source]

Compute the contrastive loss given embeddings and a positive mask.

Parameters:
  • feats (Tensor) – [N, D] projected embeddings from all views / partitions.

  • pos_mask (Tensor) – [N, N] float or bool mask where pos_mask[a, b] = 1 when sample b should be treated as a positive of anchor a. Self-pairs are excluded internally.

Returns:

Scalar contrastive loss. Returns 0.0 (with gradient) when no anchor in the batch has any valid positive.

Return type:

torch.Tensor

Example

>>> import torch
>>> loss_fn = ContrastiveLoss()
>>> feats = torch.randn(3, 5)
>>> pos_mask = torch.tensor(
...     [[0, 1, 0], [1, 0, 0], [0, 0, 0]], dtype=torch.float32
... )
>>> loss_fn(feats, pos_mask).ndim
0