Source code for rllm.transforms.graph_transforms.gcn_norm
from functools import lru_cache
from torch import Tensor
from rllm.transforms.graph_transforms import EdgeTransform
from rllm.transforms.graph_transforms.functional import (
add_remaining_self_loops,
symmetric_norm,
)
[docs]
class GCNNorm(EdgeTransform):
r"""Applies the standard GCN adjacency normalization.
Proposed in `"Semi-supervised Classification with Graph Convolutional
Networks" <https://arxiv.org/abs/1609.02907>`__.
.. math::
\mathbf{\hat{A}} = \mathbf{\hat{D}}^{-1/2} (\mathbf{A} + \mathbf{I})
\mathbf{\hat{D}}^{-1/2}
"""
def __init__(self):
super().__init__()
@lru_cache()
def forward(self, adj: Tensor) -> Tensor:
adj = add_remaining_self_loops(adj)
return symmetric_norm(adj)