deel-ai / oodeel

Simple, compact, and hackable post-hoc deep OOD detection for already trained tensorflow or pytorch image classifiers.
https://deel-ai.github.io/oodeel/
MIT License
52 stars 2 forks source link

Feat add pytorch gradient + universal tools #9

Closed y-prudent closed 1 year ago

y-prudent commented 1 year ago

Main features:

y-prudent commented 1 year ago

Ready for reviewing!

Indeed, importing repetedly torch or tf condiserably slows down the execution, that's a very good point. I removed universal_tools.py and integrated all the corresponding functions as methods of the OODModel class (base.py). During the fit call of an OODModel, torch or tensorflow libraries are now globally imported depending on the model framework, determining the behaviour of each function.

paulnovello commented 1 year ago

Putting operations (argmax, sign etc) as methods of OODModel was not satisfying so I tried to put them in tf_tools and torch_tools and lazily import the appropriate one during OODModel.fit method. The imported module is assigned to a class attribute called "op" to enable its use by child classes.

Hesitated between this solution and another that would be to create a class e.g. UniversalOperator initialized with "framework" argument, whose methods are operations with if statement that calls tf syntax when self. framework=="tensorflow" or torch syntax when self. framework=="torch". An object of this class would be instantiated in the OODModel.fit method and assigned as an attribute (similarly to self.op).

What is your opinion about this?