Source code for rllm.datasets.titanic
import os
import os.path as osp
from typing import Optional
import pandas as pd
from rllm.types import ColType
from rllm.data.table_data import TableData
from rllm.datasets.dataset import Dataset
from rllm.utils.download import download_url
[docs]
class Titanic(Dataset):
r"""The Titanic dataset is a widely-used dataset for machine learning
and statistical analysis, as featured in the `Titanic: Machine Learning
from Disaster <https://www.kaggle.com/c/titanic>`__ competition on Kaggle.
The dataset contains various features related to the passengers
aboard the Titanic, and the task is to predict whether a
passenger survived.
.. PassengerId: Unique identifier for each passenger.
.. Survived: Survival status (0 = No, 1 = Yes).
.. Pclass: Passenger class (1 = 1st, 2 = 2nd, 3 = 3rd).
.. Name: Name of the passenger.
.. Sex: Gender of the passenger.
.. Age: Age of the passenger in years.
.. SibSp: Number of siblings/spouses aboard the Titanic.
.. Parch: Number of parents/children aboard the Titanic.
.. Ticket: Ticket number.
.. Fare: Passenger fare.
.. Cabin: Cabin number.
.. Embarked: Port of embarkation
.. (C = Cherbourg, Q = Queenstown, S = Southampton).
Args:
cached_dir (str): Root directory where dataset should be saved.
forced_reload (bool): If set to `True`, this dataset will be
re-process again.
.. parsed-literal::
Statics:
Name Passengers Features
Size 891 12
"""
url = "https://github.com/datasciencedojo/datasets/raw/master/titanic.csv"
def __init__(
self,
cached_dir: str,
forced_reload: Optional[bool] = False,
transform=None,
tokenizer_config=None,
) -> None:
self.name = "titanic"
root = os.path.join(cached_dir, self.name)
self._tokenizer_config = tokenizer_config
super().__init__(root, force_reload=forced_reload)
self.data_list = [TableData.load(self.processed_paths[0])]
self.transform = transform
if self.transform is not None:
self.data_list[0] = self.transform(self.data_list[0])
@property
def raw_filenames(self):
return ["titanic.csv"]
@property
def processed_filenames(self):
return ["data.pt"]
[docs]
def process(self):
r"""
process data and save to './cached_dir/{dataset}/processed/'.
"""
os.makedirs(self.processed_dir, exist_ok=True)
path = osp.join(self.raw_dir, self.raw_filenames[0])
df = pd.read_csv(path, index_col=["PassengerId"])
# Note: the order of column in col_types must
# correspond to the order of column in files,
# except target column.
col_types = { # TODO Use 'Name', 'Ticket' and 'Cabin'.
"Survived": ColType.CATEGORICAL,
"Pclass": ColType.CATEGORICAL,
"Sex": ColType.CATEGORICAL,
"Age": ColType.NUMERICAL,
"SibSp": ColType.NUMERICAL,
"Parch": ColType.NUMERICAL,
"Fare": ColType.NUMERICAL,
"Embarked": ColType.CATEGORICAL,
}
data = TableData(
df=df,
col_types=col_types,
target_col="Survived",
tokenizer_config=self._tokenizer_config,
convert_text_coltypes=(
{ColType.CATEGORICAL} if self._tokenizer_config else None
),
)
data.save(self.processed_paths[0])
[docs]
def download(self):
os.makedirs(self.raw_dir, exist_ok=True)
download_url(self.url, self.raw_dir, self.raw_filenames[0])
def __len__(self):
return 1
def __getitem__(self, index: int):
if index != 0:
raise IndexError
return self.data_list[index]