rllm.nn.conv.graph_conv.MessagePassing¶
- class rllm.nn.conv.graph_conv.MessagePassing(aggr: str | Aggregator | None = 'sum', *, aggr_kwargs: Dict[str, Any] | None = None)[source]¶
Bases:
Module,ABCBase class for message passing.
Message passing is the general framework for graph neural networks. Its forward formula is defined as:
\[\mathbf{x}_i^{(k+1)} = \text{Update}^{(k)} \left( \mathbf{x}_i^{(k)}, \text{Aggregate}^{(k)} \left( \left\{ \text{Message}^{(k)} \left( \mathbf{x}_i^{(k)}, \mathbf{x}_j^{(k)}, \mathbf{e}_{j,i}^{(k)} \right) \right\}_{j \in \mathcal{N}(i)} \right) \right)\]- Parameters:
aggr (Optional[Union[str, Aggregator]]) – The aggregation method to use. (default:
"sum")aggr_kwargs (Optional[Dict[str, Any]]) – Additional arguments for the aggregator. (default:
None)
- aggregate(msgs: Tensor, edge_index: Tensor, dim: int = 0, dim_size: int | None = None)[source]¶
Aggregate messages from src nodes to dst nodes.
- Parameters:
msgs (Tensor) – The messages to aggregate.
edge_index (Union[Tensor, SparseTensor]) – The edge indices.
dim (int) – The dimension along which to aggregate. (default:
0)dim_size (Optional[int]) – The size of the output tensor at
dim. IfNone, inferred fromedge_index. (default:None)
- Returns:
Aggregated node representations of shape \((\text{dim\_size}, F)\).
- Return type:
Tensor
- explain(kwargs: Dict[str, Any]) Any[source]¶
Explain the behavior of the message passing layer.
For now, keep it a interface and implement it in the subclasses if necessary.
- property if_explain: bool¶
Whether to enable explain mode.
- message(x: Tensor, edge_index: Tensor, edge_weight: Tensor | None = None) Tensor[source]¶
Compute message from src nodes \(v_j\) to dst nodes \(v_i\).
- Parameters:
x (Tensor) – The input node feature matrix.
edge_index (Union[Tensor, SparseTensor]) – The edge indices or adj.
edge_weight (Tensor) – The edge weights.
- Returns:
The message tensor with size \((|V_{src}|, \text{message_dim})\).
- Return type:
Tensor
- message_and_aggregate(edge_index: Tensor) Tensor[source]¶
The message and aggregation interface to be overridden by subclasses.
- propagate(x: Tensor | Tuple[Tensor, Tensor], edge_index: Tensor, **kwargs) Tensor[source]¶
The initial call to start propagating messages. This method will call
message(),aggregate()( ormessage_and_aggregate()if it’s available ) andupdate()in sequence to complete once propagate.- Parameters:
x (Union[Tensor, Tuple[Tensor, Tensor]]) –
Tensor: The input node feature matrix. \((|V|, F_{in})\)Tuple[Tensor, Tensor]: The input node feature matrix for source and destination nodes.
edge_index (Union[Tensor, SparseTensor]) – The edge indices. Tensor, \((2, |E|)\)
**kwargs – Additional arguments for the message, aggregate and update functions.
- Returns:
Updated destination node representations after running message, aggregate (or message_and_aggregate), and update.
- Return type:
Tensor
Example
>>> import torch >>> from rllm.nn.conv.graph_conv import GCNConv >>> conv = GCNConv(8, 4) >>> x = torch.randn(5, 8) >>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) >>> out = conv.propagate(x, edge_index) >>> out.shape torch.Size([5, 8])