Source code for rllm.nn.conv.graph_conv.message_passing

from abc import ABC
import inspect
from collections import OrderedDict
from functools import lru_cache
from typing import Tuple, Callable, Dict, Any, Union, Optional, overload

import torch
from torch import Tensor
from torch.sparse import Tensor as SparseTensor

from rllm.nn.conv.graph_conv.aggrs import Aggregator


[docs] class MessagePassing(torch.nn.Module, ABC): r"""Base class for message passing. Message passing is the general framework for graph neural networks. Its forward formula is defined as: .. math:: \mathbf{x}_i^{(k+1)} = \text{Update}^{(k)} \left( \mathbf{x}_i^{(k)}, \text{Aggregate}^{(k)} \left( \left\{ \text{Message}^{(k)} \left( \mathbf{x}_i^{(k)}, \mathbf{x}_j^{(k)}, \mathbf{e}_{j,i}^{(k)} \right) \right\}_{j \in \mathcal{N}(i)} \right) \right) Args: aggr (Optional[Union[str, Aggregator]]): The aggregation method to use. (default: :obj:`"sum"`) aggr_kwargs (Optional[Dict[str, Any]]): Additional arguments for the aggregator. (default: :obj:`None`) """ def __init__( self, aggr: Optional[Union[str, Aggregator]] = "sum", *, aggr_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__() self.__explain__ = self.__is_overrided__(self.explain) self.__msg_aggr__ = self.__is_overrided__(self.message_and_aggregate) self.aggr_module = self.aggr_resolver(aggr, **(aggr_kwargs or {}))
[docs] def propagate( self, x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], **kwargs, ) -> Tensor: r""" The initial call to start propagating messages. This method will call :meth:`message`, :meth:`aggregate` ( or :meth:`message_and_aggregate` if it's available ) and :meth:`update` in sequence to complete once propagate. Args: x (Union[Tensor, Tuple[Tensor, Tensor]]): - :obj:`Tensor`: The input node feature matrix. :math:`(|V|, F_{in})` - :obj:`Tuple[Tensor, Tensor]`: The input node feature matrix for source and destination nodes. edge_index (Union[Tensor, SparseTensor]): The edge indices. Tensor, :math:`(2, |E|)` **kwargs: Additional arguments for the message, aggregate and update functions. Returns: Tensor: Updated destination node representations after running message, aggregate (or message_and_aggregate), and update. Example: >>> import torch >>> from rllm.nn.conv.graph_conv import GCNConv >>> conv = GCNConv(8, 4) >>> x = torch.randn(5, 8) >>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) >>> out = conv.propagate(x, edge_index) >>> out.shape torch.Size([5, 8]) """ # Infer aggregator dim_size if "dim_size" not in kwargs or kwargs["dim_size"] is None: if x is not None: if isinstance(x, Tensor): kwargs["dim_size"] = x.size(0) else: kwargs["dim_size"] = x[1].size(0) else: raise ValueError("dim_size must be provided while x is None.") # message and aggregate if self.__msg_aggr__: msg_aggr_kwargs = self.__collect__( self.message_and_aggregate, x, edge_index, kwargs ) out = self.message_and_aggregate(**msg_aggr_kwargs) else: msg_kwargs = self.__collect__(self.message, x, edge_index, kwargs) out = self.message(**msg_kwargs) aggr_kwargs = self.__collect__(self.aggregate, x, edge_index, kwargs) out = self.aggregate(out, **aggr_kwargs) # update update_kwargs = self.__collect__(self.update, x, edge_index, kwargs) out = self.update(out, **update_kwargs) return out
[docs] def message( self, x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Tensor = None, ) -> Tensor: r""" Compute message from src nodes :math:`v_j` to dst nodes :math:`v_i`. Args: x (Tensor): The input node feature matrix. edge_index (Union[Tensor, SparseTensor]): The edge indices or adj. edge_weight (Tensor): The edge weights. Returns: Tensor: The message tensor with size :math:`(|V_{src}|, \text{message_dim})`. """ edge_index, edge_weight_ = self.__unify_edgeindex__(edge_index) edge_weight = edge_weight if edge_weight_ is None else edge_weight_ src_index = edge_index[0, :] msgs = x.index_select(dim=0, index=src_index) if edge_weight is not None: return msgs * edge_weight.view(-1, 1) return msgs
[docs] def aggregate( self, msgs: Tensor, edge_index: Union[Tensor, SparseTensor], dim: int = 0, dim_size: Optional[int] = None, ): r"""Aggregate messages from src nodes to dst nodes. Args: msgs (Tensor): The messages to aggregate. edge_index (Union[Tensor, SparseTensor]): The edge indices. dim (int): The dimension along which to aggregate. (default: :obj:`0`) dim_size (Optional[int]): The size of the output tensor at :obj:`dim`. If :obj:`None`, inferred from :obj:`edge_index`. (default: :obj:`None`) Returns: Tensor: Aggregated node representations of shape :math:`(\text{dim\_size}, F)`. """ edge_index, _ = self.__unify_edgeindex__(edge_index) return self.aggr_module(msgs, edge_index[1, :], dim=dim, dim_size=dim_size)
[docs] def message_and_aggregate(self, edge_index: Union[Tensor, SparseTensor]) -> Tensor: r"""The message and aggregation interface to be overridden by subclasses.""" raise NotImplementedError
[docs] def update(self, output: Tensor) -> Tensor: r"""Update the dst node embeddings.""" return output
# Properties @property def if_message_and_aggregate(self) -> bool: return self.__msg_aggr__ @if_message_and_aggregate.setter def if_message_and_aggregate(self, msg_aggr: bool) -> None: self.__msg_aggr__ = msg_aggr # Utility functions def __collect__(self, func: Callable, x, edge_index, kwargs) -> Dict[str, Any]: r"""Collects the arguments funcs.""" func_params = OrderedDict(self.__func_params__(func)) if func.__name__ in ["aggregate", "update"]: func_params.popitem(last=False) coll = OrderedDict() for k, v in func_params.items(): if k in kwargs: coll[k] = kwargs[k] elif k == "x": coll[k] = x elif k == "edge_index": coll[k] = edge_index else: if v.default != inspect.Parameter.empty: coll[k] = v.default else: raise ValueError(f"Missing required parameter {k}.") return coll @lru_cache def __func_params__(self, func: Callable) -> OrderedDict: return inspect.signature(func).parameters def __unify_edgeindex__( self, edge_index: Tensor ) -> Tuple[Tensor, Optional[Tensor]]: r"""Unify the edge index to a 2D tensor.""" if edge_index.is_sparse: return self.__adj_to_edges__(edge_index) elif edge_index.size(0) != 2: try: return self.__adj_to_edges__(edge_index) except ValueError: raise ValueError( f"Expect edge_index to be a 2D tensor, got {edge_index.size()}." ) else: return edge_index, None def __adj_to_edges__(self, adj: SparseTensor) -> Tuple[Tensor, Tensor]: r"""Converts a sparse adjacency matrix to edge indices.""" if adj.is_sparse: coo_adj = adj.to_sparse_coo().coalesce() s, d, vs = coo_adj.indices()[0], coo_adj.indices()[1], coo_adj.values() vs = None if torch.all(vs == 1) else vs return torch.stack([s, d]), vs else: raise TypeError(f"Expect adj to be a SparseTensor, got {type(adj)}.") @lru_cache def __is_overrided__(self, func: Callable) -> bool: r"""Check if the function is overridden. If so, return True.""" return getattr(self.__class__, func.__name__, None) != getattr( MessagePassing, func.__name__ )
[docs] def aggr_resolver(self, target_aggr: Union[str, Aggregator], **kwargs) -> Aggregator: r"""Resolve the aggregator.""" if isinstance(target_aggr, Aggregator): return target_aggr import rllm.nn.conv.graph_conv.aggrs as aggrs aggrs_l = [ getattr(aggrs, name) for name in dir(aggrs) if inspect.isclass(getattr(aggrs, name)) ] def normalize_str(s: str) -> str: return s.lower().replace("_", "").replace("-", "").replace(" ", "") norm_target_aggr = normalize_str(target_aggr) for aggr in aggrs_l: aggr_name = normalize_str(aggr.__name__) if norm_target_aggr in [aggr_name, aggr_name.replace("aggregator", "")]: return aggr(**kwargs) raise ValueError(f"Aggregator {target_aggr} not found.")
@overload def retrieve_feats( self, feats: Tensor, edge_index: Union[Tensor, SparseTensor], dim: Optional[int] = None, retrieve_dim: int = 0, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Non-bipartite graph, :obj:`feats` contains all nodes' features. Args: feats (Tensor): The node features. edge_index (Union[Tensor, SparseTensor]): The edge indices. dim (Optional[int]): The edge_index dimension to retrieve. If None, retrieve both src and dst. retrieve_dim (int): The dimension to retrieve. Returns: Union[Tensor, Tuple[Tensor, Tensor]]: Node features at dim, or both source and destination node features. :math:`(|E_{dim}|, F)` or (:math:`(|E_{dim}|, F)` and :math:`(|E_{dim}|, F)`). """ ... @overload def retrieve_feats( self, feats: Tuple[Tensor, Tensor], edge_index: Union[Tensor, SparseTensor], dim: Optional[int] = None, retrieve_dim: int = 0, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Bipartite graph, :obj:`feats` contains source and destination nodes' features. Args: feats (Tensor): The node features. edge_index (Union[Tensor, SparseTensor]): The edge indices. dim (Optional[int]): The edge_index dimension to retrieve. If None, retrieve both src and dst. retrieve_dim (int): The dimension to retrieve. Returns: Union[Tensor, Tuple[Tensor, Tensor]]: Node features at dim, or both source and destination node features. :math:`(|E_{dim}|, F)` or (:math:`(|E_{dim}|, F_{src})` and :math:`(|E_{dim}|, F_{dst})`). """ ... def retrieve_feats( self, feats: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], dim: Optional[int] = None, retrieve_dim: int = 0, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: edge_index, _ = self.__unify_edgeindex__(edge_index) if isinstance(feats, tuple): src_feats, dst_feats = feats if dim is None: src_feats = src_feats.index_select(retrieve_dim, edge_index[0, :]) dst_feats = dst_feats.index_select(retrieve_dim, edge_index[1, :]) return src_feats, dst_feats else: assert dim in [0, 1], f"Expect dim to be 0 or 1, got {dim}." if dim == 0: return src_feats.index_select(retrieve_dim, edge_index[0, :]) else: return dst_feats.index_select(retrieve_dim, edge_index[1, :]) else: if dim is None: src_feats = feats.index_select(retrieve_dim, edge_index[0, :]) dst_feats = feats.index_select(retrieve_dim, edge_index[1, :]) return src_feats, dst_feats else: assert dim in [0, 1], f"Expect dim to be 0 or 1, got {dim}." return feats.index_select(retrieve_dim, edge_index[dim, :]) # explain functions
[docs] def explain(self, kwargs: Dict[str, Any]) -> Any: r"""Explain the behavior of the message passing layer. For now, keep it a interface and implement it in the subclasses if necessary. """ raise NotImplementedError
@property def if_explain(self) -> bool: r"""Whether to enable explain mode.""" return self.__explain__ @if_explain.setter def if_explain(self, explain: bool) -> None: r"""Set the explain mode.""" self.__explain__ = explain