henrysky / astroNN

Deep Learning for Astronomers with Tensorflow
http://astronn.readthedocs.io/
MIT License
193 stars 52 forks source link

ApogeeBCNN() dimensions #14

Open luantunez opened 3 years ago

luantunez commented 3 years ago

Hello and thank you for sharing your work. I want to classify images with color depth with a Bayesian Neural Network. Though, with this model, I am getting a dimensions error:

Input 0 of layer max_pooling1d_13 is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: (None, 75, 75, 3)

My input is a dataset loaded with

training_dataset = tf.keras.preprocessing.image_dataset_from_directory

and converted to tensors with

images, labels = next(iter(training_dataset))

so I am trying to train the model with

bcnn_net = ApogeeBCNN()
bcnn_net.fit(images, labels )

Why am I getting this error? Is there a specific way to pass the data?

Thank you, Lucia

henrysky commented 3 years ago

Yes ApogeeBCNN use 1D convolution, you should use MNIST_BCNN for images.

i.e.

from astroNN.models import MNIST_BCNN
henrysky commented 3 years ago

For example see here: https://github.com/henrysky/astroNN/blob/master/demo_tutorial/NN_uncertainty_analysis/Uncertainty_Demo_MNIST.ipynb

luantunez commented 3 years ago

Thank you! it is working now

Though, I am getting very strange losses:

() is deprecated and will be removed in future. Use fit() instead.
Number of Training Data: 2008, Number of Validation Data: 223
====Message from Normalizer==== 
 You selected mode: 255 
 Featurewise Center: {'input': False, 'input_err': False, 'labels_err': False} 
 Datawise Center: {'input': False, 'input_err': False, 'labels_err': False} 
 Featurewise std Center: {'input': False, 'input_err': False, 'labels_err': False} 
 Datawise std Center: {'input': False, 'input_err': False, 'labels_err': False} 
 ====Message ends====
====Message from Normalizer==== 
 You selected mode: 255 
 Featurewise Center: {'input': False, 'input_err': False, 'labels_err': False} 
 Datawise Center: {'input': False, 'input_err': False, 'labels_err': False} 
 Featurewise std Center: {'input': False, 'input_err': False, 'labels_err': False} 
 Datawise std Center: {'input': False, 'input_err': False, 'labels_err': False} 
 ====Message ends====
====Message from Normalizer==== 
 You selected mode: 255 
 Featurewise Center: {'input': False, 'input_err': False, 'labels_err': False} 
 Datawise Center: {'input': False, 'input_err': False, 'labels_err': False} 
 Featurewise std Center: {'input': False, 'input_err': False, 'labels_err': False} 
 Datawise std Center: {'input': False, 'input_err': False, 'labels_err': False} 
 ====Message ends====
====Message from Normalizer==== 
 You selected mode: 0 
 Featurewise Center: {'output': False, 'variance_output': False} 
 Datawise Center: {'output': False, 'variance_output': False} 
 Featurewise std Center: {'output': False, 'variance_output': False} 
 Datawise std Center: {'output': False, 'variance_output': False} 
 ====Message ends====
====Message from Normalizer==== 
 You selected mode: 0 
 Featurewise Center: {'output': False, 'variance_output': False} 
 Datawise Center: {'output': False, 'variance_output': False} 
 Featurewise std Center: {'output': False, 'variance_output': False} 
 Datawise std Center: {'output': False, 'variance_output': False} 
 ====Message ends====
