rllm.dataloader.BRIDGELoader

class rllm.dataloader.BRIDGELoader(table: TableData, non_table: Tensor | None, graph: GraphData, num_samples: List[int], train_mask: Tensor | None = None, **kwargs)[source]

Bases: NeighborLoader

BRIDGELoader is a specialized data loader for the BRIDGE model. It is designed to handle the unique requirements of the BRIDGE model and provides additional functionality for processing graph and table data in a unified manner.

BRIDGE always put table before non-table embeddings in the graph node index. After sampling n_id, we need to split it by the lengh of table: n_id <= self.sep is the table node index and n_id > self.sep is the non-table node index.

Parameters:
  • table (TableData) – The table data to be sampled.

  • non_table (Optional[Tensor]) – The non-table data to be sampled.

  • graph (GraphData) – The graph data to be sampled.

  • num_samples (List[int]) – The number of samples to be taken from each layer of the graph.

  • train_mask (Optional[Tensor]) – The mask to be used for sampling.

  • **kwargs – Additional keyword arguments to be passed to the NeighborLoader class.

collate_fn(batch)[source]

Collate function for the NeighborLoader. Samples neighbors for each node in the batch and returns the sampled subgraph.

Parameters:

batch (List[Tensor]) – A list of seed node indices.

Returns:

A tuple of (batch_size, n_id, adjs) where batch_size is the number of seed nodes, n_id contains all sampled node indices, and adjs is a list of sparse adjacency matrices per hop.

Return type:

Tuple[int, Tensor, List[Tensor]]