rllm.nn.conv.graph_conv.SAGEConv

class rllm.nn.conv.graph_conv.SAGEConv(in_dim: int, out_dim: int, aggr: str | ~rllm.nn.conv.graph_conv.aggrs.Aggregator | None = 'sum', activation: ~typing.Callable | None = <function relu>, dropout: float = 0.0, bias: bool = False, dst_in_dim: int | None = None, **kwargs)[source]

Bases: MessagePassing

Simple SAGEConv layer implementation with message passing, as introduced in the “Inductive Representation Learning on Large Graphs” paper.

Supported aggregators:

sum, mean, max_pool, mean_pool, gcn, lstm

Parameters:
  • in_dim (int) – Size of each input sample.

  • out_dim (int) – Size of each output sample.

  • aggr (str or Aggregator) – The aggregation method to use, e.g., "sum", "mean", "max_pool", "mean_pool", "gcn", "lstm". (default: "sum")

  • activation (Callable) – The activation function applied after aggregation. (default: F.relu)

  • dropout (float) – Dropout probability applied to node features before aggregation. (default: 0.0)

  • bias (bool) – If set to False, no bias terms are added into the final output. (default: False)

  • dst_in_dim (Optional[int]) – The input dimension of the destination nodes. If None, in_dim is used. Useful for heterogeneous graphs where source and destination nodes have different dimensions. (default: None)

forward(x: Tensor | Tuple[Tensor], edge_index: Tensor, edge_weight: Tensor | None = None)[source]

Aggregate neighbor information and combine with destination features.

Parameters:
  • x (Union[Tensor, Tuple[Tensor]]) –

    • Tensor input features for homogeneous graphs.

    • Tuple containing source and destination node features.

  • edge_index (Union[Tensor, SparseTensor]) – Graph connectivity in edge-list or sparse adjacency format.

  • edge_weight (Optional[Tensor]) – Optional edge weights used by certain aggregators such as gcn.

Returns:

Output embeddings for destination nodes.

Return type:

Tensor

Example

>>> import torch
>>> from rllm.nn.conv.graph_conv import SAGEConv
>>> conv = SAGEConv(16, 8, aggr='sum')
>>> x = torch.randn(4, 16)
>>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])
>>> out = conv(x, edge_index)
>>> out.shape
torch.Size([4, 8])
reset_parameters()[source]

Resets all learnable parameters of the module.