rllm.nn.models.HeteroSAGE

class rllm.nn.models.HeteroSAGE(node_types: List[str], edge_types: List[Tuple[str, str, str]], hidden_dim: int, aggr: str = 'mean', num_layers: int = 2)[source]

Bases: Module

The heterogeneous version of the GraphSAGE model.

Parameters:
  • node_types (List[str]) – The list of node types.

  • edge_types (List[Tuple[str, str, str]]) – The list of edge types.

  • hidden_dim (int) – The number of hidden channels.

  • aggr (str) – The aggregation method. (default: "mean")

  • num_layers (int) – The number of layers. (default: 2)

Example

>>> from rllm.nn.models import HeteroSAGE
>>> model = HeteroSAGE(
...     node_types=["user", "item"],
...     edge_types=[("user", "rates", "item")],
...     hidden_dim=16,
... )
forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Tensor]) Dict[str, Tensor][source]

Run heterogeneous GraphSAGE message passing.

Parameters:
  • x_dict (Dict[str, Tensor]) – Input node features by node type.

  • edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Edge indices by edge type.

Returns:

Updated node embeddings for each node type.

Return type:

Dict[str, Tensor]

reset_parameters()[source]

Resets all learnable parameters of the module.