Source code for rllm.datasets.dblp

import os
import os.path as osp
from itertools import product
from typing import Optional, Callable

import numpy as np
import scipy.sparse as sp
import torch

from rllm.datasets.dataset import Dataset
from rllm.data.graph_data import HeteroGraphData
from rllm.utils.graph_utils import sparse_mx_to_torch_sparse_tensor
from rllm.utils.download import download_url
from rllm.utils.extract import extract_zip
from rllm.datasets.utils import index_to_mask


[docs] class DBLP(Dataset): r"""DBLP is a heterogeneous graph containing four types of entities, as collected in the `MAGNN: Metapath Aggregated Graph Neural Network for Heterogeneous Graph Embedding <https://arxiv.org/abs/2002.01680>`__ paper. The authors are divided into four research areas (database, data mining, artificial intelligence, information retrieval). Each author is described by a bag-of-words representation of their paper keywords. Args: cached_dir (str): Root directory where dataset should be saved. transform (callable, optional): A function/transform that takes in an `HeteroGraphData` object and returns a transformed version. The data object will be transformed before every access. forced_reload (bool): If set to `True`, this dataset will be re-process again. .. parsed-literal:: Statics: Name authors papers terms conferences nodes 4,057 14,328 7,723 20 """ url = "https://www.dropbox.com/s/yh4grpeks87ugr2/DBLP_processed.zip?dl=1" def __init__( self, cached_dir: str, transform: Optional[Callable] = None, forced_reload: Optional[bool] = False, ): self.name = "dblp" root = os.path.join(cached_dir, self.name) super().__init__(root, force_reload=forced_reload) self.data_list = [HeteroGraphData.load(self.processed_paths[0])] self.transform = transform if self.transform is not None: self.data_list[0] = self.transform(self.data_list[0]) @property def raw_filenames(self): return [ "adjM.npz", "features_0.npz", "features_1.npz", "features_2.npy", "labels.npy", "node_types.npy", "train_val_test_idx.npz", ] @property def processed_filenames(self): return ["data.pt"]
[docs] def process(self): r""" process data and save to './cached_dir/{dataset}/processed/'. """ os.makedirs(self.processed_dir, exist_ok=True) data = HeteroGraphData() node_types = ["author", "paper", "term", "conference"] for i, node_type in enumerate(node_types[:2]): x = sp.load_npz(osp.join(self.raw_dir, f"features_{i}.npz")) data[node_type].x = torch.from_numpy(x.todense()).to(torch.float) x = np.load(osp.join(self.raw_dir, "features_2.npy")) data["term"].x = torch.from_numpy(x).to(torch.float) node_type_idx = np.load(osp.join(self.raw_dir, "node_types.npy")) node_type_idx = torch.from_numpy(node_type_idx).to(torch.long) data["conference"].num_nodes = int((node_type_idx == 3).sum()) y = np.load(osp.join(self.raw_dir, "labels.npy")) data["author"].y = data.y = torch.from_numpy(y).to(torch.long) split = np.load(osp.join(self.raw_dir, "train_val_test_idx.npz")) data.train_mask = index_to_mask( torch.from_numpy(split["train_idx"]).to(torch.long), len(y) ) data.val_mask = index_to_mask( torch.from_numpy(split["val_idx"]).to(torch.long), len(y) ) data.test_mask = index_to_mask( torch.from_numpy(split["test_idx"]).to(torch.long), len(y) ) s = {} N_a = data["author"].num_nodes N_p = data["paper"].num_nodes N_t = data["term"].num_nodes N_c = data["conference"].num_nodes s["author"] = (0, N_a) s["paper"] = (N_a, N_a + N_p) s["term"] = (N_a + N_p, N_a + N_p + N_t) s["conference"] = (N_a + N_p + N_t, N_a + N_p + N_t + N_c) A = sp.load_npz(osp.join(self.raw_dir, "adjM.npz")) for src, dst in product(node_types, node_types): A_sub = A[s[src][0] : s[src][1], s[dst][0] : s[dst][1]].tocoo() if A_sub.nnz > 0: data[src, dst].adj = sparse_mx_to_torch_sparse_tensor(A_sub) # data.adj = sparse_mx_to_torch_sparse_tensor(A) data.save(self.processed_paths[0])
[docs] def download(self): r""" download data from url to './cached_dir/{dataset}/raw/'. """ os.makedirs(self.raw_dir, exist_ok=True) path = download_url(self.url, self.raw_dir, "DBLP_processed.zip") extract_zip(path, self.raw_dir) os.remove(path)
def item(self): return self.data_list[0] def __len__(self): return 1 def __getitem__(self, index: int): if index != 0: raise IndexError return self.data_list[index]