Source code for rllm.transforms.graph_transforms.normalize_features

from functools import lru_cache
from torch import Tensor

from rllm.transforms.graph_transforms import NodeTransform
from rllm.transforms.graph_transforms.functional import normalize_features


[docs] class NormalizeFeatures(NodeTransform): r"""Row-normalizes the node features. .. math:: \vec{x} = \frac{\vec{x}}{||\vec{x}||_p} Args: norm (str): The norm to use to normalize each non zero sample, *e.g.*, `l1`, `l2`. (default: `l2`) """ def __init__(self, norm: str = "l2"): self.norm = norm @lru_cache() def forward(self, x: Tensor) -> Tensor: return normalize_features(x, self.norm)