Source code for rllm.datasets.sjtutables.tacm12k

from typing import Optional, List
import os
import os.path as osp

import numpy as np
import pandas as pd
import torch

from rllm.types import ColType
from rllm.data.table_data import TableData
from rllm.datasets.dataset import Dataset
from rllm.utils.download import download_url
from rllm.utils.extract import extract_zip


[docs] class TACM12KDataset(Dataset): r"""TACM12KDataset is a multi-table relational dataset containing 4 tables, as collected in the `rLLM: Relational Table Learning with LLMs <https://arxiv.org/abs/2407.20157>`__ paper. It includes four tables: papers, authors, citations and writings. The papers table includes publication information of papers. The authors table includes author information. The citations table includes citation (i.e., <paper, paper>) information between papers. The writings table includes <author, write, paper> relationship between authors and papers. The default task is to predict the conference of papers. Args: cached_dir (str): Root directory where dataset should be saved. forced_reload (bool): If set to `True`, this dataset will be re-process again. .. parsed-literal:: Table1: papers --------------- Statics: Name Papers Features Size 12,499 5 Table2: authors ------------------ Statics: Name Authors Features Size 17,431 3 Table3: citations ------------------ Statics: Name Citations Features edges 30,789 2 Table4: writings ------------------ Statics: Name Writings Features edges 37,055 2 """ url = "https://raw.githubusercontent.com/rllm-project/rllm_datasets/refs/heads/main/sjtutables/tacm12k.zip" # noqa def __init__(self, cached_dir: str, force_reload: Optional[bool] = False) -> None: self.name = "tacm12k" root = os.path.join(cached_dir, self.name) super().__init__(root, force_reload=force_reload) # Table_ACM12K data_list: # 0: papers_table # 1: authors_table # 2: citations_table # 3: writings_table # 4: paper_embeddings # 5: author_embeddings self.data_list: List[TableData] = [ TableData.load(self.processed_paths[0]), TableData.load(self.processed_paths[1]), TableData.load(self.processed_paths[2]), TableData.load(self.processed_paths[3]), torch.from_numpy(np.load(osp.join(self.raw_dir, "paper_embeddings.npy"))), torch.from_numpy(np.load(osp.join(self.raw_dir, "author_embeddings.npy"))), ] @property def raw_filenames(self): return [ "papers.csv", "authors.csv", "citations.csv", "writings.csv", "masks.pt", ] @property def processed_filenames(self): return [ "paper_data.pt", "authors_data.pt", "citations_data.pt", "writings_data.pt", ]
[docs] def process(self): r""" process data and save to './cached_dir/{dataset}/processed/'. """ os.makedirs(self.processed_dir, exist_ok=True) # papers Data path = osp.join(self.raw_dir, self.raw_filenames[0]) paper_df = pd.read_csv(path, index_col=["paper_id"]) col_types = { "year": ColType.CATEGORICAL, "conference": ColType.CATEGORICAL, "title": ColType.CATEGORICAL, "abstract": ColType.CATEGORICAL, } # Create masks masks_path = osp.join(self.raw_dir, self.raw_filenames[4]) masks = torch.load(masks_path, weights_only=False) TableData( df=paper_df, col_types=col_types, target_col="conference", train_mask=masks["train_mask"], val_mask=masks["val_mask"], test_mask=masks["test_mask"], ).save(self.processed_paths[0]) # authors Data path = osp.join(self.raw_dir, self.raw_filenames[1]) author_df = pd.read_csv(path, index_col=["author_id"]) col_types = { "name": ColType.CATEGORICAL, "firm": ColType.CATEGORICAL, } TableData(df=author_df, col_types=col_types).save(self.processed_paths[1]) # cite Data path = osp.join(self.raw_dir, self.raw_filenames[2]) cite_df = pd.read_csv(path) col_types = { "paper_id": ColType.NUMERICAL, "paper_id_cited": ColType.NUMERICAL, } TableData(df=cite_df, col_types=col_types).save(self.processed_paths[2]) # cite Data path = osp.join(self.raw_dir, self.raw_filenames[3]) pa_df = pd.read_csv(path) col_types = { "paper_id": ColType.NUMERICAL, "author_id": ColType.NUMERICAL, } TableData(df=pa_df, col_types=col_types).save(self.processed_paths[3])
[docs] def download(self): os.makedirs(self.raw_dir, exist_ok=True) path = download_url(self.url, self.raw_dir, self.name + ".zip") extract_zip(path, self.raw_dir) os.remove(path)
def __len__(self): return 4 def __getitem__(self, index: int): if index < 0 or index > self.__len__(): raise IndexError return self.data_list[index]