lukasruff / Deep-SVDD-PyTorch

A PyTorch implementation of the Deep SVDD anomaly detection method
MIT License
698 stars 197 forks source link

Wrong get_radius function #3

Closed fsk119 closed 5 years ago

fsk119 commented 5 years ago

https://github.com/lukasruff/Deep-SVDD-PyTorch/blob/1919546eda13b6b70e1701bde08dedf13b92bc11/src/optim/deepSVDD_trainer.py#L179-L181

Hi, I think this line 181 is wrong. We can modify this line to

np.quantile(np.sqrt(dist.clone().data.cpu().numpy()), 1-nu). 

Because dist is (feat - center)**2 and we are comparing dist to R^2 in https://github.com/lukasruff/Deep-SVDD-PyTorch/blob/1919546eda13b6b70e1701bde08dedf13b92bc11/src/optim/deepSVDD_trainer.py#L131.

lukasruff commented 5 years ago

Thanks for catching this error! You're correct and I've fixed line 181 in commit 77b5d33816cb454aa30264a0353401f6f6e8b59e according to your suggestion.

jhl13 commented 4 years ago

Thanks for catching this error! You're correct and I've fixed line 181 in commit 77b5d33 according to your suggestion.

Hi, according to the paper, the dist should be sorted, but I don't find any code to do this.