rllm.utils.lexsort

class rllm.utils.lexsort(keys: List[Tensor], dim: int = -1, descending: bool = False)[source]

Bases:

Perform an indirect stable sort using a sequence of keys.

Given multiple sorting keys, lexsort returns an array of integer indices that describes the sort order by multiple keys. The last key in the sequence is used for the primary sort order, ties are broken by the second-to-last key, and so on.

Example

>>> a = torch.tensor([1, 5, 1, 4, 3, 4, 4])  # First sequence
>>> b = torch.tensor([9, 4, 0, 4, 0, 2, 1])  # Second sequence
>>> ind = lexsort((b, a))  # Sort by `a`, then by `b`
>>> ind
tensor([2, 0, 4, 6, 5, 3, 1])
>>> [torch.tensor((a[i], b[i])) for i in ind]
[tensor([1, 0]), tensor([1, 9]), tensor([3, 0]), tensor([4, 1]),
tensor([4, 2]), tensor([4, 4]), tensor([5, 4])]
Parameters:
  • keys (List[Tensor]) – Sorting keys; the last key has the highest priority (primary sort), earlier keys break ties.

  • dim (int) – The dimension along which to sort. (default: -1)

  • descending (bool) – If True, sorts in descending order. (default: False)

Returns:

A 1-D tensor of integer indices that sorts the input keys.

Return type:

Tensor