Source code for rllm.datasets.adult

import os
import os.path as osp
from typing import Optional

import shutil
import numpy as np
import pandas as pd

from rllm.types import ColType
from rllm.data.table_data import TableData
from rllm.datasets.dataset import Dataset
from rllm.utils.download import download_url


[docs] class Adult(Dataset): r"""The Adult dataset is a dataset from a classic data mining project, which was extracted from the `1994 Census database <https://archive.ics.uci.edu/dataset/2/adult>`__. The dataset encompasses a variety of features pertaining to adults and their income. The primary objective is to predict whether an individual's annual income surpasses $50,000. .. Age: Age of the individual. .. Workclass: Type of industry (Private, Self-emp-not-inc, Self-emp-inc, .. Federal-gov, Local-gov, State-gov, Without-pay, Never-worked). .. fnlwgt: The number of people the census believes have this job. .. Education: The highest level of education achieved (Bachelors, .. Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, .. 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool). .. Education-Num: A numeric version of Education. .. Marital-Status: Marital status of the individual (Married-civ-spouse, .. Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse). .. Occupation: The kind of work individuals perform (Tech-support, .. Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, .. Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, .. Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces). .. Relationship: Relationship to head-of-household (Wife, Own-child, Husband, .. Not-in-family, Other-relative, Unmarried). .. Race: Race of the individual (White, Asian-Pac-Islander, .. Amer-Indian-Eskimo, Other, Black). .. Sex: Gender of the individual. .. Capital-Gain: Total capital gains. .. Capital-Loss: Total capital losses. .. Hours-per-week: Average hours worked per week. .. Native-Country: Country of origin of the individual. .. Target: Income level. Args: cached_dir (str): Root directory where dataset should be saved. forced_reload (bool): If set to `True`, this dataset will be re-process again. .. parsed-literal:: Statics: Name Individuals Features Size 48842 14 """ urls = [ "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", ] def __init__(self, cached_dir: str, forced_reload: Optional[bool] = False) -> None: self.name = "adult" root = os.path.join(cached_dir, self.name) super().__init__(root, force_reload=forced_reload) self.data_list = [TableData.load(self.processed_paths[0])] @property def raw_filenames(self): return ["adult_train.csv", "adult_test.csv"] @property def processed_filenames(self): return ["data.pt"]
[docs] def process(self, num_rows: Optional[int] = None) -> None: r""" process data and save to './cached_dir/{dataset}/processed/'. """ if os.path.exists(self.processed_dir) and os.path.isdir(self.processed_dir): shutil.rmtree(self.processed_dir) os.makedirs(self.processed_dir, exist_ok=True) path_train = osp.join(self.raw_dir, self.raw_filenames[0]) path_test = osp.join(self.raw_dir, self.raw_filenames[1]) # Note: the order of column in col_types must # correspond to the order of column in files, # except target column. col_types = { "age": ColType.NUMERICAL, "workclass": ColType.CATEGORICAL, "fnlwgt": ColType.NUMERICAL, "education": ColType.CATEGORICAL, "educational-num": ColType.NUMERICAL, "marital-status": ColType.CATEGORICAL, "occupation": ColType.CATEGORICAL, "relationship": ColType.CATEGORICAL, "race": ColType.CATEGORICAL, "gender": ColType.CATEGORICAL, "capital-gain": ColType.NUMERICAL, "capital-loss": ColType.NUMERICAL, "hours-per-week": ColType.NUMERICAL, "native-country": ColType.CATEGORICAL, "income": ColType.CATEGORICAL, } df_train = pd.read_csv(path_train, header=None, names=list(col_types.keys())) df_test = pd.read_csv( path_test, header=None, names=list(col_types.keys()), skiprows=1 ) df = pd.concat([df_train, df_test], ignore_index=True) df["income"] = df["income"].str.strip().str.rstrip(".") if num_rows is not None: # generate random subset id_rows = df.index.tolist() selected_ids = np.random.choice(id_rows, num_rows, replace=False) df = df.iloc[selected_ids].reset_index(drop=True) assert isinstance(df, pd.DataFrame) data = TableData( df=df, col_types=col_types, target_col="income", ) data.save(self.processed_paths[0])
[docs] def download(self): os.makedirs(self.raw_dir, exist_ok=True) download_url(self.urls[0], self.raw_dir, self.raw_filenames[0]) download_url(self.urls[1], self.raw_dir, self.raw_filenames[1])
def __len__(self): return 1 def __getitem__(self, index: int): if index != 0: raise IndexError return self.data_list[index]