Source code for rllm.datasets.churn_modelling
import os
import os.path as osp
from typing import Optional
import pandas as pd
from rllm.types import ColType
from rllm.data.table_data import TableData
from rllm.datasets.dataset import Dataset
from rllm.utils.download import download_url
[docs]
class ChurnModelling(Dataset):
r"""`The Churn Modelling dataset <https://www.kaggle.com/shrutimechlearn/
churn-modelling>`__ is used to predict which customers are
likely to churn from the organization by analyzing various attributes and
applying machine learning and deep learning techniques.
Customer churn refers to when a customer (player, subscriber, user, etc.)
ceases their relationship with a company. Online businesses typically treat
a customer as churned once a particular amount of time has elapsed since
the customer's last interaction with the site or service.
Customer churn occurs when customers or subscribers stop doing business
with a company or service, also known as customer attrition. It is also
referred to as loss of clients or customers. Similar to predicting
employee turnover, we are going to predict customer churn using this
dataset.
The dataset encompasses a variety of features pertaining to customers and
their interactions with the company. The primary objective is to predict
whether a customer will churn.
.. RowNumber: Row number.
.. CustomerId: Unique identifier for the customer.
.. Surname: Surname of the customer.
.. CreditScore: Credit score of the customer.
.. Geography: Country of the customer (France, Spain, Germany).
.. Gender: Gender of the customer (Male, Female).
.. Age: Age of the customer.
.. Tenure: Number of years the customer has been with the company.
.. Balance: Account balance of the customer.
.. NumOfProducts: Number of products the customer has with the company.
.. HasCrCard: Does the customer have a credit card? (0 = No, 1 = Yes).
.. IsActiveMember: Is the customer an active member? (0 = No, 1 = Yes).
.. EstimatedSalary: Estimated salary of the customer.
.. Exited: Did the customer churn? (0 = No, 1 = Yes).
Args:
cached_dir (str): Root directory where dataset should be saved.
forced_reload (bool): If set to `True`, this dataset will be
re-processed again.
.. parsed-literal::
Statics:
Name Customers Features
Size 10000 14
"""
url = "https://raw.githubusercontent.com/sharmaroshan/Churn-Modelling-Dataset/master/Churn_Modelling.csv"
def __init__(self, cached_dir: str, forced_reload: Optional[bool] = False) -> None:
self.name = "churn"
root = os.path.join(cached_dir, self.name)
super().__init__(root, force_reload=forced_reload)
self.data_list = [TableData.load(self.processed_paths[0])]
@property
def raw_filenames(self):
return ["churn.csv"]
@property
def processed_filenames(self):
return ["data.pt"]
[docs]
def process(self):
r"""
process data and save to './cached_dir/{dataset}/processed/'.
"""
os.makedirs(self.processed_dir, exist_ok=True)
path = osp.join(self.raw_dir, self.raw_filenames[0])
df = pd.read_csv(path, index_col=["RowNumber"])
# Note: the order of column in col_types must
# correspond to the order of column in files,
# except target column.
col_types = {
"CreditScore": ColType.NUMERICAL,
"Geography": ColType.CATEGORICAL,
"Gender": ColType.CATEGORICAL,
"Age": ColType.NUMERICAL,
"Tenure": ColType.NUMERICAL,
"Balance": ColType.NUMERICAL,
"NumOfProducts": ColType.NUMERICAL,
"HasCrCard": ColType.NUMERICAL,
"IsActiveMember": ColType.CATEGORICAL,
"EstimatedSalary": ColType.NUMERICAL,
"Exited": ColType.CATEGORICAL,
}
data = TableData(
df=df,
col_types=col_types,
target_col="Exited",
)
data.save(self.processed_paths[0])
[docs]
def download(self):
os.makedirs(self.raw_dir, exist_ok=True)
download_url(self.url, self.raw_dir, self.raw_filenames[0])
def __len__(self):
return 1
def __getitem__(self, index: int):
if index != 0:
raise IndexError
return self.data_list[index]