devanshkv / fetch

A set of deep learning models for FRB/RFI binary classification.
GNU General Public License v3.0
40 stars 32 forks source link

Convert to PyTorch #38

Open aweaver1fandm opened 1 day ago

aweaver1fandm commented 1 day ago

Is your feature request related to a problem? Please describe. Fetch currently uses tensorflow/keras but those frameworks have issues related to pulling newer versions of tensorflow and especially Keras3. Current model format will not work with Keras 3 and newer versions of tensorflow may build but not fully execute because of issues with TensorRT

Describe the solution you'd like Port FETCH to use PyTorch instead

Describe alternatives you've considered Attempted to update the current models to work with Keras 3. It is unclear if the new models are correct because newer versions of tensorflow is having issues running because of an issue related to tensorrt not being correctly found. Solutions for that problem are very hacky and not automated in any way.

Additional context I have started to do this conversion. I was able to use onnx and onnx2torch to convert the initial models from tensorflow to onnx format and then onnx to torch format. The problem with this is that the torch models are now much much larger than the tensorflow models 100s of MB instead of 100s of KB. This means the models cannot be stored on github (or at least free accounts of Github. The size of the model does not seem to be a by-product of the conversion as I have seen similar complaints from other folks about the size of pytorch models in general.

Currently the NN model and the weights are stored in the PyTorch file. It is possible to have the model in a python file as it's own class and then store only the model weights but it's not clear if that will reduce the size enough.

aweaver1fandm commented 1 day ago

After doing some digging around I found the original model descriptions in .h5 format as well as the training and test data. One question I have was each model trained with the same train/test data? I would assume that is the case