rllm.utils.seg_sum¶
- class rllm.utils.seg_sum(data: Tensor, segment_ids: Tensor, num_segments: int)[source]¶
Bases:
Compute the sum of elements in
datafor each segment specified bysegment_ids.- Parameters:
data (Tensor) – A tensor, typically two-dimensional.
segment_ids (Tensor) – A one-dimensional tensor that indicates the segment assignment of each element in
data.num_segments (int) – Total number of segments.
- Returns:
Segment sums with shape
[num_segments, data.size(1)].- Return type:
Tensor