LSSTDESC / tomo_challenge

2020 Tomographic binning challenge
13 stars 18 forks source link

CBPF IA Team models LSTM, Deep Ensemble, Jax/Flax Deep Learning Ensemble and others #25

Open cdebom opened 4 years ago

cdebom commented 4 years ago

We added two models. One bidirectional LSTM and a conv1D.

EiffL commented 4 years ago

@cdebom Fantastic! Thank you for your entry :-D Could you document in this thread some of your metric results if you have them?

cdebom commented 4 years ago

Dear @EiffL , sorry for the late answer. We perform a couple of tests and retrieve the SNR_3x2 and FOM_3x2 for these two solutions. Here follows a table for the LSTM, later I will send the conv1D.

image

I updated the previous table since we found some issues in the calculations we had before.

cdebom commented 4 years ago

I just included the LightGBM model. These are the results we got from it: image

image

cdebom commented 4 years ago

Here is summary of our models based on tensorflow/keras our team developed and uploaded:

Here are the contributors of our team: Clecio R. Bom (CBPF), Bernardo Fraga (CBPF), Gabriel Teixeira (CBPF), Eduardo Cypriano (USP) and Elizabeth Gonzalez (IATE-CONICET).

cdebom commented 4 years ago

Additionally, we also added a solution based in JAX/FLAX. It is a LSTM that optimizes for SNR score. Thus we upload a total of 8 models.

cdebom commented 4 years ago

We added JAX/FLAX Deep Ensemble model. For now we have 7 Tensorflow based entries and 2 NN using JAX/FLAX entries.

joezuntz commented 4 years ago

Hi @cdebom - thanks for your entries!

I'm currently making sure I can run things properly on my system. There are a couple of imports that I'm now missing

import jax_random as rand
import jax_numpy as jnp

should these be jax.random and jax.numpy?

joezuntz commented 4 years ago

Also it looks like at least one file may be missing, lightgbm.py

joezuntz commented 4 years ago

And hopefully finally, there are two different classes called ENSEMBLE - in DEEP_ENESMBLE_main.py and Deep_Ensemble_jax.py. Is one of them the right one to use, or should I just rename one so it doesn't clash?

cdebom commented 4 years ago

Dear Joe, I am going to check the files, and will come back to you as soon as possible. About the Ensemble, there are two ensemble models. The one in DEEP_ENESMBLE_main.py is based in tensorflow/keras models and optimizes to reduce the classification error. The second ensemble, the Deep_Ensemble_jax.py has a different set of models and it optimizes for SNR and FOM. However I might have uploaded the SNR optimization hardcoded, but it is just a matter of changing jax-metrics for FOM that is already implemented.

Em qua., 16 de set. de 2020 às 08:06, joezuntz notifications@github.com escreveu:

And hopefully finally, there are two different classes called ENSEMBLE - in DEEP_ENESMBLE_main.py and Deep_Ensemble_jax.py. Is one of them the right one to use, or should I just rename one so it doesn't clash?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/LSSTDESC/tomo_challenge/pull/25#issuecomment-693334784, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHC5UZ6LWWB3BTXTQGMVCSTSGCL25ANCNFSM4POIULUA .

-- Clécio R. Bom, PhD Assistant Professor / Scientist Brazilian Center for Physics Research clearnightsrthebest.com

joezuntz commented 4 years ago

I just realized lightgbm is an external module. I'll install it. And I'll rename the ENSEMBLEs to ENSEMBLE1 and ENSEMBLE2. The challenge has closed now, so it's too late to modify any scientific things.

cdebom commented 4 years ago

OK, about the

import jax_random as rand import jax_numpy as jnp

yes you are correct, it should be jax.random and jax.numpy there must be a typho somewhere. Sorry about that.

