Closed MeranaTona closed 3 years ago
Okay, after adding following, my second approach of using TFAutoModelForSequenceClassification does not result in nonetype anymore. Althought the model is treating the problem more like a multilabel regression problem instead of a multilabel classification
model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])
predictor = ktrain.get_predictor(model, t)
predictor.predict("What are we doing today?")
Out[105]: [('label1', 0.4994808),
('label2', 0.5707391),
('label3', 0.4876785),
('label4', 0.5431454),
('label5', 0.50143594),
('label6', 0.45981723),
('label7', 0.49624988),
('label8', 0.51856476),
('label9', 0.4871253)]
Hello.
The 404 client error is simply saying that there is no TensorFlow version of this model (which is True). This is the first time I've seen this warning, so perhaps it is a warning from a newer version of requests
or one of its dependencies. In any case, under these cases, ktrain then tries to load the model as a PyTorch model (which is successful in this case). If you run the entire code and ignore the 404 warning, everything should work (and it does for me):
# load text data
categories = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True)
test_b = fetch_20newsgroups(subset='test',categories=categories, shuffle=True)
(x_train, y_train) = (train_b.data, train_b.target)
(x_test, y_test) = (test_b.data, test_b.target)
# build and train model
import ktrain
from ktrain import text
MODEL_NAME = 'google/bert_uncased_L-2_H-128_A-2'
t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)
trn = t.preprocess_train(x_train, y_train)
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, batch_size=6)
learner.fit_onecycle(5e-5, 1)
Okay thank you very much :)
Hi there!
Thank you for that simplified library. I am trying to use a very small version of Bert from Google on Hugging Face due to my slow CPU. Unfortunately it can't be located. Can't figure out to make it work. Maybe because I'm new.
Which results in:
I have also tried loading it like this:
But it results in nonetype:
My data are strings and multilabels with 9 labels:
Thank you in advance.