Source code for rllm.transforms.graph_transforms.remove_self_loops
from functools import lru_cache
from torch import Tensor
from rllm.transforms.graph_transforms import EdgeTransform
from rllm.transforms.graph_transforms.functional import remove_self_loops
[docs]
class RemoveSelfLoops(EdgeTransform):
r"""Removes diagonal self-loop edges from an adjacency matrix."""
def __init__(self):
super().__init__()
@lru_cache()
def forward(self, adj: Tensor) -> Tensor:
return remove_self_loops(adj)