Source code for rllm.transforms.graph_transforms.graph_transform

from __future__ import annotations
from abc import ABC
from typing import List, Callable, Union

import torch

from rllm.data.graph_data import GraphData, HeteroGraphData


[docs] class GraphTransform(torch.nn.Module, ABC): r"""The GraphTransform class is a base class for applying a series of transformations to graph data. It supports both homogeneous and heterogeneous graph data. Args: transforms (List[Callable]): A list of transformation functions to be applied to the graph data. """ def __init__( self, transforms: List[Callable], ) -> None: super().__init__() self.data = None self.transforms = transforms
[docs] def forward( self, data: Union[GraphData, HeteroGraphData, list, tuple], ): for transform in self.transforms: if isinstance(data, (list, tuple)): data = [transform(d) for d in data] else: data = transform(data) return data