Source code for rllm.dataloader.neighbor_loader

from typing import Optional, List, Tuple, Callable

import torch
from torch import Tensor

from rllm.data import GraphData


[docs] class NeighborLoader(torch.utils.data.DataLoader): r"""The random neighbor sampler, which allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. Args: data (GraphData): The graph data to be sampled. num_neighbors (List[int]): The number of neighbors to sample for each node in each layer. seeds (Optional[Tensor]): The nodes to sample from. If None, all nodes will be used. transform (Optional[Callable]): A function/transform that takes in a graph and returns a transformed version. The data loader will use this function to transform the graph before returning it. replace (bool, optional): Whether to sample with replacement. (default: :obj:`False`) shuffle (bool, optional): Whether to shuffle the data at every epoch. (default: :obj:`False`) batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) num_workers (int, optional): How many subprocesses to use for data loading. (default: :obj:`0`) **kwargs: Additional keyword arguments to be passed to the :class:`torch.utils.data.DataLoader` class. """ def __init__( self, data: GraphData, num_neighbors: List[int], seeds: Optional[Tensor] = None, transform: Optional[Callable] = None, replace: bool = False, shuffle: bool = False, batch_size: int = 1, num_workers: int = 0, **kwargs, ): kwargs.pop("dataset", None) kwargs.pop("collate_fn", None) self.device = data.device self.num_neighbors = num_neighbors self.replace = replace self.transform = transform self.num_nodes = data.num_nodes if seeds is None: seeds = torch.arange(self.num_nodes, dtype=torch.long) elif not isinstance(seeds, Tensor): seeds = torch.tensor(seeds, dtype=torch.long, device=self.device) elif seeds.dtype == torch.bool: seeds = seeds.nonzero(as_tuple=False).flatten() # prepare csc for sampling self._build_csc(data) super().__init__( dataset=seeds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=self.collate_fn, **kwargs, ) def _build_csc(self, data: GraphData): r"""Build a compressed sparse column (CSC) representation of the graph for efficient neighbor sampling. """ if hasattr(data, "edge_index"): edge_index = data.edge_index elif hasattr(data, "adj"): edge_index = data.adj.coalesce().indices() order = torch.argsort(edge_index[1]) self.dst_sorted = edge_index[1][order] self.src_sorted = edge_index[0][order] cnts = torch.bincount(self.dst_sorted, minlength=self.num_nodes) self.col_ptr = torch.empty( self.num_nodes + 1, dtype=torch.long, device=self.device ) self.col_ptr[0] = 0 self.col_ptr[1:] = torch.cumsum(cnts, dim=0)
[docs] def get_in_neighbors(self, node: int) -> torch.Tensor: r"""Get the in-neighbors of a given node in the graph. Args: node (int): The node for which to get the in-neighbors. Returns: Tensor: The indices of in-neighbor nodes. """ start = self.col_ptr[node].item() end = self.col_ptr[node + 1].item() return self.src_sorted[start:end]
[docs] def sample_neighbors_one_layer( self, seed_nodes: List[int], num_neighbor: int ) -> Tuple[Tensor, Tensor]: r"""Sample neighbors for a given set of seed nodes. Args: seed_nodes (List[int]): The nodes to sample neighbors from. num_neighbor (int): The number of neighbors to sample for each node. Returns: Tuple[Tensor, Tensor]: A tuple containing the sampled source nodes and destination nodes. """ sampled_src_list = [] dst_list = [] for i, node in enumerate(seed_nodes): neighbors = self.get_in_neighbors(node) n_neighbors = neighbors.numel() if n_neighbors == 0: continue elif num_neighbor < 0 or n_neighbors < num_neighbor: sampled = neighbors else: perm = torch.randperm(n_neighbors, device=self.device)[:num_neighbor] sampled = neighbors[perm] sampled_src_list.append(sampled) dst_list.append( torch.full( (sampled.numel(),), node, dtype=torch.long, device=self.device ) ) if sampled_src_list: return torch.cat(sampled_src_list), torch.cat(dst_list) else: return ( torch.empty((0,), dtype=torch.long, device=self.device), torch.empty((0,), dtype=torch.long, device=self.device), )
[docs] def collate_fn( self, batch: List[Tensor], ) -> Tuple[int, Tensor, List[Tensor]]: r"""Collate function for the NeighborLoader. Samples neighbors for each node in the batch and returns the sampled subgraph. Args: batch (List[Tensor]): A list of seed node indices. Returns: Tuple[int, Tensor, List[Tensor]]: A tuple of :obj:`(batch_size, n_id, adjs)` where :obj:`batch_size` is the number of seed nodes, :obj:`n_id` contains all sampled node indices, and :obj:`adjs` is a list of sparse adjacency matrices per hop. """ batch = torch.tensor(batch, dtype=torch.long).tolist() raw_adjs = [] seed_nodes = batch n_id = batch.copy() seen = set(n_id) for num_neighbor in self.num_neighbors: sampled_src, dst = self.sample_neighbors_one_layer(seed_nodes, num_neighbor) raw_adjs.append((sampled_src, dst)) for node in sampled_src.tolist(): if node not in seen: seen.add(node) n_id.append(node) seed_nodes = sampled_src.unique().tolist() n_id = torch.tensor(n_id, dtype=torch.long, device=self.device) sorted_, perm = torch.sort(n_id) adjs = [] for i, (src, dst) in enumerate(raw_adjs): if src.numel() == 0: adjs.append(torch.empty((0,), dtype=torch.long, device=self.device)) continue src_ = perm[torch.searchsorted(sorted_, src)] dst_ = perm[torch.searchsorted(sorted_, dst)] edge_index = torch.stack([src_, dst_], dim=0) size = edge_index.max() + 1 adj = torch.sparse_coo_tensor( indices=edge_index, values=torch.ones(edge_index.shape[1], device=self.device), size=(size, size), device=self.device, ) if self.transform is not None: adj = self.transform(adj) adjs.append(adj) return len(batch), n_id, adjs