By the way we also experimented with a custom data loader with different cuts. As we could not pull anything outside tomo_challenge/tomo_challenge/ we leave it in tomo_challenge/tomo_challenge/custom_loader.py and tomo_challenge/tomo_challenge/challenge_custom_loader.py This last file could be used in tomo_challenge/bin/ instead of challenge.py in order to the custom loader work.

Em qui., 17 de set. de 2020 às 12:23, joezuntz notifications@github.com escreveu:

I just realized lightgbm is an external module. I'll install it. And I'll rename the ENSEMBLEs to ENSEMBLE1 and ENSEMBLE2. The challenge has closed now, so it's too late to modify any scientific things.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/LSSTDESC/tomo_challenge/pull/25#issuecomment-694309665, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHC5UZ5WKTP477G64PLX2VDSGISW5ANCNFSM4POIULUA .

-- Clécio R. Bom, PhD Assistant Professor / Scientist Brazilian Center for Physics Research clearnightsrthebest.com

joezuntz commented 4 years ago

Hi @cdebom - I'm just putting together the full suite of runs we will do. I don't quite see what your custom loader does differently - is it just removing all the galaxies with inf values for the magnitudes? If so I will modify your entries so they just don't provide an estimate for undetected galaxies (which is what inf means)

joezuntz commented 4 years ago

And also are you sure you want to do this? In the standard data loader we just set the inf values to 30 instead, so if your problem was with them being inf specifcally that shouldn't be an issue. Possibly we added this feature after you wrote this.

cdebom commented 4 years ago

Dear Joe, even setting to 30, we still see some difference in our deep models. For instance, using 10 bins, we found in the tensorflow based Ensemble (resnet, autolstm, fcn) with original loader and jax_metrics: SNR_3x2 : 1920.7 and FOM_3x2 : 10615.2, removing those galaxies in the custom loader we found SNR_3x2 : 1943.4 and FOM_3x2 : 10797.1.

joezuntz commented 4 years ago

Okay - I will tweak your code to assign those objects to no bin.

cdebom commented 4 years ago

OK, thanks. Just to emphasize that this is something we notice in our models if we remove those galaxies both in training phase and testing phase. So this only makes sense if the training sample dataset has also these galaxies removed.

joezuntz commented 4 years ago

When I run the Flax_LSTM code I get this error:

  File "/home/jzuntz/tomo_challenge/bin/../tomo_challenge/classifiers/lstm_flax.py", line 82, in __init__
    self.n_features = options['n_feats']
KeyError: 'n_feats'

What value should n_feats have in the configuration? Is it just the total number of input columns or is it more complicated than that?

cdebom commented 4 years ago

Yes, it is the total number of input columns.

Cheers

Em ter., 29 de set. de 2020 às 14:45, joezuntz notifications@github.com escreveu:

When I run the Flax_LSTM code I get this error:

File "/home/jzuntz/tomo_challenge/bin/../tomo_challenge/classifiers/lstm_flax.py", line 82, in init self.n_features = options['n_feats'] KeyError: 'n_feats'

What value should n_feats have in the configuration? Is it just the total number of input columns or is it more complicated than that?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/LSSTDESC/tomo_challenge/pull/25#issuecomment-700874810, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHC5UZ7EDS4VK4YDXXO4YE3SIIMM5ANCNFSM4POIULUA .

-- Clécio R. Bom, PhD Assistant Professor / Scientist Brazilian Center for Physics Research clearnightsrthebest.com

joezuntz commented 4 years ago

Why is that a configuration input? Why is it not calculated from the data?

cdebom commented 3 years ago

Just for the record, here are some of the results we had with 10 bins. Since training using jax was very computer intensive we tested using 10% of the data in a couple random choices of training set. To have a fair comparison with the other methods we also trained with 10%.

image

It is interesting to note that in the couple of runs of jax LSTM (optimized for SNR) there were variations in SNR_3X2 and FOM_3X2. However it seems that there is a trade off between those two, when one gets a bit higher the other gets a bit lower.

This is the binning plot for JAX autolstm run2

image