rllm.nn.conv.graph_conv.SAGEConv¶
- class rllm.nn.conv.graph_conv.SAGEConv(in_dim: int, out_dim: int, aggr: str | ~rllm.nn.conv.graph_conv.aggrs.Aggregator | None = 'sum', activation: ~typing.Callable | None = <function relu>, dropout: float = 0.0, bias: bool = False, dst_in_dim: int | None = None, **kwargs)[source]¶
Bases:
MessagePassingSimple SAGEConv layer implementation with message passing, as introduced in the “Inductive Representation Learning on Large Graphs” paper.
- Supported aggregators:
sum, mean, max_pool, mean_pool, gcn, lstm
- Parameters:
in_dim (int) – Size of each input sample.
out_dim (int) – Size of each output sample.
aggr (str or Aggregator) – The aggregation method to use, e.g.,
"sum","mean","max_pool","mean_pool","gcn","lstm". (default:"sum")activation (Callable) – The activation function applied after aggregation. (default:
F.relu)dropout (float) – Dropout probability applied to node features before aggregation. (default:
0.0)bias (bool) – If set to
False, no bias terms are added into the final output. (default:False)dst_in_dim (Optional[int]) – The input dimension of the destination nodes. If
None,in_dimis used. Useful for heterogeneous graphs where source and destination nodes have different dimensions. (default:None)
- forward(x: Tensor | Tuple[Tensor], edge_index: Tensor, edge_weight: Tensor | None = None)[source]¶
Aggregate neighbor information and combine with destination features.
- Parameters:
x (Union[Tensor, Tuple[Tensor]]) –
Tensor input features for homogeneous graphs.
Tuple containing source and destination node features.
edge_index (Union[Tensor, SparseTensor]) – Graph connectivity in edge-list or sparse adjacency format.
edge_weight (Optional[Tensor]) – Optional edge weights used by certain aggregators such as
gcn.
- Returns:
Output embeddings for destination nodes.
- Return type:
Tensor
Example
>>> import torch >>> from rllm.nn.conv.graph_conv import SAGEConv >>> conv = SAGEConv(16, 8, aggr='sum') >>> x = torch.randn(4, 16) >>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) >>> out = conv(x, edge_index) >>> out.shape torch.Size([4, 8])