Source code for rllm.transforms.graph_transforms.add_remaining_self_loops

from functools import lru_cache

from torch import Tensor

from rllm.transforms.graph_transforms import EdgeTransform
from rllm.transforms.graph_transforms.functional import add_remaining_self_loops


[docs] class AddRemainingSelfLoops(EdgeTransform): r"""Adds missing self-loops to the adjacency matrix. .. math:: \mathbf{\hat{A}} = \mathbf{A} + \mathbf{I} Args: fill_value (Any): Value used for added self-loops. (default: :obj:`1.0`) """ def __init__(self, fill_value=1.0): super().__init__() self.fill_value = fill_value @lru_cache() def forward(self, adj: Tensor) -> Tensor: return add_remaining_self_loops(adj, self.fill_value)