Source code for rllm.utils.sparse
import numpy as np
import torch
from torch import Tensor
[docs]
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
r"""Convert a scipy sparse matrix to a :class:`torch.sparse.Tensor`.
Args:
sparse_mx (scipy.sparse.spmatrix): The input scipy sparse matrix.
Returns:
Tensor: A sparse COO tensor with :obj:`float32` values.
"""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
)
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse_coo_tensor(indices, values, shape)
[docs]
def is_torch_sparse_tensor(src):
r"""Return :obj:`True` if the input is a :class:`torch.sparse.Tensor`.
Args:
src (Any): The object to check.
Returns:
bool: :obj:`True` if :obj:`src` is a sparse tensor, :obj:`False`
otherwise.
"""
sparse_types = [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc]
if isinstance(src, torch.Tensor):
return src.layout in sparse_types
return False
[docs]
def get_indices(adj: Tensor):
r"""Get indices of non-zero elements from an adjacency matrix.
Args:
adj (Tensor): the adjacency matrix.
Returns:
indices (Tensor): indices of non-zero elements.
"""
if is_torch_sparse_tensor(adj):
indices = adj.coalesce().indices()
else:
indices = adj.nonzero().t()
return indices
def set_values(adj: Tensor, values: Tensor) -> Tensor:
r"""Replace the values of a sparse tensor while keeping its indices.
Args:
adj (Tensor): The sparse tensor whose values are to be replaced.
Must be in COO, CSR, or CSC format.
values (Tensor): The new values. Must match the number of non-zero
elements in :obj:`adj`.
Returns:
Tensor: A new sparse tensor with the same sparsity pattern as
:obj:`adj` but with the updated :obj:`values`.
"""
if values.dim() > 1:
size = adj.size() + values.size()[1:]
else:
size = adj.size()
if adj.layout == torch.sparse_coo:
return torch.sparse_coo_tensor(
adj.indices(), values, size, device=adj.device
)
elif adj.layout == torch.sparse_csr:
return torch.sparse_csr_tensor(
adj.crow_indices(), adj.col_indices(), values, size, device=adj.device
)
elif adj.layout == torch.sparse_csc:
return torch.sparse_csc_tensor(
adj.ccol_indices(), adj.row_indices(), values, size, device=adj.device
)
else:
raise ValueError(f"Unsupported sparse tensor layout: {adj.layout}")