rllm.nn.loss.ContrastiveLoss¶
- class rllm.nn.loss.ContrastiveLoss(temperature: float = 10.0, base_temperature: float = 10.0, similarity: Literal['dot', 'cosine'] = 'dot', eps: float = 1e-12)[source]¶
Bases:
BaseLossGeneralized InfoNCE-style contrastive loss with a customizable positive mask.
This class provides a reusable implementation of the InfoNCE / SupCon contrastive objective used in self-supervised, supervised, and vertical-partition contrastive learning. Subclasses only define how positives are selected (via
pos_mask); all numerical and normalisation steps are handled here.Per-anchor loss:
\[\ell_i = -\frac{1}{|P(i)|} \sum_{p \in P(i)} \log \frac{\exp(s_{ip} / \tau)} {\sum_{a \neq i} \exp(s_{ia} / \tau)}\]Batch loss (with scaling factor \(\frac{\tau}{\tau_0}\)):
\[\mathcal{L} = \frac{\tau}{\tau_0} \cdot \frac{1}{N} \sum_{i=1}^{N} \ell_i\]- Parameters:
temperature (float) – Temperature \(\tau\) scaling the logits.
base_temperature (float) – Reference temperature \(\tau_0\).
similarity (str) – Similarity metric,
"dot"or"cosine".eps (float) – Numerical stability constant added to log-denominators.
Example
>>> import torch >>> loss_fn = ContrastiveLoss(temperature=1.0, similarity="dot") >>> feats = torch.randn(4, 8) >>> pos_mask = torch.eye(4) >>> loss_fn(feats, pos_mask).ndim 0
- forward(feats: Tensor, pos_mask: Tensor) Tensor[source]¶
Compute the contrastive loss given embeddings and a positive mask.
- Parameters:
feats (Tensor) –
[N, D]projected embeddings from all views / partitions.pos_mask (Tensor) –
[N, N]float or bool mask wherepos_mask[a, b] = 1when sample b should be treated as a positive of anchor a. Self-pairs are excluded internally.
- Returns:
Scalar contrastive loss. Returns
0.0(with gradient) when no anchor in the batch has any valid positive.- Return type:
torch.Tensor
Example
>>> import torch >>> loss_fn = ContrastiveLoss() >>> feats = torch.randn(3, 5) >>> pos_mask = torch.tensor( ... [[0, 1, 0], [1, 0, 0], [0, 0, 0]], dtype=torch.float32 ... ) >>> loss_fn(feats, pos_mask).ndim 0