Source code for rllm.utils.graph_utils

from typing import Union, Tuple, Optional

import torch
from torch import Tensor
import numpy as np
import scipy.sparse as sp

from ._sort import lexsort
from rllm.utils.sparse import is_torch_sparse_tensor, sparse_mx_to_torch_sparse_tensor
import rllm.utils._pyglib

# Filter the csr warning
import warnings

warnings.filterwarnings(
    "ignore",
    message="Sparse CSR tensor support is in beta state.",
    category=UserWarning,
    module=r".*graph_utils",  # discard specific module warning
)


[docs] def adj_to_edge_index(adj: Tensor) -> Tuple[Tensor, Optional[Tensor]]: r"""Convert a sparse adjacency matrix to an edge index tensor. Args: adj (Tensor): A sparse adjacency tensor in COO, CSR, or CSC format. Returns: Tuple[Tensor, Optional[Tensor]]: A tuple :obj:`(edge_index, edge_attr)` where :obj:`edge_index` has shape :obj:`[2, num_edges]` and :obj:`edge_attr` is :obj:`None` if all edge weights are 1. """ if is_torch_sparse_tensor(adj): coo_adj = adj.to_sparse_coo().coalesce() s, d, vs = coo_adj.indices()[0], coo_adj.indices()[1], coo_adj.values() vs = None if torch.all(vs == 1) else vs return torch.stack([s, d]), vs else: raise TypeError(f"Expect adj to be a SparseTensor, got {type(adj)}.")
def remove_self_loops(adj: Tensor): r"""Remove self-loops from the adjacency matrix. Args: adj (Tensor): The adjacency matrix in sparse or dense format. Returns: Tensor: The adjacency matrix with self-loops removed. """ shape = adj.shape device = adj.device if is_torch_sparse_tensor(adj): adj = adj.coalesce() indices = adj.indices() values = adj.values() mask = indices[0] != indices[1] indices = indices[:, mask] values = values[mask] return torch.sparse_coo_tensor(indices, values, shape).to(device) loop_index = torch.arange(0, shape[0], dtype=torch.long, device=device) adj = adj.clone() adj[loop_index, loop_index] = 0 return adj def gcn_norm(adj: Tensor): r"""Normalize the sparse adjacency matrix from the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper. .. math:: \mathbf{\hat{A}} = \mathbf{\hat{D}}^{-1/2} (\mathbf{A} + \mathbf{I}) \mathbf{\hat{D}}^{-1/2} Args: adj (Tensor): The sparse adjacency matrix. Supported layouts: :obj:`torch.sparse_coo`, :obj:`torch.sparse_csr`, :obj:`torch.sparse_csc`. Returns: Tensor: The symmetrically normalized sparse adjacency matrix. """ shape = adj.shape device = adj.device if is_torch_sparse_tensor(adj): adj = adj.coalesce() indices = adj.indices() mask = indices[0] != indices[1] loop_index = torch.arange(0, shape[0], dtype=torch.long, device=device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) indices_sl = torch.cat([indices[:, mask], loop_index], dim=1) indices = indices[:, mask] # D adj_data = np.ones([indices.shape[1]], dtype=np.float32) adj_sp = sp.csr_matrix((adj_data, (indices[0], indices[1])), shape=shape) adj_sl_data = np.ones([indices_sl.shape[1]], dtype=np.float32) adj_sl_sp = sp.csr_matrix( (adj_sl_data, (indices_sl[0], indices_sl[1])), shape=shape ) # D-1/2 deg = np.array(adj_sl_sp.sum(axis=1)).flatten() deg_sqrt_inv = np.power(deg, -0.5) deg_sqrt_inv[deg_sqrt_inv == float("inf")] = 0.0 deg_sqrt_inv = sp.diags(deg_sqrt_inv) # filters filters = sp.coo_matrix(deg_sqrt_inv * adj_sp * deg_sqrt_inv) return sparse_mx_to_torch_sparse_tensor(filters).to(device)
[docs] def sort_edge_index( edge_index: Tensor, edge_attr: Tensor = None, num_nodes: int = None, sort_by_row: bool = True, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Sort the edge index. Args: edge_index (Tensor): The edge index tensor of shape :obj:`[2, num_edges]`. edge_attr (Tensor, optional): Edge weights with :obj:`size(0) == edge_index.size(1)`. If not :obj:`None`, returns :obj:`(edge_index, edge_attr)`. (default: :obj:`None`) num_nodes (int, optional): The number of nodes. If :obj:`None`, inferred from :obj:`edge_index`. (default: :obj:`None`) sort_by_row (bool): If set to :obj:`False`, sorts by destination node instead of source node. (default: :obj:`True`) Example: >>> edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) >>> edge_attr = torch.tensor([[1], [2], [3], [4]]) >>> sort_edge_index(edge_index) tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) >>> sort_edge_index(edge_index, edge_attr) (tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), tensor([[4], [3], [2], [1]])) """ if num_nodes is None: num_nodes = int(edge_index.max().item()) + 1 if edge_index.numel() > 0 else 0 index = lexsort( keys=[ edge_index[int(sort_by_row)], edge_index[1 - int(sort_by_row)], ] ) edge_index = edge_index[:, index] if edge_attr is not None: return edge_index, edge_attr[index] return edge_index
[docs] def index_to_ptr(index: Tensor, num_nodes: Optional[int] = None) -> Tensor: r"""Convert a sorted index tensor to a CSR/CSC pointer tensor. Args: index (Tensor): The sorted index tensor. num_nodes (int, optional): The number of nodes. If :obj:`None`, inferred from :obj:`index`. (default: :obj:`None`) Example: >>> index = torch.tensor([0, 1, 1, 2, 2, 3]) >>> index_2_ptr(index, 4) tensor([0, 1, 3, 5, 6]) """ if num_nodes is None: num_nodes = int(index.max().item()) + 1 if index.numel() > 0 else 0 ptr = torch.zeros(num_nodes + 1, dtype=torch.long, device=index.device) ptr[1:] = torch.bincount(index, minlength=num_nodes) return ptr.cumsum(0)
def _to_csc( input: Tensor, device: Optional[torch.device] = None, num_nodes: Optional[int] = None, share_memory: bool = False, is_sorted: bool = False, src_node_time: Optional[Tensor] = None, edge_time: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: r"""Convert an edge index or sparse adjacency matrix to CSC format. Args: input (Tensor): The input edge index of shape :obj:`[2, num_edges]` or a sparse adjacency tensor. device (torch.device, optional): The desired device of the returned tensors. If :obj:`None`, uses the input device. (default: :obj:`None`) num_nodes (int, optional): The number of nodes. If :obj:`None`, inferred from the edge index. (default: :obj:`None`) share_memory (bool): If set to :obj:`True`, moves output tensors to shared memory for multi-process data loading. (default: :obj:`False`) is_sorted (bool): If set to :obj:`True`, skips sorting the edge index. (default: :obj:`False`) src_node_time (Tensor, optional): Source node timestamps. If not :obj:`None`, the edge index is sorted by source node time for temporal sampling. (default: :obj:`None`) edge_time (Tensor, optional): Edge timestamps. If not :obj:`None`, the edge index is sorted by edge time for temporal sampling. (default: :obj:`None`) Returns: Tuple[Tensor, Tensor, Optional[Tensor]]: The column indices, row indices, and the permutation index. """ device = input.device if device is None else device # sparse adj if isinstance(input, Tensor) and input.is_sparse: adj = input if is_torch_sparse_tensor(adj): if src_node_time is not None or edge_time is not None: raise NotImplementedError( "Do not support temporal convert for torch sparse tensor." ) csc_t: torch.sparse.Tensor = adj.to_sparse_csc() col_ptr = csc_t.ccol_indices() row = csc_t.row_indices() perm = None # edge index elif isinstance(input, Tensor) and input.dim() == 2 and input.size(0) == 2: row, col = input[0, :], input[1, :] if num_nodes is None: # If not provided, use max destination node index + 1 # as the node number. # This is used to build col_ptr. num_nodes = col.max() + 1 if not is_sorted: if src_node_time is None and edge_time is None: perm = torch.argsort(col) col = col[perm] row = row[perm] elif edge_time is not None and src_node_time is None: perm = lexsort(keys=[edge_time, col]) col = col[perm] row = row[perm] elif src_node_time is not None and edge_time is None: perm = lexsort(keys=[src_node_time[row], col]) col = col[perm] row = row[perm] else: raise NotImplementedError( "Only support one temporal sort for now." "But both `src_node_time` and `edge_time` are not `None`." ) col_ptr = index_to_ptr(col, num_nodes) else: raise ValueError( "No edge found. Edge type should be either `adj` or `edge_index`." ) col_ptr = col_ptr.to(device=device) row = row.to(device=device) perm = perm.to(device=device) if perm is not None else None if not col_ptr.is_cuda and share_memory: col_ptr.share_memory_() row.share_memory_() if perm is not None: perm.share_memory_() return col_ptr, row, perm def to_bidirectional( row: Tensor, col: Tensor, rev_row: Tensor, rev_col: Tensor, ) -> Tuple[Tensor, Tensor]: r"""Merge a directed edge set and its reverse into a deduplicated bidirectional edge set. Args: row (Tensor): Source node indices of the forward edges. col (Tensor): Destination node indices of the forward edges. rev_row (Tensor): Source node indices of the reverse edges. rev_col (Tensor): Destination node indices of the reverse edges. Returns: Tuple[Tensor, Tensor]: Deduplicated :obj:`(row, col)` tensors representing the bidirectional edge set. """ assert row.numel() == col.numel() assert rev_row.numel() == rev_col.numel() edge_index = row.new_empty(2, row.numel() + rev_row.numel()) edge_index[0, :row.numel()] = row edge_index[1, :row.numel()] = col edge_index[0, row.numel():] = rev_col edge_index[1, row.numel():] = rev_row # Fast path: use PyG's `coalesce` (C++/CUDA-backed via torch-sparse/pyg-lib) # to de-duplicate edges. This is significantly faster than `torch.unique` # on a 2xE tensor, especially when called repeatedly per edge type/batch. if rllm.utils._pyglib.WITH_PYG_LIB: try: from torch_geometric.utils import coalesce # type: ignore (edge_index, _) = coalesce( edge_index, None, sort_by_row=False, reduce='any', ) return edge_index[0], edge_index[1] except Exception: pass # Fallback: pure PyTorch de-duplication. edge_index = torch.unique(edge_index, dim=1) return edge_index[0], edge_index[1]