rllm.dataloader.BRIDGELoader¶
- class rllm.dataloader.BRIDGELoader(table: TableData, non_table: Tensor | None, graph: GraphData, num_samples: List[int], train_mask: Tensor | None = None, **kwargs)[source]¶
Bases:
NeighborLoaderBRIDGELoader is a specialized data loader for the BRIDGE model. It is designed to handle the unique requirements of the BRIDGE model and provides additional functionality for processing graph and table data in a unified manner.
BRIDGE always put table before non-table embeddings in the graph node index. After sampling n_id, we need to split it by the lengh of table: n_id <= self.sep is the table node index and n_id > self.sep is the non-table node index.
- Parameters:
table (TableData) – The table data to be sampled.
non_table (Optional[Tensor]) – The non-table data to be sampled.
graph (GraphData) – The graph data to be sampled.
num_samples (List[int]) – The number of samples to be taken from each layer of the graph.
train_mask (Optional[Tensor]) – The mask to be used for sampling.
**kwargs – Additional keyword arguments to be passed to the NeighborLoader class.
- collate_fn(batch)[source]¶
Collate function for the NeighborLoader. Samples neighbors for each node in the batch and returns the sampled subgraph.
- Parameters:
batch (List[Tensor]) – A list of seed node indices.
- Returns:
A tuple of
(batch_size, n_id, adjs)wherebatch_sizeis the number of seed nodes,n_idcontains all sampled node indices, andadjsis a list of sparse adjacency matrices per hop.- Return type:
Tuple[int, Tensor, List[Tensor]]