Source code for rllm.datasets.sjtutables.tlf2k

from typing import Optional, List
import os
import os.path as osp

import pandas as pd
import torch

from rllm.types import ColType
from rllm.data.table_data import TableData
from rllm.datasets.dataset import Dataset
from rllm.utils.download import download_url
from rllm.utils.extract import extract_zip


[docs] class TLF2KDataset(Dataset): r"""TLF2KDataset is a multi-table relational dataset containing 3 tables, as collected in the `rLLM: Relational Table Learning with LLMs <https://arxiv.org/abs/2407.20157>`__ paper. It contains three tables: users, movies and ratings. The artists table includes information about artists, such as location and genre. The user_artists table contains the interaction between the user and artist as format: [user, artist, listening_count]. The user_friends table represents bi-directional friendship between users. The default task of this dataset is to predict artists's genre. Args: cached_dir (str): Root directory where dataset should be saved. forced_reload (bool): If set to `True`, this dataset will be re-process again. .. parsed-literal:: Table1: artists --------------- Statics: Name Users Features Size 9,047 10 Table2: user_artists ------------------ Statics: Name Movies Features nodes 80,009 3 Table3: user_friends ------------------ Statics: Name Ratings Features nodes 12,717 2 """ url = "https://raw.githubusercontent.com/rllm-project/rllm_datasets/refs/heads/main/sjtutables/tlf2k.zip" # noqa def __init__(self, cached_dir: str, force_reload: Optional[bool] = False) -> None: self.name = "tlf2k" root = os.path.join(cached_dir, self.name) super().__init__(root, force_reload=force_reload) # Table_LastFM2K data_list: # 0: artists_table # 1: user_artists_table # 2: user_friends_table self.data_list: List[TableData] = [ TableData.load(self.processed_paths[0]), TableData.load(self.processed_paths[1]), TableData.load(self.processed_paths[2]), ] @property def raw_filenames(self): return ["artists.csv", "user_artists.csv", "user_friends.csv", "masks.pt"] @property def processed_filenames(self): return ["artists_data.pt", "user_artists_data.pt", "user_friends_data.pt"]
[docs] def process(self): r""" process data and save to './cached_dir/{dataset}/processed/'. """ os.makedirs(self.processed_dir, exist_ok=True) # Artists Data path = osp.join(self.raw_dir, self.raw_filenames[0]) artist_df = pd.read_csv(path) col_types = { # TODO: Process these feature with column type `Text` "type": ColType.CATEGORICAL, "name": ColType.CATEGORICAL, "born": ColType.CATEGORICAL, "yearsActive": ColType.CATEGORICAL, "location": ColType.CATEGORICAL, "biography": ColType.CATEGORICAL, "label": ColType.CATEGORICAL, } # Create masks masks_path = osp.join(self.raw_dir, self.raw_filenames[3]) masks = torch.load(masks_path, weights_only=False) TableData( df=artist_df, col_types=col_types, target_col="label", train_mask=masks["train_mask"], val_mask=masks["val_mask"], test_mask=masks["test_mask"], ).save(self.processed_paths[0]) # User-Artist Relationship path = osp.join(self.raw_dir, self.raw_filenames[1]) ua_df = pd.read_csv(path) col_types = { "userID": ColType.NUMERICAL, "artistID": ColType.NUMERICAL, } TableData(df=ua_df, col_types=col_types).save(self.processed_paths[1]) # User-user Relationship path = osp.join(self.raw_dir, self.raw_filenames[2]) uu_df = pd.read_csv(path) col_types = { "userID": ColType.NUMERICAL, "friendID": ColType.NUMERICAL, } TableData(df=uu_df, col_types=col_types).save(self.processed_paths[2])
[docs] def download(self): os.makedirs(self.raw_dir, exist_ok=True) path = download_url(self.url, self.raw_dir, "TLF2K.zip") extract_zip(path, self.raw_dir) os.remove(path)
def __len__(self): return 3 def __getitem__(self, index: int): if index < 0 or index > self.__len__(): raise IndexError return self.data_list[index]