rllm.nn.conv.graph_conv.MessagePassing

class rllm.nn.conv.graph_conv.MessagePassing(aggr: str | Aggregator | None = 'sum', *, aggr_kwargs: Dict[str, Any] | None = None)[source]

Bases: Module, ABC

Base class for message passing.

Message passing is the general framework for graph neural networks. Its forward formula is defined as:

\[\mathbf{x}_i^{(k+1)} = \text{Update}^{(k)} \left( \mathbf{x}_i^{(k)}, \text{Aggregate}^{(k)} \left( \left\{ \text{Message}^{(k)} \left( \mathbf{x}_i^{(k)}, \mathbf{x}_j^{(k)}, \mathbf{e}_{j,i}^{(k)} \right) \right\}_{j \in \mathcal{N}(i)} \right) \right)\]
Parameters:
  • aggr (Optional[Union[str, Aggregator]]) – The aggregation method to use. (default: "sum")

  • aggr_kwargs (Optional[Dict[str, Any]]) – Additional arguments for the aggregator. (default: None)

aggr_resolver(target_aggr: str | Aggregator, **kwargs) Aggregator[source]

Resolve the aggregator.

aggregate(msgs: Tensor, edge_index: Tensor, dim: int = 0, dim_size: int | None = None)[source]

Aggregate messages from src nodes to dst nodes.

Parameters:
  • msgs (Tensor) – The messages to aggregate.

  • edge_index (Union[Tensor, SparseTensor]) – The edge indices.

  • dim (int) – The dimension along which to aggregate. (default: 0)

  • dim_size (Optional[int]) – The size of the output tensor at dim. If None, inferred from edge_index. (default: None)

Returns:

Aggregated node representations of shape \((\text{dim\_size}, F)\).

Return type:

Tensor

explain(kwargs: Dict[str, Any]) Any[source]

Explain the behavior of the message passing layer.

For now, keep it a interface and implement it in the subclasses if necessary.

property if_explain: bool

Whether to enable explain mode.

message(x: Tensor, edge_index: Tensor, edge_weight: Tensor | None = None) Tensor[source]

Compute message from src nodes \(v_j\) to dst nodes \(v_i\).

Parameters:
  • x (Tensor) – The input node feature matrix.

  • edge_index (Union[Tensor, SparseTensor]) – The edge indices or adj.

  • edge_weight (Tensor) – The edge weights.

Returns:

The message tensor with size \((|V_{src}|, \text{message_dim})\).

Return type:

Tensor

message_and_aggregate(edge_index: Tensor) Tensor[source]

The message and aggregation interface to be overridden by subclasses.

propagate(x: Tensor | Tuple[Tensor, Tensor], edge_index: Tensor, **kwargs) Tensor[source]

The initial call to start propagating messages. This method will call message(), aggregate() ( or message_and_aggregate() if it’s available ) and update() in sequence to complete once propagate.

Parameters:
  • x (Union[Tensor, Tuple[Tensor, Tensor]]) –

    • Tensor: The input node feature matrix. \((|V|, F_{in})\)

    • Tuple[Tensor, Tensor]: The input node feature matrix for source and destination nodes.

  • edge_index (Union[Tensor, SparseTensor]) – The edge indices. Tensor, \((2, |E|)\)

  • **kwargs – Additional arguments for the message, aggregate and update functions.

Returns:

Updated destination node representations after running message, aggregate (or message_and_aggregate), and update.

Return type:

Tensor

Example

>>> import torch
>>> from rllm.nn.conv.graph_conv import GCNConv
>>> conv = GCNConv(8, 4)
>>> x = torch.randn(5, 8)
>>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])
>>> out = conv.propagate(x, edge_index)
>>> out.shape
torch.Size([5, 8])
update(output: Tensor) Tensor[source]

Update the dst node embeddings.