Source code for rllm.datasets.planetoid

import os
import os.path as osp
import pickle
from typing import Optional, Callable

import numpy as np
from numpy import ndarray
from scipy.sparse._csr import csr_matrix
import networkx as nx
import warnings

import torch

# import sys
# sys.path.append('../')
from rllm.datasets.dataset import Dataset
from rllm.data.graph_data import GraphData
from rllm.utils.sparse import sparse_mx_to_torch_sparse_tensor
from rllm.datasets.utils import index_to_mask
from rllm.utils.download import download_url

warnings.filterwarnings("ignore", category=DeprecationWarning)


[docs] class PlanetoidDataset(Dataset): r"""The citation network datasets from the `Revisiting Semi-Supervised Learning with Graph Embeddings <https://arxiv.org/abs/1603.08861>`__ paper, which include :obj:`"Cora"`, :obj:`"CiteSeer"` and :obj:`"PubMed"`. Nodes represent documents and edges represent citation links. Args: cached_dir (str): Root directory where dataset should be saved. file_name (str): The name of dataset, *e.g.*, `cora`, `citeseer` and `pubmed`. transform (callable, optional): A function/transform that takes in an `GraphData` object and returns a transformed version. The data object will be transformed before every access. (default: `None`) split (str, optional): The type of dataset split (`public`, `full`, `geom-gcn`, `random`). If set to `public`, the split will be the public fixed split from the `Revisiting Semi-Supervised Learning with Graph Embeddings <https://arxiv.org/abs/1603.08861>`__ paper. If set to `full`, all nodes except those in the validation and test sets will be used for training (as in the `FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling <https://arxiv.org/abs/1801.10247>`__ paper). If set to `geom-gcn`, the 10 public fixed splits from the `Geom-GCN: Geometric Graph Convolutional Networks <https://openreview.net/forum?id=S1e2agrFvS>`__ paper are given. If set to `random`, train, validation, and test sets will be randomly generated, according to `num_train_per_class`, `num_val` and `num_test`. (default: `public`) num_train_per_class (int, optional): The number of training samples per class in case of `random` split. (default: 20) num_val (int, optional): The number of validation samples in case of `random` split. (default: 500) num_test (int, optional): The number of test samples in case of `random` split. (default: 1000) forced_reload (bool): If set to `True`, this dataset will be re-process again. .. parsed-literal:: Statics: Name Cora CiteSeer PubMed nodes 2708 3327 19717 edges 10556 9104 88648 features 1433 3703 500 classes 7 6 3 """ url = "https://github.com/kimiyoung/planetoid/raw/master/data" geom_gcn_url = ( "https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master" # noqa ) def __init__( self, cached_dir: str, file_name: str, transform: Optional[Callable] = None, split: str = "public", num_train_per_class: int = 20, num_val: int = 500, num_test: int = 1000, force_reload: Optional[bool] = False, ): self.name = file_name.lower() assert self.name in ["cora", "citeseer", "pubmed"] self.split = split.lower() assert self.split in ["public", "full", "geom-gcn", "random"] root = osp.join(cached_dir, self.name) if self.split == "geom-gcn": root = osp.join(root, "geom-gcn") super().__init__(root, force_reload=force_reload) self.data_list = [GraphData.load(self.processed_paths[0])] if self.split == "full": data = self.data_list[0] data.train_mask.fill_(True) data.train_mask[data.val_mask | data.test_mask] = False elif split == "random": data = self.data_list[0] data.train_mask.fill_(False) for c in range(data.num_classes): idx = (data.y == c).nonzero(as_tuple=False).view(-1) idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]] data.train_mask[idx] = True remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1) remaining = remaining[torch.randperm(remaining.size(0))] data.val_mask.fill_(False) data.val_mask[remaining[:num_val]] = True data.test_mask.fill_(False) data.test_mask[remaining[num_val : num_val + num_test]] = True self.transform = transform if self.transform is not None: self.data_list[0] = self.transform(self.data_list[0]) @property def raw_filenames(self): suffix = ["x", "tx", "allx", "y", "ty", "ally", "graph", "test.index"] return [f"ind.{self.name}.{s}" for s in suffix] @property def processed_filenames(self): return ["data.pt"] def _load_raw_file(self, filename: str): r""" load data from './cached_dir/{dataset}/raw/' """ filepath = osp.join(self.raw_dir, filename) if "test.index" in filename: with open(filepath, "r") as f: lines = f.readlines() out = [int(line.strip("\n")) for line in lines] out = torch.as_tensor(out, dtype=torch.long) else: with open(filepath, "rb") as f: content = pickle.load(f, encoding="latin1") if isinstance(content, csr_matrix): out = content.todense() out = torch.from_numpy(out).float() elif isinstance(content, ndarray): out = torch.from_numpy(content).float() else: out = content return out
[docs] def process(self): r""" process data and save to './cached_dir/{dataset}/processed/'. """ os.makedirs(self.processed_dir, exist_ok=True) items = [self._load_raw_file(filename) for filename in self.raw_filenames] x, tx, allx, y, ty, ally, graph, test_index = items train_index = torch.arange(x.shape[0], dtype=torch.long) val_index = torch.arange(x.shape[0], x.shape[0] + 500, dtype=torch.long) sorted_test_index, _ = test_index.sort() if self.name == "citeseer": # For citeseer, there are some isolated nodes. # We should find them and add them as # zero-vector in the right position. min_index, max_index = sorted_test_index[0], sorted_test_index[-1] tx_ext = torch.zeros(max_index - min_index + 1, tx.shape[1], dtype=tx.dtype) tx_ext[sorted_test_index - min_index, :] = tx ty_ext = torch.zeros(max_index - min_index + 1, ty.shape[1], dtype=ty.dtype) ty_ext[sorted_test_index - min_index, :] = ty tx, ty = tx_ext, ty_ext x = torch.cat([allx, tx], dim=0) x[test_index] = x[sorted_test_index] y = torch.cat([ally, ty], dim=0).max(dim=1)[1] y[test_index] = y[sorted_test_index] if self.split == "geom-gcn": train_masks, val_masks, test_masks = [], [], [] for i in range(10): name = f"{self.name.lower()}_split_0.6_0.2_{i}.npz" splits = np.load(osp.join(self.raw_dir, name)) train_masks.append(torch.from_numpy(splits["train_mask"])) val_masks.append(torch.from_numpy(splits["val_mask"])) test_masks.append(torch.from_numpy(splits["test_mask"])) train_mask = torch.stack(train_masks, dim=1) val_mask = torch.stack(val_masks, dim=1) test_mask = torch.stack(test_masks, dim=1) else: train_mask = index_to_mask(train_index, x.shape[0]) val_mask = index_to_mask(val_index, x.shape[0]) test_mask = index_to_mask(test_index, x.shape[0]) G = nx.from_dict_of_lists(graph) adj_sp = sparse_mx_to_torch_sparse_tensor(nx.to_scipy_sparse_array(G)) data = GraphData( x, y, adj_sp, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask ) data.save(self.processed_paths[0])
[docs] def download(self): r""" download data from url to './cached_dir/{dataset}/raw/'. """ os.makedirs(self.raw_dir, exist_ok=True) for filename in self.raw_filenames: target_url = f"{self.url}/{filename}" download_url(target_url, self.raw_dir, filename) if self.split == "geom-gcn": for i in range(10): url = f"{self.geom_gcn_url}/splits/{self.name.lower()}" download_url(f"{url}_split_0.6_0.2_{i}.npz", self.raw_dir)
def item(self): return self.data_list[0] def __len__(self): return 1 def __getitem__(self, index: int): if index != 0: raise IndexError return self.data_list[index]