Source code for rllm.utils._sort

from typing import List

from torch import Tensor


[docs] def lexsort( keys: List[Tensor], dim: int = -1, descending: bool = False, ) -> Tensor: r"""Perform an indirect stable sort using a sequence of keys. Given multiple sorting keys, lexsort returns an array of integer indices that describes the sort order by multiple keys. The last key in the sequence is used for the primary sort order, ties are broken by the second-to-last key, and so on. Example: >>> a = torch.tensor([1, 5, 1, 4, 3, 4, 4]) # First sequence >>> b = torch.tensor([9, 4, 0, 4, 0, 2, 1]) # Second sequence >>> ind = lexsort((b, a)) # Sort by `a`, then by `b` >>> ind tensor([2, 0, 4, 6, 5, 3, 1]) >>> [torch.tensor((a[i], b[i])) for i in ind] [tensor([1, 0]), tensor([1, 9]), tensor([3, 0]), tensor([4, 1]), tensor([4, 2]), tensor([4, 4]), tensor([5, 4])] Args: keys (List[Tensor]): Sorting keys; the last key has the highest priority (primary sort), earlier keys break ties. dim (int): The dimension along which to sort. (default: :obj:`-1`) descending (bool): If :obj:`True`, sorts in descending order. (default: :obj:`False`) Returns: Tensor: A 1-D tensor of integer indices that sorts the input keys. """ out = keys[0].argsort(dim=dim, descending=descending, stable=True) for k in keys[1:]: index = k.gather(dim=dim, index=out) index = index.argsort(dim=dim, descending=descending, stable=True) out = out.gather(dim=dim, index=index) return out