Source code for rllm.datasets.tagdataset

import os
import os.path as osp
import pickle
import types
import warnings
from typing import Optional, Callable

import torch

from rllm.datasets.dataset import Dataset
from rllm.data.graph_data import GraphData
from rllm.utils.download import download_url
from rllm.data.storage import BaseStorage

warnings.filterwarnings("ignore", category=DeprecationWarning)


[docs] class TAGDataset(Dataset): """Three text-attributed-graph datasets, including `cora` from `Automating the Construction of Internet Portals <https://link.springer.com/content/pdf/10.1023/A:1009953814988.pdf>`__, `pubmed` from `Collective Classification in Network Data <https://ojs.aaai.org/aimagazine/index.php/aimagazine/article/view/2157>`__ and `citeseer` from `CiteSeer: an automatic citation indexing system <https://dl.acm.org/doi/10.1145/276675.276685>`__ paper. This dataset also contains cached LLM predictions and confidences provided by the paper `Label-free Node Classification on Graphs with Large Language Models (LLMS) <https://arxiv.org/abs/2310.04668>`__. Args: cached_dir (str): Root directory where dataset should be saved. file_name (str): The name of dataset, *e.g.*, `cora` and `pubmed`. transform (callable, optional): A function/transform that takes in an `GraphData` object and returns a transformed version. The data object will be transformed before every access. (default: `None`) use_preds (bool): If set to `False`, cached pesudo-labels annotated by gpt will not be loaded. forced_reload (bool): If set to `True`, this dataset will be re-process again. """ urls = { "text": { "cora": "https://drive.usercontent.google.com/download?id=10iBkU36HGc9mVPOWUofVyC_9cJ-U1-xl&confirm=t", # noqa "pubmed": "https://drive.usercontent.google.com/download?id=1hcqnmKXv4dk060k6VjmW66e8etPYcOW9&confirm=t", # noqa "citeseer": "https://drive.usercontent.google.com/download?id=1JRJHDiKFKiUpozGqkWhDY28E6F4v5n7l&confirm=t", # noqa }, "pred": { "cora": "https://drive.usercontent.google.com/download?id=1jAZH9daUjg0ce9O4IqitPA0coWCJKcMs&confirm=t", # noqa "pubmed": "https://drive.usercontent.google.com/download?id=1d7saxT6Uc5sA4UpZ1ujky_Ns9vChgw9L&confirm=t", # noqa "citeseer": "https://drive.usercontent.google.com/download?id=1k4L-NPxbrd9hiTehspjd0Ue6NOT-PRWi&confirm=t", # noqa }, } def __init__( self, cached_dir: str, file_name: str, transform: Optional[Callable] = None, use_cache: bool = True, force_reload: Optional[bool] = False, ): self.name = file_name.lower() assert self.name in ["cora", "pubmed", "citeseer"] root = os.path.join(cached_dir, f"LLMGNN_{self.name}") self.use_cache = use_cache super().__init__(root, force_reload=force_reload) self.data_list = [GraphData.load(self.processed_paths[0])] self.transform = transform if self.transform is not None: self.data_list[0] = self.transform(self.data_list[0]) @property def raw_filenames(self): filenames = [f"{self.name}_fixed_sbert.pt", f"{self.name}^cache^consistency.pt"] return filenames @property def processed_filenames(self): return ["data.pt"]
[docs] def process(self): r""" process data and save to './cached_dir/{dataset}/processed/'. """ os.makedirs(self.processed_dir, exist_ok=True) data = self._get_raw_text() if self.use_cache: filepath = osp.join(self.raw_dir, f"{self.name}^cache^consistency.pt") cache_data = torch.load(filepath, weights_only=False) data.pl = cache_data["pred"] data.conf = cache_data["conf"] data.cache_mask = data.pl >= 0 data.save(self.processed_paths[0])
[docs] def download(self): r""" download data from url to './cached_dir/{dataset}/raw/'. """ os.makedirs(self.raw_dir, exist_ok=True) download_url(self.urls["text"][self.name], self.raw_dir, f"{self.name}_fixed_sbert.pt") download_url(self.urls["pred"][self.name], self.raw_dir, f"{self.name}^cache^consistency.pt",)
def item(self): return self.data_list[0] def __len__(self): return 1 def __getitem__(self, index: int): if index != 0: raise IndexError return self.data_list[index] def _get_raw_text(self): class CustomUnpickler(pickle.Unpickler): def find_class(self, module, name): if "torch_geometric" in module: if name == "GlobalStorage": return BaseStorage else: return types.SimpleNamespace return super().find_class(module, name) custom_pickle_module = types.ModuleType("custom_pickle_module") custom_pickle_module.Unpickler = CustomUnpickler custom_pickle_module.load = pickle.load path = osp.join(self.raw_dir, f"{self.name}_fixed_sbert.pt") raw_data = torch.load( path, pickle_module=custom_pickle_module, weights_only=False )._store num_nodes = raw_data.edge_index.max().item() + 1 adj = torch.sparse_coo_tensor( indices=raw_data.edge_index, values=torch.ones(raw_data.edge_index.size(1)), size=(num_nodes, num_nodes), ) data = GraphData(x=raw_data.x, y=raw_data.y, adj=adj, text=raw_data.raw_texts) data.label_names = raw_data.label_names return data