rllm.transforms.graph_transforms.NodeTransform¶
- class rllm.transforms.graph_transforms.NodeTransform[source]¶
Bases:
ABCBase class for node-wise transformations on graph data.
The transform is applied to
xfor homogeneous graphs, each validstore.xfor heterogeneous graphs, or directly to a squaretorch.Tensor.- Shape:
GraphData:data.xcan be any node feature shape accepted by subclasses.HeteroGraphData:store.xcan be any node feature shape accepted by subclasses.torch.Tensor: input must be a square matrix with shape[N, N].
Examples:
class NormalizeNodeX(NodeTransform): def forward(self, x): return x / (x.norm(dim=-1, keepdim=True) + 1e-12) transform = NormalizeNodeX() out = transform(data)