Source code for rllm.nn.conv.graph_conv.gcn_conv

from typing import Optional, Union

import torch
from torch import Tensor
from torch.nn import Linear, Parameter
from torch.sparse import Tensor as SparseTensor
import torch.nn.init as init

from rllm.transforms.graph_transforms import GCNNorm
from rllm.nn.conv.graph_conv import MessagePassing


[docs] class GCNConv(MessagePassing): r"""The GCN (Graph Convolutional Network) model implementation with message passing, based on the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`__ paper. This model applies convolution operations to graph-structured data, allowing for the aggregation of feature information from neighboring nodes. .. math:: \mathbf{X}^{\prime} = \mathbf{\tilde{A}} \mathbf{X} \mathbf{W} Args: in_dim (int): Size of each input sample. out_dim (int): Size of each output sample. bias (bool): If set to `False`, no bias terms are added into the final output. normalize (bool): If set to `True`, the adjacency matrix is normalized using the symmetric normalization method. The normalization is performed as follows: :math:`\mathbf{\tilde{A}} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}`. where :math:`\mathbf{D}` is the degree matrix of the graph. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` edge_index is sparse adjacency matrix :math:`(|\mathcal{V}|, |\mathcal{V}|)` or edge list :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__( self, in_dim: int, out_dim: int, bias: bool = True, normalize: bool = False, ): super().__init__(aggr="gcn") self.in_dim = in_dim self.out_dim = out_dim self.linear = Linear(in_dim, out_dim, bias=False) if bias: self.bias = Parameter(torch.empty(out_dim)) else: self.register_parameter("bias", None) self.normalize = normalize if normalize: self.norm = GCNNorm() self.reset_parameters() def reset_parameters(self) -> None: init.xavier_normal_(self.linear.weight) if self.bias is not None: init.zeros_(self.bias)
[docs] def forward( self, x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None, dim_size: Optional[int] = None, ) -> Tensor: r"""Apply a single GCN message-passing step. Args: x (Tensor): Input node features with shape ``[num_nodes, in_dim]``. edge_index (Union[Tensor, SparseTensor]): Graph connectivity in edge-list or sparse adjacency format. edge_weight (Optional[Tensor]): Optional edge weights used during neighborhood aggregation. dim_size (Optional[int]): Optional number of destination nodes for aggregation output shape inference. Returns: Tensor: Output node features with shape ``[num_nodes, out_dim]``. Example: >>> import torch >>> from rllm.nn.conv.graph_conv import GCNConv >>> conv = GCNConv(16, 8) >>> x = torch.randn(5, 16) >>> edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]]) >>> out = conv(x, edge_index) >>> out.shape torch.Size([5, 8]) """ if self.normalize: if not edge_index.is_sparse: raise ValueError( "GCNNorm requires sparse adjacency input when " "`normalize=True`. Set `normalize=False` for edge-list " "or dense inputs." ) edge_index = self.norm(edge_index) x = self.linear(x) out = self.propagate(x, edge_index, edge_weight=edge_weight, dim_size=dim_size) if self.bias is not None: out += self.bias return out
def __repr__(self) -> str: return f"{self.__class__.__name__}({self.in_dim}, " f"{self.out_dim})"