Source code for rllm.utils.seg_reduce

import torch
from torch import Tensor


[docs] def seg_sum(data: Tensor, segment_ids: Tensor, num_segments: int): r"""Compute the sum of elements in :obj:`data` for each segment specified by :obj:`segment_ids`. Args: data (Tensor): A tensor, typically two-dimensional. segment_ids (Tensor): A one-dimensional tensor that indicates the segment assignment of each element in :obj:`data`. num_segments (int): Total number of segments. Returns: Tensor: Segment sums with shape :obj:`[num_segments, data.size(1)]`. """ output = torch.zeros( (num_segments, data.size(1)), device=data.device, dtype=data.dtype ) return torch.scatter_reduce( output, dim=0, index=segment_ids.unsqueeze(1).expand(-1, data.size(1)), src=data, reduce="sum", )
[docs] def seg_softmax(data: Tensor, segment_ids: Tensor, num_segs: int): r"""Compute the segment-wise softmax scores of elements in :obj:`data`. Args: data (Tensor): A tensor, typically two-dimensional. segment_ids (Tensor): A one-dimensional tensor that indicates the segment assignment of each element in :obj:`data`. num_segs (int): Total number of segments. Returns: Tensor: Softmax scores with the same shape as :obj:`data`. """ max_values = torch.zeros( num_segs, data.size(1), device=data.device, dtype=data.dtype ) max_values = torch.scatter_reduce( max_values, dim=0, index=segment_ids.unsqueeze(1).expand(-1, data.size(1)), src=data, reduce="amax", ) gathered_max_values = max_values[segment_ids] exp = torch.exp(data - gathered_max_values) denominator = torch.zeros(num_segs, data.size(1), device=data.device) denominator = torch.scatter_reduce( denominator, dim=0, index=segment_ids.unsqueeze(1).expand(-1, data.size(1)), src=exp, reduce="sum", ) gathered_denominator = denominator[segment_ids] score = exp / (gathered_denominator + 1e-16) return score
def seg_softmax_(data: Tensor, segment_ids: Tensor, num_segs: int): r"""Compute the segment-wise softmax scores of elements in :obj:`data` using a loop-based implementation (fallback for older PyTorch versions). Args: data (Tensor): A tensor, typically two-dimensional. segment_ids (Tensor): A one-dimensional tensor that indicates the segment assignment of each element in :obj:`data`. num_segs (int): Total number of segments. Returns: Tensor: Softmax scores with the same shape as :obj:`data`. """ max_values = torch.zeros( num_segs, data.size(1), device=data.device, dtype=data.dtype ) for i in range(num_segs): segment_data = data[segment_ids == i] if segment_data.size(0) > 0: max_values[i] = segment_data.max(dim=0)[0] gathered_max_values = max_values[segment_ids] # (E, H) exp = torch.exp(data - gathered_max_values) # (E, H) denominator = torch.zeros(num_segs, data.size(1), device=data.device) for i in range(num_segs): segment_exp = exp[segment_ids == i] if segment_exp.size(0) > 0: denominator[i] = segment_exp.sum(dim=0) gathered_denominator = denominator[segment_ids] # (E, H) score = exp / (gathered_denominator + 1e-16) # (E, H) return score