Closed flww213 closed 1 month ago
I solve it In model/mae_cvt.py line 21, the params "self.abnormal_score_func" and "self.abnormal_score_func_TS" are "L" "2" from config.abnormal_score_func='L2' in config,py. They shoule be "L1" or "L2".
class MaskedAutoencoderCvT(nn.Module):
def __init__(self, img_size=(512,512), patch_size=16, in_chans=9, out_chans=4,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
use_only_masked_tokens_ab=False, abnormal_score_func='L1', masking_method="random_masking",
grad_weighted_loss=True, student_depth=1):
super().__init__()
# --------------------------------------------------------------------------
# Abnormal specifics
self.use_only_masked_tokens_ab = use_only_masked_tokens_ab
self.abnormal_score_func = abnormal_score_func[0]
self.abnormal_score_func_TS = abnormal_score_func[1]
# --------------------------------------------------------------------------
thanks for pointing this out. yes, config.abnormal_score_func should be a list of error metrics with 2 values, one to measure the teacher's error and 2nd for the teacher student discrepancy. I think we left the config suitable for inference, we will push a fix for this
I solve it In model/mae_cvt.py line 21, the params "self.abnormal_score_func" and "self.abnormal_score_func_TS" are "L" "2" from config.abnormal_score_func='L2' in config,py. They shoule be "L1" or "L2".
class MaskedAutoencoderCvT(nn.Module): def __init__(self, img_size=(512,512), patch_size=16, in_chans=9, out_chans=4, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, use_only_masked_tokens_ab=False, abnormal_score_func='L1', masking_method="random_masking", grad_weighted_loss=True, student_depth=1): super().__init__() # -------------------------------------------------------------------------- # Abnormal specifics self.use_only_masked_tokens_ab = use_only_masked_tokens_ab self.abnormal_score_func = abnormal_score_func[0] self.abnormal_score_func_TS = abnormal_score_func[1] # --------------------------------------------------------------------------
I modify config.abnormal_score_func='L2' to ['L2', 'L2'], but got another error: AttributeError: 'list' object has no attribute 'detach' in recon_error = recon_error.detach().cpu().numpy(). I guess it is because MaskedAutoencoderCvT.abnormal_score_TS() in model/mae_cvt.py return a list of scores when teacher-student mode are used. How can I run it and compute metrics correctly?
# Compute statistics
predictions = np.array(predictions)
labels = np.array(labels)
videos = np.array(videos)
aucs = []
filtered_preds = []
filtered_labels = []
for vid in np.unique(videos):
pred = predictions[np.array(videos) == vid]
pred = np.nan_to_num(pred, nan=0.)
if args.dataset=='avenue':
pred = filt(pred, range=38, mu=11)
else:
raise ValueError('Unknown parameters for predictions postprocessing')
# pred = (pred - np.min(pred)) / (np.max(pred) - np.min(pred))
filtered_preds.append(pred)
lbl = labels[np.array(videos) == vid]
filtered_labels.append(lbl)
lbl = np.array([0] + list(lbl) + [1])
pred = np.array([0] + list(pred) + [1])
fpr, tpr, _ = metrics.roc_curve(lbl, pred)
res = metrics.auc(fpr, tpr)
aucs.append(res)
macro_auc = np.nanmean(aucs)
# Micro-AUC
filtered_preds = np.concatenate(filtered_preds)
filtered_labels = np.concatenate(filtered_labels)
fpr, tpr, _ = metrics.roc_curve(filtered_labels, filtered_preds)
micro_auc = metrics.auc(fpr, tpr)
micro_auc = np.nan_to_num(micro_auc, nan=1.0)
# gather the stats from all processes
print(f"MicroAUC: {micro_auc}, MacroAUC: {macro_auc}")
Hello. We pushed a fix
When I train on the avenue dataset and the training epoch reaches start_TS_epoch, an AttributeError occurs: 'NoneType' object has no attribute 'detach'. The complete error message is below. This bug consistently appears across multiple datasets. It seems to be an issue that arises when starting to train the student model. How can I resolve this?