Closed rbunn80110 closed 5 years ago
Hi - you're going to have to provide more context than this :) Can you show how you are creating the network, etc? Or just share out your notebook?
This works when I use the built in models:
learn = cnn_learner(data, models.resnet152,
metrics=[accuracy, dice, Precision(average='macro'),Recall(average='macro'),FBeta(average='macro')],
callback_fns=[ShowGraph],
wd=1e-3,
bn_wd=False,
true_wd=True,
loss_func=LabelSmoothingCrossEntropy(),
opt_func=optar, ps=0.001).to_fp16().blend(**kwargs).show_tfms()
uncomment and remove the resnet152 and you get the error.
I'm thinking a super simple notebook that uses this code would probably clear up my confusion on how to use this.
Apparently I need to use Learner and not cnn_learner:
learn = Learner(data, res2net(), metrics=[accuracy, dice, Precision(average='macro'),Recall(average='macro'),FBeta(average='macro')], callback_fns=[ShowGraph], wd=1e-3, bn_wd=False, true_wd=True, loss_func=LabelSmoothingCrossEntropy(), opt_func=optar).to_fp16()
TypeError Traceback (most recent call last)