adriangb / scikeras

Scikit-Learn API wrapper for Keras.
https://www.adriangb.com/scikeras/
MIT License
242 stars 50 forks source link

bug when the metric is a string #300

Closed michael-shire-rs closed 1 year ago

michael-shire-rs commented 1 year ago

https://github.com/adriangb/scikeras/blob/32fc6af3ffc85884ee8ce455c33b0bfe39126491/scikeras/utils/__init__.py#LL109C1-L112C30

when fn_or_cls is a string, this function breaks because there is no __name__

I propose something like this

if hasattr(fn_or_cls, '__name__'):
    return fn_or_cls.__name__
return fn_or_cls

I ran into an error when the metric was "loss"

adriangb commented 1 year ago

That seems reasonable to me. A PR with a test would be very welcome 😀

michael-shire-rs commented 1 year ago

haha, of course. i'll get on that. thanks.

edit: having trouble with Poetry on my macbook and making a PR. My software-deveoper-fu is not strong. :-(

djsegal commented 1 year ago

Running into this now

File ~/anaconda3/envs/x/lib/python3.11/site-packages/scikeras/wrappers.py:735, in BaseWrapper.fit(self, X, y, sample_weight, **kwargs)
    730 kwargs["epochs"] = kwargs.get(
    731     "epochs", getattr(self, "fit__epochs", self.epochs)
    732 )
    733 kwargs["initial_epoch"] = kwargs.get("initial_epoch", 0)
--> 735 self._fit(
    736     X=X, y=y, sample_weight=sample_weight, warm_start=self.warm_start, **kwargs,
    737 )
    739 return self

File ~/anaconda3/envs/x/lib/python3.11/site-packages/scikeras/wrappers.py:900, in BaseWrapper._fit(self, X, y, sample_weight, warm_start, epochs, initial_epoch, **kwargs)
    896 X = self.feature_encoder_.transform(X)
    898 self._check_model_compatibility(y)
--> 900 self._fit_keras_model(
    901     X,
    902     y,
    903     sample_weight=sample_weight,
    904     warm_start=warm_start,
    905     epochs=epochs,
    906     initial_epoch=initial_epoch,
    907     **kwargs,
    908 )

File ~/anaconda3/envs/x/lib/python3.11/site-packages/scikeras/wrappers.py:510, in BaseWrapper._fit_keras_model(self, X, y, sample_weight, warm_start, epochs, initial_epoch, **kwargs)
    508 for key, val in hist.history.items():
    509     try:
--> 510         key = metric_name(key)
    511     except ValueError as e:
    512         # Keras puts keys like "val_accuracy" and "loss" and
    513         # "val_loss" in hist.history
    514         if "Unknown metric function" not in str(e):

File ~/anaconda3/envs/x/lib/python3.11/site-packages/scikeras/utils/__init__.py:112, in metric_name(metric)
    110 if isinstance(fn_or_cls, Metric):
    111     return _camel2snake(fn_or_cls.__class__.__name__)
--> 112 return fn_or_cls.__name__

AttributeError: 'str' object has no attribute '__name__'