rllm.nn.models.HeteroSAGE¶
- class rllm.nn.models.HeteroSAGE(node_types: List[str], edge_types: List[Tuple[str, str, str]], hidden_dim: int, aggr: str = 'mean', num_layers: int = 2)[source]¶
Bases:
ModuleThe heterogeneous version of the GraphSAGE model.
- Parameters:
node_types (List[str]) – The list of node types.
edge_types (List[Tuple[str, str, str]]) – The list of edge types.
hidden_dim (int) – The number of hidden channels.
aggr (str) – The aggregation method. (default:
"mean")num_layers (int) – The number of layers. (default:
2)
Example
>>> from rllm.nn.models import HeteroSAGE >>> model = HeteroSAGE( ... node_types=["user", "item"], ... edge_types=[("user", "rates", "item")], ... hidden_dim=16, ... )
- forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Tensor]) Dict[str, Tensor][source]¶
Run heterogeneous GraphSAGE message passing.
- Parameters:
x_dict (Dict[str, Tensor]) – Input node features by node type.
edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Edge indices by edge type.
- Returns:
Updated node embeddings for each node type.
- Return type:
Dict[str, Tensor]