rllm.transforms.graph_transforms.NodeTransform

class rllm.transforms.graph_transforms.NodeTransform[source]

Bases: ABC

Base class for node-wise transformations on graph data.

The transform is applied to x for homogeneous graphs, each valid store.x for heterogeneous graphs, or directly to a square torch.Tensor.

Shape:
  • GraphData: data.x can be any node feature shape accepted by subclasses.

  • HeteroGraphData: store.x can be any node feature shape accepted by subclasses.

  • torch.Tensor: input must be a square matrix with shape [N, N].

Examples:

class NormalizeNodeX(NodeTransform):
    def forward(self, x):
        return x / (x.norm(dim=-1, keepdim=True) + 1e-12)

transform = NormalizeNodeX()
out = transform(data)