deel-ai / oodeel

Simple, compact, and hackable post-hoc deep OOD detection for already trained tensorflow or pytorch image classifiers.
https://deel-ai.github.io/oodeel/
MIT License
52 stars 2 forks source link

feat: add training func for torch models #8

Closed y-prudent closed 1 year ago

y-prudent commented 1 year ago

Features:

Note: A function allow_growth is called at the import of the oodeel.models.training_funcs.torch_models module. It prevents tensorflow from allocating the totality of the GPU when a tfds dataset is loaded, so that some VRAM remain free to train torch models. Thus, the torch_models module absolutely needs to be loaded BEFORE calling data_handler.load_tfds.

Run training on CIFAR-10:

 python -m notebooks.train_cifar10 --framework 'torch' --save_dir 'saved_models/cifar10-torch'