BERT sklearn

In order to use bert-sklearn with HiClass, some of scikit-learns checks need to be disabled. The reason is that BERT expects text as input for the features, but scikit-learn expects numerical features. Hence, the checks will fail. To disable scikit-learn’s checks, we can simply use the parameter bert=True in the constructor of the local hierarchical classifier.

Out:

Building sklearn text classifier...

  0%|          | 0/231508 [00:00<?, ?B/s]
100%|##########| 231508/231508 [00:00<00:00, 7565485.48B/s]
Loading bert-base-uncased model...

  0%|          | 0/440473133 [00:00<?, ?B/s]
  1%|          | 3610624/440473133 [00:00<00:12, 36105155.37B/s]
  2%|2         | 10748928/440473133 [00:00<00:07, 56854315.10B/s]
  4%|3         | 16766976/440473133 [00:00<00:08, 49138876.10B/s]
  5%|5         | 23408640/440473133 [00:00<00:11, 35868688.96B/s]
  6%|6         | 27495424/440473133 [00:00<00:12, 33539381.02B/s]
  8%|7         | 33543168/440473133 [00:00<00:11, 34783040.91B/s]
  9%|9         | 40184832/440473133 [00:00<00:09, 42044158.51B/s]
 10%|#         | 44816384/440473133 [00:01<00:09, 40090521.17B/s]
 12%|#1        | 51091456/440473133 [00:01<00:08, 45704566.45B/s]
 13%|#3        | 58343424/440473133 [00:01<00:07, 52712907.24B/s]
 15%|#4        | 63976448/440473133 [00:01<00:08, 46342355.47B/s]
 16%|#6        | 71189504/440473133 [00:01<00:06, 52816569.69B/s]
 17%|#7        | 76861440/440473133 [00:01<00:08, 44277976.88B/s]
 19%|#9        | 83870720/440473133 [00:01<00:07, 48218049.48B/s]
 21%|##        | 90957824/440473133 [00:01<00:06, 53755724.54B/s]
 22%|##1       | 96724992/440473133 [00:02<00:07, 45306314.29B/s]
 23%|##3       | 101716992/440473133 [00:02<00:08, 38180400.55B/s]
 24%|##4       | 107297792/440473133 [00:02<00:08, 40869109.59B/s]
 25%|##5       | 111775744/440473133 [00:02<00:09, 35045660.94B/s]
 27%|##6       | 118755328/440473133 [00:02<00:07, 42578604.63B/s]
 29%|##8       | 125819904/440473133 [00:02<00:06, 49130058.08B/s]
 30%|###       | 132964352/440473133 [00:02<00:05, 54757066.55B/s]
 32%|###1      | 139756544/440473133 [00:03<00:05, 58243235.46B/s]
 33%|###3      | 145973248/440473133 [00:03<00:05, 57074745.14B/s]
 35%|###4      | 152972288/440473133 [00:03<00:04, 60611018.29B/s]
 36%|###6      | 159372288/440473133 [00:03<00:05, 55300435.27B/s]
 38%|###7      | 165878784/440473133 [00:03<00:04, 57888597.81B/s]
 39%|###9      | 171870208/440473133 [00:03<00:06, 41755665.23B/s]
 40%|####      | 176796672/440473133 [00:03<00:06, 43030828.34B/s]
 41%|####1     | 182790144/440473133 [00:03<00:05, 45852931.26B/s]
 43%|####2     | 187821056/440473133 [00:04<00:06, 41948197.31B/s]
 44%|####3     | 192923648/440473133 [00:04<00:06, 38622488.50B/s]
 45%|####5     | 198544384/440473133 [00:04<00:05, 42685452.36B/s]
 46%|####6     | 203127808/440473133 [00:04<00:07, 31549777.05B/s]
 47%|####7     | 207953920/440473133 [00:04<00:07, 30616033.23B/s]
 48%|####8     | 211436544/440473133 [00:04<00:08, 27344324.78B/s]
 50%|####9     | 218088448/440473133 [00:05<00:06, 34495082.45B/s]
 51%|#####1    | 224917504/440473133 [00:05<00:05, 41930084.44B/s]
 52%|#####2    | 229713920/440473133 [00:05<00:05, 40421163.21B/s]
 54%|#####3    | 236889088/440473133 [00:05<00:04, 47986874.01B/s]
 55%|#####5    | 244081664/440473133 [00:05<00:03, 54126934.80B/s]
 57%|#####7    | 251360256/440473133 [00:05<00:03, 59129511.92B/s]
 58%|#####8    | 257638400/440473133 [00:05<00:03, 53271392.99B/s]
 60%|#####9    | 263318528/440473133 [00:06<00:04, 37692782.61B/s]
 61%|######    | 267929600/440473133 [00:06<00:04, 35999934.62B/s]
 62%|######1   | 272102400/440473133 [00:06<00:06, 27043278.56B/s]
 63%|######2   | 276812800/440473133 [00:06<00:06, 27176040.16B/s]
 64%|######4   | 283680768/440473133 [00:06<00:04, 34872861.63B/s]
 65%|######5   | 287971328/440473133 [00:06<00:04, 33867851.81B/s]
 67%|######6   | 293588992/440473133 [00:06<00:03, 38136976.97B/s]
 68%|######8   | 300334080/440473133 [00:07<00:03, 44995163.23B/s]
 69%|######9   | 305390592/440473133 [00:07<00:02, 45779773.05B/s]
 70%|#######   | 310370304/440473133 [00:07<00:03, 43234643.08B/s]
 72%|#######1  | 314989568/440473133 [00:07<00:04, 26767582.33B/s]
 72%|#######2  | 318617600/440473133 [00:07<00:04, 24449866.63B/s]
 73%|#######3  | 321727488/440473133 [00:08<00:05, 22806613.16B/s]
 74%|#######4  | 327141376/440473133 [00:08<00:04, 24603095.65B/s]
 76%|#######5  | 334175232/440473133 [00:08<00:03, 33164378.41B/s]
 77%|#######7  | 341010432/440473133 [00:08<00:02, 40588670.95B/s]
 79%|#######8  | 345896960/440473133 [00:08<00:02, 42121542.82B/s]
 80%|#######9  | 352305152/440473133 [00:08<00:01, 45878475.78B/s]
 82%|########1 | 359180288/440473133 [00:08<00:01, 51684828.48B/s]
 83%|########2 | 364783616/440473133 [00:08<00:01, 46957684.39B/s]
 84%|########4 | 371143680/440473133 [00:08<00:01, 51161172.04B/s]
 86%|########5 | 377478144/440473133 [00:09<00:01, 47382219.16B/s]
 87%|########7 | 384119808/440473133 [00:09<00:01, 48977573.72B/s]
 88%|########8 | 389230592/440473133 [00:09<00:01, 40596418.03B/s]
 90%|########9 | 394254336/440473133 [00:09<00:01, 39980211.88B/s]
 91%|#########1| 401330176/440473133 [00:09<00:00, 47124493.16B/s]
 92%|#########2| 406412288/440473133 [00:09<00:00, 45860557.39B/s]
 94%|#########3| 412911616/440473133 [00:09<00:00, 50710323.12B/s]
 95%|#########5| 419418112/440473133 [00:10<00:00, 51837194.75B/s]
 97%|#########6| 425925632/440473133 [00:10<00:00, 55348849.32B/s]
 98%|#########7| 431636480/440473133 [00:10<00:00, 25341565.42B/s]
 99%|#########9| 436194304/440473133 [00:10<00:00, 25541713.09B/s]
