Source code for rllm.nn.conv.graph_conv.lgc_conv

from typing import Optional, Union

import torch
from torch import Tensor
from torch.sparse import Tensor as SparseTensor

from rllm.nn.conv.graph_conv import MessagePassing


[docs] class LGCConv(MessagePassing): r"""The LGC (Lazy Graph Convolution) implementation with message passing, based on the `"From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited" <https://arxiv.org/abs/2309.13599>`__ paper. This model use hyperparameter :math:`\beta` to control the message attribution of both neighbor nodes and the node itself: - If :math:`\beta = 1`, the model is equivalent to the graph convolution form of GCN model. (if `with_param` is `True`, the model is equivalent to the GCN model) - If :math:`\beta = 0`, the model only focus on the node itself. .. math:: \mathbf{\hat{A}} = \mathbf{\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}}} \mathbf{X}^{(k+1)} = \left[\beta\mathbf{\hat{A}} + (1-\beta)\mathbf{I}\right] \mathbf{X}^{(k)} Args: beta (float): The hyperparameter :math:`\beta` to control the message attribution. with_param (bool): If set to `True`, the model will learn a linear transformation for node features. in_dim (int): Size of each input sample. out_dim (int): Size of each output sample. bias (bool): If set to `False`, no bias terms are added into the final output. Only available when `with_param` is `True`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` edge_index is sparse adjacency matrix :math:`(|\mathcal{V}|, |\mathcal{V}|)` or edge list :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ node_dim: int = 0 def __init__( self, beta: float = 0.5, *, with_param: bool = False, in_dim: Optional[int] = None, out_dim: Optional[int] = None, bias: bool = False, ): super().__init__(aggr="sum") self.beta = beta self.with_param = with_param if self.with_param: assert ( in_dim is not None and out_dim is not None ), "in_dim and out_dim should be provided" self.lin = torch.nn.Linear(in_dim, out_dim, bias=False) if bias: self.bias = torch.nn.Parameter(torch.empty(out_dim)) else: self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: torch.nn.init.xavier_normal_(self.lin.weight) if self.bias is not None: torch.nn.init.zeros_(self.bias)
[docs] def forward( self, x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None, ) -> Tensor: r"""Run lazy graph convolution with optional linear projection. Args: x (Tensor): Input node features. edge_index (Union[Tensor, SparseTensor]): Graph connectivity in edge-list or sparse adjacency format. edge_weight (Optional[Tensor]): Optional per-edge weights. Returns: Tensor: Output node features after lazy propagation. Example: >>> import torch >>> from rllm.nn.conv.graph_conv import LGCConv >>> conv = LGCConv(beta=0.5) >>> x = torch.randn(4, 8) >>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) >>> out = conv(x, edge_index) >>> out.shape torch.Size([4, 8]) """ if self.with_param: x = self.lin(x) x = self.propagate(x, edge_index, edge_weight=edge_weight, dim=self.node_dim) if self.with_param and self.bias is not None: x = x + self.bias return x
[docs] def message_and_aggregate(self, x, edge_index, edge_weight, dim): r"""Lazy Graph Convolution .. math:: lazy\_msgs = \beta * gcn\_msgs + (1-\beta) * \mathbf{X} """ edge_index, ew = self.__unify_edgeindex__(edge_index) if edge_weight is None and ew is not None: edge_weight = ew src_index = edge_index[0, :] gcn_msgs = x.index_select(dim=0, index=src_index) if edge_weight is not None: gcn_msgs = gcn_msgs * edge_weight.view(-1, 1) gcn_msgs = self.aggr_module( gcn_msgs, edge_index[1, :].squeeze(), dim=dim, dim_size=x.size(0) ) return self.beta * gcn_msgs + (1 - self.beta) * x
def __repr__(self) -> str: return f"{self.__class__.__name__}(beta: {self.beta}, with_param: {self.with_param})"