Source code for rllm.transforms.graph_transforms.svd_feature_reduction

from functools import lru_cache

from torch import Tensor

from rllm.transforms.graph_transforms import NodeTransform
from rllm.transforms.graph_transforms.functional import svd_feature_reduction


[docs] class SVDFeatureReduction(NodeTransform): r"""Dimensionality reduction of node features via Singular Value Decomposition (SVD). Args: out_dim (int): The dimensionality of node features after reduction. """ def __init__( self, out_dim: int, ): self.out_dim = out_dim @lru_cache() def forward(self, x: Tensor) -> Tensor: return svd_feature_reduction(x, self.out_dim)