100%|##########| 440473133/440473133 [00:10<00:00, 40350279.58B/s]

  0%|          | 0/433 [00:00<?, ?B/s]
100%|##########| 433/433 [00:00<00:00, 3109817.86B/s]
Defaulting to linear classifier/regressor
Loading Pytorch checkpoint
This DataLoader will create 5 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
train data size: 2, validation data size: 0

Training  :   0%|          | 0/1 [00:00<?, ?it/s]This DataLoader will create 5 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
This overload of add_ is deprecated:
        add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
        add_(Tensor other, *, Number alpha) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1630.)

Training  :   0%|          | 0/1 [00:04<?, ?it/s, loss=0.749]
Training  : 100%|##########| 1/1 [00:04<00:00,  4.31s/it, loss=0.749]
Training  : 100%|##########| 1/1 [00:04<00:00,  4.38s/it, loss=0.749]

Training  :   0%|          | 0/1 [00:00<?, ?it/s]
Training  :   0%|          | 0/1 [00:04<?, ?it/s, loss=0.703]
Training  : 100%|##########| 1/1 [00:04<00:00,  4.66s/it, loss=0.703]
Training  : 100%|##########| 1/1 [00:04<00:00,  4.86s/it, loss=0.703]

Training  :   0%|          | 0/1 [00:00<?, ?it/s]
Training  :   0%|          | 0/1 [00:05<?, ?it/s, loss=0.687]
Training  : 100%|##########| 1/1 [00:05<00:00,  5.00s/it, loss=0.687]
Training  : 100%|##########| 1/1 [00:05<00:00,  5.21s/it, loss=0.687]

Predicting:   0%|          | 0/1 [00:00<?, ?it/s]
Predicting: 100%|##########| 1/1 [00:01<00:00,  1.15s/it]
Predicting: 100%|##########| 1/1 [00:01<00:00,  1.29s/it]
[['Action' 'The Dark Night']
 ['Action' 'Watchmen']]

from bert_sklearn import BertClassifier
from hiclass import LocalClassifierPerParentNode

# Define data
X_train = X_test = [
    "Batman",
    "Rorschach",
]
Y_train = [
    ["Action", "The Dark Night"],
    ["Action", "Watchmen"],
]

# Use BERT for every node
bert = BertClassifier()
classifier = LocalClassifierPerParentNode(
    local_classifier=bert,
    bert=True,
)

# Train local classifier per node
classifier.fit(X_train, Y_train)

# Predict
predictions = classifier.predict(X_test)
print(predictions)

Total running time of the script: ( 0 minutes 35.918 seconds)

Gallery generated by Sphinx-Gallery