import json
import os
import os.path as osp
from typing import Optional, Dict, Any, List, Tuple
from dataclasses import dataclass
import warnings
from enum import Enum
import tqdm
import torch
import numpy as np
import pandas as pd
from rllm.types import ColType, StatType
from rllm.data import TableData, HeteroGraphData
from rllm.datasets.dataset import Dataset
from rllm.utils import download_url, extract_zip, sort_edge_index
from rllm.utils.col_process import timecol_to_unix_time
[docs]
class RelBenchTaskType(Enum):
REGRESSION = "regression"
BINARY_CLASSIFICATION = "binary_classification"
# MULTICLASS_CLASSIFICATION = "multiclass_classification"
# MULTILABEL_CLASSIFICATION = "multilabel_classification"
[docs]
@dataclass
class RelBenchTask:
task_name: str
task_type: RelBenchTaskType
entity_col: str
entity_table: str
time_col: str
target_col: str
timedelta: pd.Timedelta
num_eval_timestamps: int
# split: ["train", "val", "test"] data
task_data_dict: Dict[str, Tuple[pd.DataFrame, RelBenchTableMeta]]
def save(self, save_path: str):
assert save_path.endswith(".pt")
torch.save(self, save_path)
@staticmethod
def load(save_path: str) -> "RelBenchTask":
assert save_path.endswith(".pt")
task = torch.load(save_path, weights_only=False)
return task
[docs]
class RelBenchDataset(Dataset):
"""
Override methods for RelBench datasets.
Subclasses need to assign the following properties after processing:
self._task_dict: Dict[str, RelBenchTask]
self._table_dict: Dict[str, TableData]
self._hdata: HeteroGraphData
self._tabledata_stats_dict: Dict[str, Any]
self._table_meta_dict: Dict[str, RelBenchTableMeta]
"""
COLTYPE_FILE = "coltypes.json"
HDATA_FILE = "pkey_fkey_graph.pt"
TABLEDATA_STATS_FILE = "tabledata_stats.json"
TABLE_META_FILE = "table_meta.json"
###############################################################
# abstract properties which need to be implemented by subclasses
url = None
val_timestamp: Optional[pd.Timestamp] = None
test_timestamp: Optional[pd.Timestamp] = None
@property
def tasks(self) -> List[str]:
raise NotImplementedError
@property
def table_names(self) -> List[str]:
raise NotImplementedError
# placeholder implementations for abstract methods
[docs]
def process(self):
raise NotImplementedError
#################################################################
# interface properties and methods
@property
def raw_zip_files(self) -> List[str]:
return ["db.zip"] + [f"tasks/{task}.zip" for task in self.tasks]
@property
def raw_filenames(self):
return [
f"{table_name}.parquet" for table_name in self.table_names
]
@property
def processed_filenames(self):
return (
[f"{table_name}.pt" for table_name in self.table_names]
+ [f"{task_name}.pt" for task_name in self.tasks]
+ [self.TABLE_META_FILE]
+ [self.COLTYPE_FILE]
+ [self.HDATA_FILE]
+ [self.TABLEDATA_STATS_FILE]
)
# path properties
@property
def db_dir(self):
return osp.join(self.raw_dir, "db")
@property
def task_dir(self):
return osp.join(self.raw_dir, "tasks")
# after process properties
@property
def coltypes(self) -> Dict[str, Dict]:
if not hasattr(self, "_coltypes"):
self._coltypes = self._try_load_coltypes()
return self._coltypes
@property
def table_dict(self) -> Dict[str, TableData]:
if not hasattr(self, "_table_dict"):
self._table_dict = self._try_load_cached_table_dict()
return self._table_dict
@property
def table_meta_dict(self) -> Dict[str, RelBenchTableMeta]:
if not hasattr(self, "_table_meta_dict"):
self._table_meta_dict = self._try_load_cached_table_meta_dict()
return self._table_meta_dict
@property
def hdata(self) -> HeteroGraphData:
if not hasattr(self, "_hdata"):
self._hdata = self._try_load_hdata()
return self._hdata
@property
def tabledata_stats_dict(self) -> Dict[str, Any]:
if not hasattr(self, "_tabledata_stats_dict"):
self._tabledata_stats_dict = self._try_load_tabledata_stats_dict()
return self._tabledata_stats_dict
@property
def task_dict(self) -> Dict[str, RelBenchTask]:
if not hasattr(self, "_task_dict"):
self._task_dict = self._try_load_task_dict()
return self._task_dict
[docs]
def load_all(self):
"""Force load all cached properties."""
if not self.has_process:
raise ValueError("Dataset has not been processed yet.")
_ = self.coltypes
_ = self.table_dict
_ = self.table_meta_dict
_ = self.hdata
_ = self.tabledata_stats_dict
_ = self.task_dict
#####################################################################
@property
def has_download(self):
for db_file in self.raw_filenames:
if not osp.exists(osp.join(self.db_dir, db_file)):
return False
for subtask in self.tasks:
sub_task_dir = osp.join(self.task_dir, subtask)
for split in ["train", "val", "test"]:
split_file = osp.join(sub_task_dir, f"{split}.parquet")
if not osp.exists(split_file):
return False
print("All raw files are present.")
return True
@property
def has_process(self):
file_exist = all(
osp.exists(osp.join(self.processed_dir, file))
for file in self.processed_filenames
)
return file_exist
[docs]
def download(self):
"""
Download and unzip raw files.
"""
os.makedirs(self.raw_dir, exist_ok=True)
os.makedirs(self.task_dir, exist_ok=True)
print("Downloading raw files...")
print(self.task_dir)
for filename in self.raw_zip_files:
url = self.url + filename
path = download_url(url, self.raw_dir, filename)
if filename.startswith("tasks/"):
extract_zip(path, self.task_dir)
else:
# extract db files
extract_zip(path, self.raw_dir)
os.remove(path)
[docs]
def validate_dataset(self):
"""
Validate the integrity of downloaded files.
1. validate primary keys
2. validate foreign keys (correct if necessary)
"""
# 1. validate primary keys
for table_name, table in self.table_dict.items():
if table.pkey is not None:
ser = table.df.index
if not (ser.values == np.arange(len(ser))).all():
raise ValueError(
f"Primary key column {table.pkey} in table {table_name} is not valid."
)
# 2. validate foreign keys
for table_name, table in self.table_dict.items():
metadata: RelBenchTableMeta = self.table_meta_dict[table_name]
for fkey_col, pkey_table_name in metadata.fkey_col_to_pkey_table.items():
pkey_range = len(self.table_dict[pkey_table_name].df)
mask = table.df[fkey_col] >= pkey_range
if mask.any():
warnings.warn(
f"Foreign key column {fkey_col} in table {table_name} has values over {pkey_range}. "
f"Correcting them by setting to None."
)
table.df.loc[mask, fkey_col] = None
[docs]
def make_pkey_fkey_graph(self) -> Tuple[HeteroGraphData, Dict]:
"""
Make primary key - foreign key graph for the dataset.
This method lazy materializes each TableData, saves them to processed_dir,
and constructs the HeteroGraphData based on pkey-fkey relations.
Returns:
HeteroGraphData: Heterogeneous graph data.
Dict: table_name -> TableData.metadata
"""
hdata = HeteroGraphData()
tabledata_stats_dict = {} # table_name -> metadata
table_dict = self.table_dict
table_meta_dict = self.table_meta_dict
for table_name, table in tqdm.tqdm(table_dict.items(), desc="Processing tables"):
df = table.df
# Ensure that pkey is consecutive.
if table.pkey is not None:
assert (df.index.values == np.arange(len(df))).all()
col_to_coltype = table.col_types
# remove pkey, fkeys in col_to_coltype
self._remove_pkey_fkeys(col_to_coltype, table)
# add constant feature in case df is empty:
if len(col_to_coltype) == 0:
col_to_coltype = {"__const__": ColType.NUMERICAL}
# We need to add edges later, so we need to also keep the fkeys
fkey_dict = {key: df[key] for key in table.fkeys}
df = pd.DataFrame({"__const__": np.ones(len(table.df)), **fkey_dict})
# tensorize and cache
cache_path = osp.join(self.processed_dir, f"{table_name}.pt")
print(f"Lazy materializing table {table_name}...")
table.lazy_materialize(
keep_df=True,
text_embedder_config=getattr(self, "text_embedder_config", None),
)
table.save(cache_path)
# Add table data to hetero graph data
hdata[table_name].table = table
if table.time_col is not None:
hdata[table_name].time = torch.from_numpy(
timecol_to_unix_time(table.df[table.time_col])
)
# Add table column stats
tabledata_stats_dict[table_name] = table.metadata
# Add edges based on pkey-fkey relations
for fkey_col_name, pkey_table_name in (
table_meta_dict[table_name].fkey_col_to_pkey_table.items()
):
pkey_index = df[fkey_col_name] # pkey be referenced by fkey
# Filter out dangling foreign keys
mask = ~pkey_index.isna()
fkey_index = torch.arange(len(pkey_index))
pkey_index = torch.from_numpy(
pkey_index[mask].astype(int).values
)
fkey_index = fkey_index[torch.from_numpy(mask.values)]
# Ensure no dangling fkeys
assert (pkey_index < len(table_dict[pkey_table_name].df)).all()
# fkey -> pkey edges (this table -> pkey_table)
edge_index = torch.stack(
[fkey_index, pkey_index], dim=0
)
edge_type = (table_name, f"f2p_{fkey_col_name}", pkey_table_name)
hdata[edge_type].edge_index = sort_edge_index(edge_index)
# pkey -> fkey edges (pkey_table -> this table)
# add "rev_" as revserse edge (used for undirected graph)
edge_index = torch.stack(
[pkey_index, fkey_index], dim=0
)
edge_type = (pkey_table_name, f"rev_f2p_{fkey_col_name}", table_name)
hdata[edge_type].edge_index = sort_edge_index(edge_index)
if hdata.validate():
print("HeteroGraphData validation passed.")
else:
print("HeteroGraphData validation failed.")
return hdata, tabledata_stats_dict
# private methods
def _remove_pkey_fkeys(self, col_to_type: Dict[str, Any], table: TableData):
"""Inplace remove pkey and fkeys from col_to_type."""
if table.pkey is not None:
if table.pkey in col_to_type:
col_to_type.pop(table.pkey)
for fkey in table.fkeys:
if fkey in col_to_type:
col_to_type.pop(fkey)
def _save_coltypes(self, coltypes: Dict[str, Dict]):
coltype_path = osp.join(self.processed_dir, self.COLTYPE_FILE)
with open(coltype_path, "w") as f:
json.dump(
{
table_name: {
col_name: col_type.value
for col_name, col_type in coltype_dict.items()
}
for table_name, coltype_dict in coltypes.items()
},
f,
indent=2
)
def _try_load_coltypes(self) -> Dict[str, Dict]:
if not self.has_process:
raise ValueError("Dataset has not been processed yet.")
coltype_path = osp.join(self.processed_dir, self.COLTYPE_FILE)
with open(coltype_path, "r") as f:
raw_dict = json.load(f)
return {
table_name: {
col_name: ColType(col_type_str)
for col_name, col_type_str in coltype_dict.items()
}
for table_name, coltype_dict in raw_dict.items()
}
def _try_load_cached_table_dict(self) -> Dict[str, TableData]:
"""Load cached TableData from processed_dir."""
if not self.has_process:
raise ValueError("Dataset has not been processed yet.")
table_dict = {}
for table_name in self.table_names:
cache_path = osp.join(self.processed_dir, f"{table_name}.pt")
table_data = TableData.load(cache_path)
table_dict[table_name] = table_data
return table_dict
def _save_table_meta_dict(self, table_meta_dict: Dict[str, RelBenchTableMeta]):
"""Save table_meta_dict to processed_dir."""
meta_path = osp.join(self.processed_dir, self.TABLE_META_FILE)
serializable_dict = {}
for table_name, meta in table_meta_dict.items():
serializable_dict[table_name] = {
"fkey_col_to_pkey_table": meta.fkey_col_to_pkey_table,
"pkey_col": meta.pkey_col,
"time_col": meta.time_col,
}
with open(meta_path, "w") as f:
json.dump(serializable_dict, f)
def _try_load_cached_table_meta_dict(self) -> Dict[str, RelBenchTableMeta]:
"""Load cached table_meta_dict from processed_dir."""
if not self.has_process:
raise ValueError("Dataset has not been processed yet.")
meta_path = osp.join(self.processed_dir, self.TABLE_META_FILE)
with open(meta_path, "r") as f:
raw_dict = json.load(f)
table_meta_dict = {}
for table_name, meta_dict in raw_dict.items():
table_meta_dict[table_name] = RelBenchTableMeta(
fkey_col_to_pkey_table=meta_dict["fkey_col_to_pkey_table"],
pkey_col=meta_dict["pkey_col"],
time_col=meta_dict["time_col"],
)
return table_meta_dict
def _try_load_hdata(self) -> HeteroGraphData:
"""Load cached HeteroGraphData from processed_dir."""
if not self.has_process:
raise ValueError("Dataset has not been processed yet.")
hdata_path = osp.join(self.processed_dir, self.HDATA_FILE)
hdata = HeteroGraphData.load(hdata_path)
return hdata
def _save_tabledata_stats_dict(self, table_stats: Dict[str, Any]):
stats_path = osp.join(self.processed_dir, self.TABLEDATA_STATS_FILE)
with open(stats_path, "w") as f:
res = {}
for table_name, stats_dict in table_stats.items():
res[table_name] = {}
for coltype, stats_list in stats_dict.items():
res[table_name][coltype.value] = []
for col_stats in stats_list:
stat_entry = {}
for stat_type, stat_value in col_stats.items():
stat_entry[stat_type.value] = stat_value
res[table_name][coltype.value].append(stat_entry)
json.dump(res, f, indent=2)
def _try_load_tabledata_stats_dict(self) -> Dict[str, Any]:
"""Load cached tabledata_stats_dict from processed_dir."""
if not self.has_process:
raise ValueError("Dataset has not been processed yet.")
stats_path = osp.join(self.processed_dir, self.TABLEDATA_STATS_FILE)
with open(stats_path, "r") as f:
raw_dict = json.load(f)
res = {}
for table_name, stats_dict in raw_dict.items():
res[table_name] = {}
for coltype_str, stats_list in stats_dict.items():
coltype = ColType(coltype_str)
res[table_name][coltype] = []
for stat_entry in stats_list:
stat_converted = {}
for stat_type_str, stat_value in stat_entry.items():
stat_type = StatType(stat_type_str)
stat_converted[stat_type] = stat_value
res[table_name][coltype].append(stat_converted)
return res
def _save_task_dict(self, task_dict: Dict[str, RelBenchTask]):
"""Save task_dict to processed_dir."""
for task_name, task in task_dict.items():
task_path = osp.join(self.processed_dir, f"{task_name}.pt")
task.save(task_path)
def _try_load_task_dict(self) -> Dict[str, RelBenchTask]:
"""Load cached task_dict from processed_dir."""
if not self.has_process:
raise ValueError("Dataset has not been processed yet.")
task_dict = {}
for task_name in self.tasks:
task_path = osp.join(self.processed_dir, f"{task_name}.pt")
task = RelBenchTask.load(task_path)
task_dict[task_name] = task
return task_dict
# override other methods
def __len__(self):
return len(self.table_names)
def __getitem__(self, idx):
if idx < 0 or idx >= len(self.table_names):
raise IndexError
table_name = self.table_names[idx]
return self._table_dict[table_name]