rllm.dataloader.RelbenchLoader

class rllm.dataloader.RelbenchLoader(dataset: RelBenchDataset, task: RelBenchTask | str, split: str = 'train', shuffle: bool = False, batch_size: int = 512, num_neighbors: List[int] = [15, 10], to_bidirectional: bool = False, use_pyg_lib: bool = True)[source]

Bases: DataLoader

DataLoader for RelBench dataset with heterogeneous neighbor sampling.

Parameters:
  • dataset (RelBenchDataset) – The RelBench dataset.

  • task (Union[RelBenchTask, str]) – The task to load.

  • split (str) – The data split to load. (default: 'train')

  • shuffle (bool) – Whether to shuffle the data. (default: False)

  • batch_size (int) – The batch size. (default: 512)

  • num_neighbors (List[int]) – Number of neighbors to sample at each hop. (default: [15, 10])

  • to_bidirectional (bool) – Whether to convert the graph to bidirectional by adding reverse edges. (default: False)

  • use_pyg_lib (bool) – Whether to use PyG-lib for neighbor sampling. Falls back to the pure Python sampler if not installed. (default: True)

collate_fn(index: List[int] | Tensor) HeteroSamplerOutput[source]

Sample a mini-batch sub-heterogeneous graph from input nodes.

static filter_edge_store_(store: NodeStorage, out_store: NodeStorage, row: Tensor, col: Tensor, perm: Tensor | None = None)[source]

Filter an edge storage to only hold the sampled edges represented by (row, col).

Parameters:
  • store (NodeStorage) – The source edge storage.

  • out_store (NodeStorage) – The output edge storage to write into.

  • row (Tensor) – Source node indices of sampled edges.

  • col (Tensor) – Destination node indices of sampled edges.

  • perm (Tensor, optional) – Edge permutation indices. (default: None)

filter_fn(out: HeteroSamplerOutput) HeteroGraphData[source]

Join sampled node/edge indices with their features and metadata.

Parameters:

out (HeteroSamplerOutput) – Raw sampler output containing node and edge indices.

Returns:

A mini-batch heterogeneous graph with node features, edge indices, timestamps, and target labels attached.

Return type:

HeteroGraphData

static filter_node_store_(store: NodeStorage, out_store: NodeStorage, index: Tensor)[source]

Filter a node storage to only hold the nodes given by index.

Parameters:
  • store (NodeStorage) – The source node storage.

  • out_store (NodeStorage) – The output node storage to write into.

  • index (Tensor) – The 1-D tensor of node indices to keep.

static get_loaders(dataset: RelBenchDataset, task: RelBenchTask | str, batch_size: int = 512, num_neighbors: List[int] = [15, 10], to_bidirectional: bool = False) List[RelbenchLoader][source]

Create train, val, and test loaders for each split.

Parameters:
  • dataset (RelBenchDataset) – The RelBench dataset.

  • task (Union[RelBenchTask, str]) – The task to load.

  • batch_size (int) – The batch size. (default: 512)

  • num_neighbors (List[int]) – Number of neighbors to sample at each hop. (default: [15, 10])

  • to_bidirectional (bool) – Whether to add reverse edges. (default: False)

Returns:

A list of [train_loader, val_loader, test_loader].

Return type:

List[RelbenchLoader]

static index_select(value: TableData, index: Tensor, dim: int = 0) Tensor[source]

Index the value table along dimension dim using the entries in index.

Parameters:
  • value (TableData) – The input table.

  • index (Tensor) – The 1-D tensor containing the indices to select.

  • dim (int, optional) – The dimension along which to index. (default: 0)

Returns:

The indexed sub-table.

Return type:

TableData