Epoch 1/30
31/31 - 7s - loss: 564069007360.0000 - output_loss: 1128138014720.0000 - variance_output_loss: 12.7896 - output_categorical_accuracy: 1.0000 - val_loss: 8876439109632.0000 - val_output_loss: 17752878219264.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 2/30
31/31 - 5s - loss: 4427500054642688.0000 - output_loss: 8855000109285376.0000 - variance_output_loss: 12.9506 - output_categorical_accuracy: 1.0000 - val_loss: 29915935422808064.0000 - val_output_loss: 59831870845616128.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00002: ReduceLROnPlateau reducing learning rate to 0.0024999999441206455.
Epoch 3/30
31/31 - 5s - loss: 142776859086553088.0000 - output_loss: 285553718173106176.0000 - variance_output_loss: 12.9884 - output_categorical_accuracy: 1.0000 - val_loss: 371173135304622080.0000 - val_output_loss: 742346270609244160.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.0012499999720603228.
Epoch 4/30
31/31 - 5s - loss: 662793099646337024.0000 - output_loss: 1325586199292674048.0000 - variance_output_loss: 13.0272 - output_categorical_accuracy: 1.0000 - val_loss: 1010213485256114176.0000 - val_output_loss: 2020426970512228352.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00004: ReduceLROnPlateau reducing learning rate to 0.0006249999860301614.
Epoch 5/30
31/31 - 5s - loss: 1300403534725906432.0000 - output_loss: 2600807069451812864.0000 - variance_output_loss: 12.8977 - output_categorical_accuracy: 1.0000 - val_loss: 1544848471641554944.0000 - val_output_loss: 3089696943283109888.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00005: ReduceLROnPlateau reducing learning rate to 0.0003124999930150807.
Epoch 6/30
31/31 - 5s - loss: 1762389595976105984.0000 - output_loss: 3524779191952211968.0000 - variance_output_loss: 12.9758 - output_categorical_accuracy: 1.0000 - val_loss: 1870810238567841792.0000 - val_output_loss: 3741620477135683584.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.00015624999650754035.
Epoch 7/30
31/31 - 4s - loss: 2023441006761869312.0000 - output_loss: 4046882013523738624.0000 - variance_output_loss: 12.9688 - output_categorical_accuracy: 1.0000 - val_loss: 2073333271017553920.0000 - val_output_loss: 4146666542035107840.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00007: ReduceLROnPlateau reducing learning rate to 7.812499825377017e-05.
Epoch 8/30
31/31 - 4s - loss: 2158791300457955328.0000 - output_loss: 4317582600915910656.0000 - variance_output_loss: 12.9607 - output_categorical_accuracy: 1.0000 - val_loss: 2115241981343956992.0000 - val_output_loss: 4230483962687913984.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00008: ReduceLROnPlateau reducing learning rate to 3.9062499126885086e-05.
Epoch 9/30
31/31 - 4s - loss: 2224229834496671744.0000 - output_loss: 4448459668993343488.0000 - variance_output_loss: 12.9960 - output_categorical_accuracy: 1.0000 - val_loss: 2171087963625095168.0000 - val_output_loss: 4342175927250190336.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00009: ReduceLROnPlateau reducing learning rate to 1.9531249563442543e-05.
Epoch 10/30
31/31 - 4s - loss: 2257878601158361088.0000 - output_loss: 4515757202316722176.0000 - variance_output_loss: 12.9763 - output_categorical_accuracy: 1.0000 - val_loss: 2190370786236170240.0000 - val_output_loss: 4380741572472340480.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00010: ReduceLROnPlateau reducing learning rate to 9.765624781721272e-06.
Epoch 11/30
31/31 - 4s - loss: 2272228464851419136.0000 - output_loss: 4544456929702838272.0000 - variance_output_loss: 12.9738 - output_categorical_accuracy: 1.0000 - val_loss: 2218048105247408128.0000 - val_output_loss: 4436096210494816256.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00011: ReduceLROnPlateau reducing learning rate to 4.882812390860636e-06.
Epoch 12/30
31/31 - 4s - loss: 2274702778330775552.0000 - output_loss: 4549405556661551104.0000 - variance_output_loss: 12.9753 - output_categorical_accuracy: 1.0000 - val_loss: 2198842523328184320.0000 - val_output_loss: 4397685046656368640.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00012: ReduceLROnPlateau reducing learning rate to 2.441406195430318e-06.
Epoch 13/30
31/31 - 5s - loss: 2289002339366862848.0000 - output_loss: 4578004678733725696.0000 - variance_output_loss: 12.9713 - output_categorical_accuracy: 1.0000 - val_loss: 2203501566411931648.0000 - val_output_loss: 4407003132823863296.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00013: ReduceLROnPlateau reducing learning rate to 1.220703097715159e-06.
Epoch 14/30
31/31 - 4s - loss: 2287103757663600640.0000 - output_loss: 4574207515327201280.0000 - variance_output_loss: 12.9667 - output_categorical_accuracy: 1.0000 - val_loss: 2208039938094530560.0000 - val_output_loss: 4416079876189061120.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00014: ReduceLROnPlateau reducing learning rate to 6.103515488575795e-07.
Epoch 15/30
31/31 - 4s - loss: 2287088776817672192.0000 - output_loss: 4574177553635344384.0000 - variance_output_loss: 12.8942 - output_categorical_accuracy: 1.0000 - val_loss: 2213005744922427392.0000 - val_output_loss: 4426011489844854784.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00015: ReduceLROnPlateau reducing learning rate to 3.0517577442878974e-07.
Epoch 16/30
31/31 - 4s - loss: 2291830420712456192.0000 - output_loss: 4583660841424912384.0000 - variance_output_loss: 12.9360 - output_categorical_accuracy: 1.0000 - val_loss: 2224411116476301312.0000 - val_output_loss: 4448822232952602624.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00016: ReduceLROnPlateau reducing learning rate to 1.5258788721439487e-07.
Epoch 17/30
31/31 - 4s - loss: 2284606766756921344.0000 - output_loss: 4569213533513842688.0000 - variance_output_loss: 12.9793 - output_categorical_accuracy: 1.0000 - val_loss: 2233647289027526656.0000 - val_output_loss: 4467294578055053312.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00017: ReduceLROnPlateau reducing learning rate to 7.629394360719743e-08.
Epoch 18/30
31/31 - 4s - loss: 2289540412869705728.0000 - output_loss: 4579080825739411456.0000 - variance_output_loss: 12.9652 - output_categorical_accuracy: 1.0000 - val_loss: 2200598443397742592.0000 - val_output_loss: 4401196886795485184.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00018: ReduceLROnPlateau reducing learning rate to 3.814697180359872e-08.
Epoch 19/30
31/31 - 4s - loss: 2290412325590532096.0000 - output_loss: 4580824651181064192.0000 - variance_output_loss: 12.9728 - output_categorical_accuracy: 1.0000 - val_loss: 2234517827358818304.0000 - val_output_loss: 4469035654717636608.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00019: ReduceLROnPlateau reducing learning rate to 1.907348590179936e-08.
Epoch 20/30
31/31 - 6s - loss: 2284846597730729984.0000 - output_loss: 4569693195461459968.0000 - variance_output_loss: 12.9950 - output_categorical_accuracy: 1.0000 - val_loss: 2203913470955487232.0000 - val_output_loss: 4407826941910974464.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000

