Source code for rllm.dataloader.bridge_loader

from typing import List, Optional

from torch import Tensor

from rllm.data import TableData, GraphData
from rllm.dataloader.neighbor_loader import NeighborLoader


[docs] class BRIDGELoader(NeighborLoader): r"""BRIDGELoader is a specialized data loader for the BRIDGE model. It is designed to handle the unique requirements of the BRIDGE model and provides additional functionality for processing graph and table data in a unified manner. BRIDGE always put table before non-table embeddings in the graph node index. After sampling `n_id`, we need to split it by the lengh of table: `n_id <= self.sep` is the table node index and `n_id > self.sep` is the non-table node index. Args: table (TableData): The table data to be sampled. non_table (Optional[Tensor]): The non-table data to be sampled. graph (GraphData): The graph data to be sampled. num_samples (List[int]): The number of samples to be taken from each layer of the graph. train_mask (Optional[Tensor]): The mask to be used for sampling. **kwargs: Additional keyword arguments to be passed to the `NeighborLoader` class. """ def __init__( self, table: TableData, non_table: Optional[Tensor], graph: GraphData, num_samples: List[int], train_mask: Optional[Tensor] = None, **kwargs ): super().__init__( data=graph, num_neighbors=num_samples, seeds=train_mask, replace=False, transform=None, **kwargs ) self.table = table self.non_table = non_table self.sep = len(table)
[docs] def collate_fn(self, batch): batch, n_id, adjs = super().collate_fn(batch) if self.non_table is None: return batch, n_id, adjs, self.table[n_id], None table_id = n_id[n_id < self.sep] non_table_id = n_id[n_id >= self.sep] - self.sep table_data = self.table[table_id] non_table_data = self.non_table[non_table_id] return batch, n_id, adjs, table_data, non_table_data