henrysky / astroNN

Deep Learning for Astronomers with Keras
http://astronn.readthedocs.io/
MIT License
193 stars 51 forks source link

Problem with "demo_tutorial/galaxy10/Galaxy10_Tutorial.ipynb #28

Open DrJieZheng opened 6 months ago

DrJieZheng commented 6 months ago

System information

Describe the problem

Describe the problem clearly here. Be sure to describe here why it's a bug in astroNN (instead of Tensorflow's problem) or a feature request.

When I tried to try the tutorial, the galaxy10net.train report an unexpected keyword argument 'sample_weight_mode'

Source code / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem.

%%time
galaxy10net.train(train_images, train_labels)
<timed eval>:3: UserWarning: Call to function train() is deprecated and will be removed in future. Use fit() instead.

Number of Training Data: 17646, Number of Validation Data: 1960
====Message from Normalizer====
You selected mode: 255
Featurewise Center: {'input': False}
Datawise Center: {'input': False} 
Featurewise std Center: {'input': False}
Datawise std Center: {'input': False} 
====Message ends====
====Message from Normalizer====
You selected mode: 0
Featurewise Center: {'output': False}
Datawise Center: {'output': False} 
Featurewise std Center: {'output': False}
Datawise std Center: {'output': False} 
====Message ends====

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File <timed eval>:3

File ~/.local/lib/python3.10/site-packages/astroNN/shared/warnings.py:55, in deprecated_copy_signature.<locals>.deco.<locals>.tgt(*args, **kwargs)
     49 warnings.warn(
     50     f"Call to function {target.__name__}() is deprecated and will be removed in "
     51     + f"future. Use {signature_source.__name__}() instead.",
     52     stacklevel=2,
     53 )
     54 inspect.signature(signature_source).bind(*args, **kwargs)
---> 55 return target(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/astroNN/models/base_cnn.py:702, in CNNBase.train(self, *args, **kwargs)
    700 @deprecated_copy_signature(fit)
    701 def train(self, *args, **kwargs):
--> 702     return self.fit(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/astroNN/models/base_cnn.py:394, in CNNBase.fit(self, input_data, labels, sample_weight)
    380 """
    381 Train a Convolutional neural network
    382 
   (...)
    391 :History: 2017-Dec-06 - Written - Henry Leung (University of Toronto)
    392 """
    393 # Call the checklist to create astroNN folder and save parameters
--> 394 self.pre_training_checklist_child(input_data, labels, sample_weight)
    396 reduce_lr = ReduceLROnPlateau(
    397     monitor="val_loss",
    398     factor=0.5,
   (...)
    403     verbose=self.verbose,
    404 )
    406 early_stopping = EarlyStopping(
    407     monitor="val_loss",
    408     min_delta=self.early_stopping_min_delta,
   (...)
    411     mode="min",
    412 )

File ~/.local/lib/python3.10/site-packages/astroNN/models/base_cnn.py:319, in CNNBase.pre_training_checklist_child(self, input_data, labels, sample_weight)
    315     norm_labels = self.labels_normalizer.normalize(labels, calc=False)
    316 if (
    317     self.keras_model is None
    318 ):  # only compile if there is no keras_model, e.g. fine-tuning does not required
--> 319     self.compile()
    321 norm_data = self._tensor_dict_sanitize(norm_data, self.keras_model.input_names)
    322 norm_labels = self._tensor_dict_sanitize(
    323     norm_labels, self.keras_model.output_names
    324 )

File ~/.local/lib/python3.10/site-packages/astroNN/models/base_cnn.py:235, in CNNBase.compile(self, optimizer, loss, metrics, weighted_metrics, loss_weights, sample_weight_mode)
    229     raise RuntimeError(
    230         'Only "regression", "classification" and "binary_classification" are supported'
    231     )
    233 self.keras_model = self.model()
--> 235 self.keras_model.compile(
    236     loss=loss_func,
    237     optimizer=self.optimizer,
    238     metrics=self.metrics,
    239     weighted_metrics=weighted_metrics,
    240     loss_weights=loss_weights,
    241     sample_weight_mode=sample_weight_mode,
    242 )
    244 # inject custom training step if needed
    245 try:

File ~/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ~/.local/lib/python3.10/site-packages/keras/src/utils/tracking.py:26, in no_automatic_dependency_tracking.<locals>.wrapper(*args, **kwargs)
     23 @wraps(fn)
     24 def wrapper(*args, **kwargs):
     25     with DotNotTrackScope():
---> 26         return fn(*args, **kwargs)

TypeError: Trainer.compile() got an unexpected keyword argument 'sample_weight_mode'

Suggestion

Which versions of the packages used in this tutorial?

henrysky commented 6 months ago

Thanks for the bug report!

Indeed this is an ongoing issue with the latest version of Tensorflow (which is separating Keras out again) and Keras v3. If you want to quickly train a neural network to classify Galaxy10, here is a notebook that fine-tunes ResNet-V2 with Keras v3 with Galaxy10 images loaded with astroNN.

https://drive.google.com/file/d/1GnrsZAPZFTfBrhuQ09zqh4n8x1QYEPOb/view?usp=sharing

Please let me know if the notebook works for you locally (it is unlikely you can run it online with Google Collab as you will get resource exhausted error due to limited compute resources there)