Source code for rllm.transforms.utils.remove_training_classes

from typing import List

from rllm.data.graph_data import GraphData
from rllm.transforms.utils import BaseTransform
from rllm.transforms.utils.functional import remove_training_classes


[docs] class RemoveTrainingClasses(BaseTransform): r"""Removes classes from the node-level training set as given by `data.train_mask`, *e.g.*, in order to get a zero-shot label scenario. Args: classes (List[int]): The classes to remove from the training set. """ def __init__(self, classes: List[int]): super().__init__() self.classes = classes def forward(self, data: GraphData): if not hasattr(data, "train_mask"): raise ValueError( "`RemoveTrainingClasses` requires `data.train_mask` to exist." ) data.train_mask = remove_training_classes( data.train_mask, data.y, self.classes, ) return data