rllm.data.HeteroGraphData

class rllm.data.HeteroGraphData(mapping: Mapping[str, Any] | None = None, **kwargs)[source]

Bases: BaseGraph

A class for heterogenerous graph data storage which easily fit into CPU memory.

Acceptable edge key words are adj and edge_index. Other edge key words are considered as edge attributes.

Methods of initialization:
  1. Assign attributes,

data = HeteroGraphData()
data['paper']['x'] = x_paper
data['paper'].x = x_paper
Tips:

Though name of node attribute can be arbitrary, x is prefered.

  1. pass them as keyword arguments,

data = HeteroGraphData(
    'paper' = {'x': x_paper, 'y': labels},
    'writer' = {'x': x_writer},
    'writer__of__paper' = {'adj' = adj}
)
  1. pass them as dictionaries,

data = HeteroGraphData(
    {
        'paper' = {'x': x_paper, 'y': labels},
        'writer' = {'x': x_writer},
        ('writer', 'of', 'paper') = {'adj' = adj}
    }
)

Save some attributes like train_mask:

data.train_mask = train_mask

Save more edges and nodes:

data[edge_type|node_type] = {
    ...
}

Key of edge type:
data['src__tgt'] =  {'adj': adj}
data[src, tgt] = {'adj': adj}
data[src, rel, tgt] = {'adj': adj}

Key of node type:
data['node type'] = {'x': x}
adj_dict()[source]

Collects the attribute adj from all edge types.

collect_attr(key: str | Tuple[str, str, str], exlude_None: bool = False) Dict[str | Tuple[str, str, str], Any][source]

Collects the attribute key from all node and edge types.

Parameters:
  • key (str) – The attribute key to collect.

  • exlude_None (bool, optional) – If set to True, will exclude the None attribute values. (default: False)

Example

>>> data = HeteroGraphData()
>>> data['paper'].x = ...
>>> data['author'].x = ...
>>> data['author', 'writes', 'paper'].edge_index = ...
>>> data.collect_attr('x')
{'paper': ..., 'author': ...}
edge_items()[source]

Returns a list of edge type and edge storage pairs.

property edge_stores

Returns a list of all edge storages of the graph.

property edge_types

Returns a list of all edge types of the graph.

metadata()[source]

Returns the heterogeneous meta-data, i.e. its node and edge types.

data = HeteroData()
data['paper'].x = ...
data['author'].x = ...
data['author', 'writes', 'paper'].edge_index = ...

print(data.metadata())
>>> (['paper', 'author'], [('author', 'writes', 'paper')])
node_items()[source]

Returns a list of node type and node storage pairs.

property node_stores

Returns a list of all node storages of the graph.

property node_types

Returns a list of all node types of the graph.

set_value_dict(key: str, value_d: Dict[str | Tuple[str, str, str], Any]) HeteroGraphData[source]

Set the attribute key for each node and edge type in value dict.

Parameters:
  • key (str) – The attribute key to set.

  • value (Dict[Union[NodeType, EdgeType], Any]) – The attribute values.

property stores

Returns a list of all storages of the graph.

to_csc_dict(device: device | None = None, share_memory: bool = False, is_sorted: bool = False, node_time_d: Dict[str, Tensor] | None = None, edge_time_d: Dict[Tuple[str, str, str] | str, Tensor] | None = None) Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, Tensor | None]][source]

Convert the heterogeneous graph edge into a CSC format for sampling. Returns dictionaries holding colptr and row indices as well as edge permutations for each edge type, respectively.

Parameters:
  • device (torch.device, optional) – The device to move the tensors to.

  • share_memory (bool, optional) – If set to True, will share memory with the original tensor.This can accelerate process when using multiple processes.

  • is_sorted (bool, optional) – If set to True, will not sort the edge index by column.

  • node_time_d (Dict[str, Tensor], optional) – The node time attribute dictionary.

  • edge_time_d (Dict[str, Tensor], optional) – The edge time attribute dictionary.

Returns:

  • colptr_d holds the column pointers for each edge type.

  • row_d holds the row indices for each edge type.

  • perm_d holds the permutation indices for each edge type.

validate() bool[source]

Validates the graph data by checking the following: 1. Node and edge types are matched. 2. Edge types are valid. 3. Edge indices are valid.

x_dict()[source]

Collects the attribute x from all node types.