Epoch 00020: ReduceLROnPlateau reducing learning rate to 1e-08.
Epoch 21/30
31/31 - 6s - loss: 2288547004114010112.0000 - output_loss: 4577094008228020224.0000 - variance_output_loss: 12.9708 - output_categorical_accuracy: 1.0000 - val_loss: 2234764667719254016.0000 - val_output_loss: 4469529335438508032.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 22/30
31/31 - 5s - loss: 2291752630264791040.0000 - output_loss: 4583505260529582080.0000 - variance_output_loss: 12.9839 - output_categorical_accuracy: 1.0000 - val_loss: 2213216163960193024.0000 - val_output_loss: 4426432327920386048.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 23/30
31/31 - 5s - loss: 2291957276866510848.0000 - output_loss: 4583914553733021696.0000 - variance_output_loss: 12.9879 - output_categorical_accuracy: 1.0000 - val_loss: 2231234135882465280.0000 - val_output_loss: 4462468271764930560.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 24/30
31/31 - 4s - loss: 2285794926509686784.0000 - output_loss: 4571589853019373568.0000 - variance_output_loss: 12.9607 - output_categorical_accuracy: 1.0000 - val_loss: 2206375827245891584.0000 - val_output_loss: 4412751654491783168.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 25/30
31/31 - 4s - loss: 2292182264433344512.0000 - output_loss: 4584364528866689024.0000 - variance_output_loss: 12.9320 - output_categorical_accuracy: 1.0000 - val_loss: 2210463536600055808.0000 - val_output_loss: 4420927073200111616.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 26/30
31/31 - 5s - loss: 2282710384076914688.0000 - output_loss: 4565420768153829376.0000 - variance_output_loss: 13.0136 - output_categorical_accuracy: 1.0000 - val_loss: 2210336817884954624.0000 - val_output_loss: 4420673635769909248.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 27/30
31/31 - 5s - loss: 2291233523337527296.0000 - output_loss: 4582467046675054592.0000 - variance_output_loss: 13.0116 - output_categorical_accuracy: 1.0000 - val_loss: 2217416573256204288.0000 - val_output_loss: 4434833146512408576.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 28/30
31/31 - 5s - loss: 2282051364295016448.0000 - output_loss: 4564102728590032896.0000 - variance_output_loss: 12.9879 - output_categorical_accuracy: 1.0000 - val_loss: 2215625468814557184.0000 - val_output_loss: 4431250937629114368.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 29/30
31/31 - 4s - loss: 2288319954962874368.0000 - output_loss: 4576639909925748736.0000 - variance_output_loss: 13.0106 - output_categorical_accuracy: 1.0000 - val_loss: 2219236677317033984.0000 - val_output_loss: 4438473354634067968.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Epoch 30/30
31/31 - 5s - loss: 2287157083977547776.0000 - output_loss: 4574314167955095552.0000 - variance_output_loss: 12.9335 - output_categorical_accuracy: 1.0000 - val_loss: 2226792658662064128.0000 - val_output_loss: 4453585317324128256.0000 - val_variance_output_loss: 12.8177 - val_output_categorical_accuracy: 1.0000
Completed Training, 140.00s in total

