Note
Click here to download the full example code
Parallel Training
Larger datasets require more time for training. While by default the models in HiClass are trained using a single core, it is possible to train each local classifier in parallel by leveraging the library Ray 1. In this example, we demonstrate how to train a hierarchical classifier in parallel, using all the cores available, on a mock dataset from Kaggle 2.
import sys
from os import cpu_count
import pandas as pd
import requests
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from hiclass import LocalClassifierPerParentNode
def download(url: str, path: str) -> None:
"""
Download a file from the internet.
Parameters
----------
url : str
The address of the file to be downloaded.
path : str
The path to store the downloaded file.
"""
response = requests.get(url)
with open(path, "wb") as file:
file.write(response.content)
# Download training data
training_data_url = "https://zenodo.org/record/6657410/files/train_40k.csv?download=1"
training_data_path = "train_40k.csv"
download(training_data_url, training_data_path)
# Load training data into pandas dataframe
training_data = pd.read_csv(training_data_path).fillna(" ")
# We will use logistic regression classifiers for every parent node
lr = LogisticRegression(max_iter=1000)
pipeline = Pipeline(
[
("count", CountVectorizer()),
("tfidf", TfidfTransformer()),
(
"lcppn",
LocalClassifierPerParentNode(local_classifier=lr, n_jobs=cpu_count()),
),
]
)
# Select training data
X_train = training_data["Title"]
Y_train = training_data[["Cat1", "Cat2", "Cat3"]]
# Fixes bug AttributeError: '_LoggingTee' object has no attribute 'fileno'
# This only happens when building the documentation
# Hence, you don't actually need it for your code to work
sys.stdout.fileno = lambda: False
# Now, let's train the local classifier per parent node
pipeline.fit(X_train, Y_train)
Total running time of the script: ( 1 minutes 18.178 seconds)