Closed palvors closed 1 year ago
It looks like you are using incompatible callback class.
For Yolo NAS / PPYolo-E models you should be using PPYoloEPostPredictionCallback(score_threshold=VALUE_YOU_CHOOSE, nms_threshold=VALUE_YOU_CHOOSE, nms_top_k=1000, max_predictions=300)
class instead and not YoloXPostPredictionCallback
.
Let us know whether it helps.
Hi, Yes but now , I 've a new error with this code
model_nas_s = models.get(Models.YOLO_NAS_S, pretrained_weights="coco")
model_nas_s.eval()
with torch.no_grad(): raw_predictions_temp = model_nas_s(transformed_image)
predictions_temp = PPYoloEPostPredictionCallback(score_threshold=0.1,nms_threshold=0.4,nms_top_k=1000, max_predictions=300)(raw_predictions_temp)[0].numpy()
Error -> "TypeError: forward() missing 1 required positional argument: 'device'"
sorry, but do you know why ?
Sorry for inconvenience these classes are not quite polished for out-of-trainer usage. You need to pass device= whatever device you placed model to.
Btw, if you want to run inference you can use predict() method that works on numpy images.
Thank you, that work.
so the solution was just to put device= in the foward section like that
predictions_temp = PPYoloEPostPredictionCallback(score_threshold=0.5,nms_threshold=0.6, nms_top_k=1000, max_predictions=300)(raw_predictions_temp,device=None)[0].numpy()
thank you
🐛 Describe the bug
Hi,
With the model YOLO_NAS_S , the function "non_max_suppression" failed because this part return Tuple and not a integer as expected.
As other exemple , I just want use YoloXPostPredictionCallback
ERROR :
--> 266 candidates_above_thres = prediction[..., 4] > conf_thres # filter by confidence 267 output = [None] * prediction.shape[0] 269 for image_idx, pred in enumerate(prediction):
TypeError: tuple indices must be integers or slices, not tuple
line 266 src: https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/training/utils/detection_utils.py
anyone can confirm that ?
Versions
I use super-gradients 3.1.3
see for full requirement of env
asttokens=2.2.1