Source code for rllm.transforms.graph_transforms.node_edge_transform
import copy
from abc import ABC, abstractmethod
from typing import Union
from torch import Tensor
from rllm.data.graph_data import GraphData, HeteroGraphData
[docs]
class NodeTransform(ABC):
r"""Base class for node-wise transformations on graph data.
The transform is applied to ``x`` for homogeneous graphs, each valid
``store.x`` for heterogeneous graphs, or directly to a square
:class:`torch.Tensor`.
Shape:
- ``GraphData``: ``data.x`` can be any node feature shape accepted by
subclasses.
- ``HeteroGraphData``: ``store.x`` can be any node feature shape
accepted by subclasses.
- :class:`torch.Tensor`: input must be a square matrix with shape
``[N, N]``.
Examples::
class NormalizeNodeX(NodeTransform):
def forward(self, x):
return x / (x.norm(dim=-1, keepdim=True) + 1e-12)
transform = NormalizeNodeX()
out = transform(data)
"""
def __call__(self, data: Union[GraphData, HeteroGraphData, Tensor]):
# Shallow-copy the data so that we prevent in-place data modification.
data = copy.copy(data)
if isinstance(data, GraphData):
if getattr(data, "x", None) is not None:
data.x = self.forward(data.x)
elif isinstance(data, HeteroGraphData):
for store in data.node_stores:
if "x" not in store or not store.is_bipartite():
continue
store.x = self.forward(store.x)
elif isinstance(data, Tensor):
assert data.size(0) == data.size(1)
data = self.forward(data)
return data
@abstractmethod
def forward(self, x: Tensor) -> Tensor:
raise NotImplementedError
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
[docs]
class EdgeTransform(ABC):
r"""Base class for edge-wise transformations on graph data.
The transform is applied to ``adj`` for homogeneous graphs, each valid
``store.adj`` for heterogeneous graphs, or directly to a tensor input.
Shape:
- ``GraphData``: ``data.adj`` can be dense or sparse and should follow
the adjacency format expected by subclasses.
- ``HeteroGraphData``: ``store.adj`` can be dense or sparse and should
follow the adjacency format expected by subclasses.
- :class:`torch.Tensor`: if dense, input must be a square matrix with
shape ``[N, N]``; sparse tensors are forwarded as-is.
Examples::
class KeepSelfLoops(EdgeTransform):
def forward(self, adj):
return adj
transform = KeepSelfLoops()
out = transform(data)
"""
def __call__(self, data: Union[GraphData, HeteroGraphData, Tensor]):
# Shallow-copy the data so that we prevent in-place data modification.
data = copy.copy(data)
if isinstance(data, GraphData):
if getattr(data, "adj", None) is not None:
data.adj = self.forward(data.adj)
elif isinstance(data, HeteroGraphData):
for store in data.edge_stores:
if "adj" not in store or not store.is_bipartite():
continue
store.adj = self.forward(store.adj)
elif isinstance(data, Tensor):
if not data.is_sparse:
assert data.size(0) == data.size(1)
data = self.forward(data)
return data
@abstractmethod
def forward(self, adj: Tensor) -> Tensor:
raise NotImplementedError
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"