Source code for rllm.nn.loss.contrastive_loss

from __future__ import annotations
from typing import Literal

import torch
import torch.nn.functional as F

from rllm.nn.loss.base_loss import BaseLoss


[docs] class ContrastiveLoss(BaseLoss): r"""Generalized InfoNCE-style contrastive loss with a customizable positive mask. This class provides a reusable implementation of the InfoNCE / SupCon contrastive objective used in self-supervised, supervised, and vertical-partition contrastive learning. Subclasses only define how positives are selected (via ``pos_mask``); all numerical and normalisation steps are handled here. Per-anchor loss: .. math:: \ell_i = -\frac{1}{|P(i)|} \sum_{p \in P(i)} \log \frac{\exp(s_{ip} / \tau)} {\sum_{a \neq i} \exp(s_{ia} / \tau)} Batch loss (with scaling factor :math:`\frac{\tau}{\tau_0}`): .. math:: \mathcal{L} = \frac{\tau}{\tau_0} \cdot \frac{1}{N} \sum_{i=1}^{N} \ell_i Args: temperature (float): Temperature :math:`\tau` scaling the logits. base_temperature (float): Reference temperature :math:`\tau_0`. similarity (str): Similarity metric, ``"dot"`` or ``"cosine"``. eps (float): Numerical stability constant added to log-denominators. Example: >>> import torch >>> loss_fn = ContrastiveLoss(temperature=1.0, similarity="dot") >>> feats = torch.randn(4, 8) >>> pos_mask = torch.eye(4) >>> loss_fn(feats, pos_mask).ndim 0 """ def __init__( self, temperature: float = 10.0, base_temperature: float = 10.0, similarity: Literal["dot", "cosine"] = "dot", eps: float = 1e-12, ) -> None: super().__init__() self.temperature = float(temperature) self.base_temperature = float(base_temperature) self.similarity = similarity self.eps = eps def _pairwise_logits(self, feats: torch.Tensor) -> torch.Tensor: """Compute pairwise similarity logits scaled by temperature. Args: feats: ``[N, D]`` embedding matrix. Returns: ``[N, N]`` logit matrix where ``logits[a, b] = sim(a, b) / τ``. """ if self.similarity == "cosine": feats = F.normalize(feats, dim=1) sim_matrix = torch.matmul(feats, feats.T) return sim_matrix / self.temperature
[docs] def forward( self, feats: torch.Tensor, pos_mask: torch.Tensor, ) -> torch.Tensor: """Compute the contrastive loss given embeddings and a positive mask. Args: feats (Tensor): ``[N, D]`` projected embeddings from all views / partitions. pos_mask (Tensor): ``[N, N]`` float or bool mask where ``pos_mask[a, b] = 1`` when sample *b* should be treated as a positive of anchor *a*. Self-pairs are excluded internally. Returns: torch.Tensor: Scalar contrastive loss. Returns ``0.0`` (with gradient) when no anchor in the batch has any valid positive. Example: >>> import torch >>> loss_fn = ContrastiveLoss() >>> feats = torch.randn(3, 5) >>> pos_mask = torch.tensor( ... [[0, 1, 0], [1, 0, 0], [0, 0, 0]], dtype=torch.float32 ... ) >>> loss_fn(feats, pos_mask).ndim 0 """ device = feats.device N = feats.shape[0] # 1. Pairwise logits [N, N] logits = self._pairwise_logits(feats) # 2. Set diagonal to -inf so self-pairs contribute exactly 0 in exp # without the nan-producing `exp(large) * 0` pattern. eye = torch.eye(N, dtype=torch.bool, device=device) logits = logits.masked_fill(eye, float("-inf")) # 3. Numerical stability: subtract row-wise max (ignores -inf entries) logits_max, _ = torch.max(logits, dim=1, keepdim=True) logits = logits - logits_max.detach() # 4. Positive mask: exclude self-pairs mask = pos_mask.to(dtype=feats.dtype, device=device).clone() mask.fill_diagonal_(0.0) # 5. Denominator: sum exp over all non-self entries # Diagonal is -inf → exp(-inf) = 0, so no masking needed here. exp_logits = torch.exp(logits) # [N, N] log_denom = torch.log(exp_logits.sum(dim=1, keepdim=True) + self.eps) # 6. log p(b | a) for all pairs log_prob = logits - log_denom # [N, N] # 7. For each anchor average log_prob over its positives. # Use torch.where instead of mask * log_prob to avoid # 0 * (-inf) = nan (IEEE 754): non-positive positions get 0. pos_weight_sum = mask.sum(dim=1) # [N] valid = pos_weight_sum > 0 # [N] bool if not valid.any(): return torch.tensor(0.0, device=device, requires_grad=True) zero = torch.zeros_like(log_prob) masked_log_prob = torch.where(mask.bool(), log_prob, zero) # [N, N] mean_log_prob_pos = masked_log_prob.sum(dim=1) # [N] mean_log_prob_pos[valid] = ( mean_log_prob_pos[valid] / pos_weight_sum[valid] ) mean_log_prob_pos[~valid] = 0.0 # 8. Scale and average over valid anchors only loss_per_anchor = -(self.temperature / self.base_temperature) * mean_log_prob_pos return loss_per_anchor[valid].mean()