rllm.utils.seg_sum

class rllm.utils.seg_sum(data: Tensor, segment_ids: Tensor, num_segments: int)[source]

Bases:

Compute the sum of elements in data for each segment specified by segment_ids.

Parameters:
  • data (Tensor) – A tensor, typically two-dimensional.

  • segment_ids (Tensor) – A one-dimensional tensor that indicates the segment assignment of each element in data.

  • num_segments (int) – Total number of segments.

Returns:

Segment sums with shape [num_segments, data.size(1)].

Return type:

Tensor