rllm.utils.sort_edge_index

class rllm.utils.sort_edge_index(edge_index: Tensor, edge_attr: Tensor | None = None, num_nodes: int | None = None, sort_by_row: bool = True)[source]

Bases:

Sort the edge index.

Parameters:
  • edge_index (Tensor) – The edge index tensor of shape [2, num_edges].

  • edge_attr (Tensor, optional) – Edge weights with size(0) == edge_index.size(1). If not None, returns (edge_index, edge_attr). (default: None)

  • num_nodes (int, optional) – The number of nodes. If None, inferred from edge_index. (default: None)

  • sort_by_row (bool) – If set to False, sorts by destination node instead of source node. (default: True)

Example

>>> edge_index = torch.tensor([[2, 1, 1, 0],
                               [1, 2, 0, 1]])
>>> edge_attr = torch.tensor([[1], [2], [3], [4]])
>>> sort_edge_index(edge_index)
tensor([[0, 1, 1, 2],
        [1, 0, 2, 1]])
>>> sort_edge_index(edge_index, edge_attr)
(tensor([[0, 1, 1, 2],
         [1, 0, 2, 1]]),
tensor([[4],
        [3],
        [2],
        [1]]))