Source code for rllm.datasets.relbench.f1

from typing import Optional
import os
import os.path as osp
import json

import torch
import pandas as pd
from pyarrow import parquet as pq

from rllm.types import TableType
from rllm.data.table_data import TableData
from rllm.datasets.relbench.base import (
    RelBenchDataset,
    RelBenchTableMeta,
    RelBenchTaskType,
    RelBenchTask
)
from rllm.utils.type_infer import TypeInferencer
from rllm.datasets.relbench.utils import (
    load_task_data,
    GloveTextEmbedding
)
from rllm.preprocessing import TextEmbedderConfig


[docs] class RelF1Dataset(RelBenchDataset): """ A wrapper for rel-f1 dataset in RelBench benchmark from `RelBench: A Benchmark for Deep Learning on Relational Databases <https://arxiv.org/abs/2407.20060>`__ paper, which contains Formula 1 racing data with 9 tables and 3 tasks. Tables: - circuits - constructor_results - constructors - constructor_standings - drivers - qualifying - races - results - standings Tasks: - driver-dnf: Binary classification task to predict whether a driver did not finish a race. - driver-position: Regression task to predict the finishing position of a driver. - driver-top3: Binary classification task to predict whether a driver finished in the top 3. """ url = "https://relbench.stanford.edu/download/rel-f1/" val_timestamp = pd.Timestamp("2005-01-01") test_timestamp = pd.Timestamp("2010-01-01") def _get_device(self) -> torch.device: # Lazily determine device when needed instead of at import time. return torch.device("cuda" if torch.cuda.is_available() else "cpu") @property def text_embedder_config(self) -> TextEmbedderConfig: """ Lazily create the text embedder config the first time it is needed. Previously this was constructed as a class attribute at import time, which caused `from rllm.datasets.relbench.f1 import RelF1Dataset` to download/load the sentence-transformers model and took a long time. """ if not hasattr(self, "_text_embedder_config"): device = self._get_device() self._text_embedder_config = TextEmbedderConfig( text_embedder=GloveTextEmbedding(device=device), batch_size=256, ) return self._text_embedder_config def __init__( self, cached_dir: str, force_reload: Optional[bool] = False ): self.name = "rel-f1" root = os.path.join(cached_dir, self.name) super().__init__(root, force_reload=force_reload) @property def tasks(self): return ["driver-dnf", "driver-position", "driver-top3"] @property def table_names(self): return [ "circuits", "constructor_results", "constructors", "constructor_standings", "drivers", "qualifying", "races", "results", "standings" ]
[docs] def process(self): r""" process data and save to './cached_dir/{dataset}/processed/'. """ os.makedirs(self.processed_dir, exist_ok=True) print("Processing raw data, this may take a while...") # 1. load parquet files print("Loading parquet files...") table_df_dict = {} table_meta_dict = {} for raw_file in self.raw_filenames: table_name = raw_file.removesuffix(".parquet") path = osp.join(self.db_dir, raw_file) table = pq.read_table(path) df = table.to_pandas() metadata_bytes = table.schema.metadata metadata = RelBenchTableMeta( fkey_col_to_pkey_table=json.loads(metadata_bytes[b"fkey_col_to_pkey_table"].decode("utf-8")), pkey_col=json.loads(metadata_bytes[b"pkey_col"].decode("utf-8")), time_col=json.loads(metadata_bytes[b"time_col"].decode("utf-8")) if b"time_col" in metadata_bytes else None ) table_df_dict[table_name] = df table_meta_dict[table_name] = metadata self._save_table_meta_dict(table_meta_dict) self._table_meta_dict = table_meta_dict # 2. extrat coltype and cache print("Inferring column types...") table_df_coltype_dict = TypeInferencer.infer_table_df_dict_coltype( df_dict=table_df_dict ) self._save_coltypes(table_df_coltype_dict) self._coltypes = table_df_coltype_dict # 3. convert to TableData (lazy feature) print("Converting to TableData...") table_data_dict = {} for table_name, df in table_df_dict.items(): col_types = table_df_coltype_dict[table_name] metadata: RelBenchTableMeta = table_meta_dict[table_name] table_type = TableType.DATATABLE table_data = TableData( name=table_name, df=df, col_types=col_types, table_type=table_type, pkey=metadata.pkey_col, fkeys=metadata.fkey_col_to_pkey_table.keys(), time_col=metadata.time_col, lazy_feature=True, ) # table_data = upto(table_data, self.test_timestamp) table_data_dict[table_name] = table_data self._table_dict = table_data_dict # 4. validate dataset print("Validating dataset...") self.validate_dataset() # 5. make pkey-fkey graph and cache print("Making pkey-fkey graph...") hdata, tabledata_stats_dict = self.make_pkey_fkey_graph() hdata.save(osp.join(self.processed_dir, self.HDATA_FILE)) self._save_tabledata_stats_dict(tabledata_stats_dict) self._hdata = hdata self._tabledata_stats_dict = tabledata_stats_dict # 6. construct tasks self._task_dict = {} # driver-dnf driver_dnf_task = RelBenchTask( task_name="driver-dnf", task_type=RelBenchTaskType.BINARY_CLASSIFICATION, entity_col="driverId", entity_table="drivers", time_col="date", target_col="did_not_finish", timedelta=pd.Timedelta(days=30), num_eval_timestamps=40, task_data_dict=load_task_data( task_path=osp.join(self.task_dir, "driver-dnf") ) ) self._task_dict["driver-dnf"] = driver_dnf_task # driver-position driver_position_task = RelBenchTask( task_name="driver-position", task_type=RelBenchTaskType.REGRESSION, entity_col="driverId", entity_table="drivers", time_col="date", target_col="position", timedelta=pd.Timedelta(days=60), num_eval_timestamps=40, task_data_dict=load_task_data( task_path=osp.join(self.task_dir, "driver-position") ) ) self._task_dict["driver-position"] = driver_position_task # driver-top3 driver_top3_task = RelBenchTask( task_name="driver-top3", task_type=RelBenchTaskType.BINARY_CLASSIFICATION, entity_col="driverId", entity_table="drivers", time_col="date", target_col="qualifying", timedelta=pd.Timedelta(days=30), num_eval_timestamps=40, task_data_dict=load_task_data( task_path=osp.join(self.task_dir, "driver-top3") ) ) self._task_dict["driver-top3"] = driver_top3_task self._save_task_dict(self._task_dict) print("Processing done.")