I see that issue #6 makes reference to that, I am not being able to find the solution though.

Thank you, Lucia

henrysky commented 3 years ago

I think I actually have fixed the issue (as you can see the notebook was updated earlier this year and no issue was observed, uncertainty looks reasonable).

You are using MNIST or your own images? Are you sure there is no invalid value in your images and the classes (like nan or 99999 something like that)?

luantunez commented 3 years ago

Yes, I am using the model on my own dataset. They are color image patches of the skin with a lesion, from the 2018 and 2019 datasets from this competition: https://challenge.isic-archive.com/data and I want yo classify them in 7 different classes (7 diseases). Without a Bayesian network I am getting very good metrics (Resnet50 wit transfer learning). With a regular Bayesian Neural Network I am getting a reasonable loss drop but a non improving accuracy (I think that is because of an issue with tensorflow probability layers and tensorflow 2.4.1, where the gradients are not being updated, maybe you have dealt with that issue). But, with this astroNN model, I have this problem with the loss. Do you have any ideas?

Thank you very much for your dedication!

luantunez commented 3 years ago

I managed to get lower loss numbers, thought they still don´t change consistently and the accuracy is always 1.0, but the net is predicting 0 for every image. Do you see the problem?

henrysky commented 3 years ago

do you have a notebook or code to show what have you been doing? it is very weird that output_categorical_accuracy, val_output_categorical_accuracy are always 1.

luantunez commented 3 years ago

Yes sure, here it is https://colab.research.google.com/drive/1DdNOz41441KGkW2_53ODx-u_kPtwsf1-?usp=sharing

luantunez commented 3 years ago

Sorry, I changed to this link https://colab.research.google.com/drive/1KOk954wweKBCu_bm7XQa6jsrqyIEY6Fu?usp=sharing you have access now

henrysky commented 3 years ago

One immediate issue I have noticed is your images are astype(int) where they should be float

luantunez commented 3 years ago

I have actually done that because I have a third dimension that is color depth, and my images wouldn´t be plotted otherwise. I have tried training with float32 though, and I get the same result:


Number of Training Data: 1800, Number of Validation Data: 200
Epoch 1/30
28/28 - 32s - loss: 16389206.0000 - output_loss: 32778396.0000 - variance_output_loss: 12.9051 - output_categorical_accuracy: 1.0000 - val_loss: 16378944.0000 - val_output_loss: 32757874.0000 - val_variance_output_loss: 13.0521 - val_output_categorical_accuracy: 1.0000
Epoch 2/30
28/28 - 34s - loss: 16658005.0000 - output_loss: 33315986.0000 - variance_output_loss: 12.9062 - output_categorical_accuracy: 1.0000 - val_loss: 16560405.0000 - val_output_loss: 33120800.0000 - val_variance_output_loss: 13.0521 - val_output_categorical_accuracy: 1.0000

Epoch 00002: ReduceLROnPlateau reducing learning rate to 7.81249980263965e-07.
Epoch 3/30
28/28 - 33s - loss: 16761271.0000 - output_loss: 33522528.0000 - variance_output_loss: 12.8862 - output_categorical_accuracy: 1.0000 - val_loss: 16862254.0000 - val_output_loss: 33724492.0000 - val_variance_output_loss: 13.0521 - val_output_categorical_accuracy: 1.0000

Epoch 00003: ReduceLROnPlateau reducing learning rate to 3.906249901319825e-07.
Epoch 4/30