Explaining Local Classifier Per Node

A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPN model. A detailed summary of the Explainer class has been given at Algorithms Overview Section for Hierarchical Explainability. SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded here.

plot lcpn explainer

Out:

<xarray.Dataset>
Dimensions:          (class: 30, level: 3, sample: 246, feature: 9)
Coordinates:
  * class            (class) <U18 'Allergy_0' 'Allergy_1' ... 'Respiratory_1'
  * level            (level) int64 0 1 2
Dimensions without coordinates: sample, feature
Data variables:
    node             (sample, level) object 'Respiratory' ... 'Milk Allergy'
    predicted_class  (sample, level) object 'Respiratory' ... 'Milk Allergy'
    predict_proba    (sample, level, class) float64 nan nan nan ... nan nan nan
    classes          (sample, level, class) object nan nan nan ... nan nan nan
    shap_values      (level, class, sample, feature) float64 nan nan ... nan nan

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerNode, Explainer
from hiclass.datasets import load_platypus
import shap

# Load train and test splits
X_train, X_test, Y_train, Y_test = load_platypus()

# Use random forest classifiers for every node
rfc = RandomForestClassifier()
classifier = LocalClassifierPerNode(local_classifier=rfc, replace_classifiers=False)

# Train local classifier per node
classifier.fit(X_train, Y_train)

# Define Explainer
explainer = Explainer(classifier, data=X_train.values, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Filter samples which only predicted "Respiratory" at first level
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory"

# Specify additional filters to obtain only level 0
shap_filter = {"level": 0, "class": "Respiratory_1", "sample": respiratory_idx}

# Use .sel() method to apply the filter and obtain filtered results
shap_val_respiratory = explanations.sel(shap_filter)

# Plot feature importance on test set
shap.plots.violin(
    shap_val_respiratory.shap_values,
    feature_names=X_train.columns.values,
    plot_size=(13, 8),
)

Total running time of the script: ( 0 minutes 35.808 seconds)

Gallery generated by Sphinx-Gallery