poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
469 stars 32 forks source link

[Bug] Problem with computing metrics #216

Closed organic-chemistry closed 2 years ago

organic-chemistry commented 2 years ago

Describe the bug Hi, when I am using the fit function I have an error message that the update function is not provided with y_true and y_pred. It seems to be coming from the metrics of the model, because if I comment the metrics line I have no error

TypeError: update() missing 2 required positional arguments: 'y_true' and 'y_pred'

Minimal code to reproduce Small snippet that contains a minimal amount of code.

import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import elegy as eg

class eCNN(eg.Module):
    """A simple CNN model."""

    @eg.compact
    def __call__(self, x):
        x=eg.Conv(10,kernel_size=(10,))(x)
        x=jax.nn.relu(x)
        x = eg.Linear(1)(x)
        x=jax.nn.sigmoid(x)
        return x

n=200
X_train = np.random.rand(n*100).reshape(n,100)
y_train = np.random.rand(n).reshape(n,1)
print(X_train.shape)
print(y_train.shape)

model = eg.Model(
    module=eCNN(),
    loss=[
        eg.losses.MeanSquaredError(),
    ],
    metrics=eg.metrics.MeanSquareError(),  #Line to be commented to get rid of the error
    optimizer=optax.rmsprop(1e-3),
)

model.fit(X_train,y_train,
    epochs=10,
    batch_size=20,
    #validation_data=0.1,
    shuffle=False,
    callbacks=[eg.callbacks.TensorBoard("summaries")]
    )

Library Info Please provide os info and elegy version.

import elegy
print(elegy.__version__) 
# 0.8.4
organic-chemistry commented 2 years ago

It is probably related because I have the same error when defining a custom loss: (and replacing model by the following code)

class BCE(eg.Loss): 
    def call(self, y_true, y_pred):
        return -jnp.mean(y_true*jnp.log(y_pred+1e-7) + (1-y_true)*jnp.log(1-y_pred+1e-7))
model = eg.Model(
    module=eCNN(),
    loss=[

        BCE(),
    ],
  #  metrics=eg.metrics.MeanSquareError(),
    optimizer=optax.rmsprop(1e-3),
)
cgarciae commented 2 years ago

Hey @organic-chemistry! I think the use of Losses and Metrics within Elegy needs to be properly documented.

Elegy uses a simple name-based dependency injection system, meaning there is a fix set of names you can use for the arguments of the Loss.call and Metric.update methods:

https://github.com/poets-ai/elegy/blob/546c50475ac55dcbf9d8dd811c0536d1ad589f38/elegy/model/model.py#L220-L233

where labels is usually contains the target key. Problem is that metrics.MeanSquareError and metrics.MeanAbsoluteError which come from Treex recently used the unsupported y_true and y_pred names (cgarciae/treex#55), this should be fixed soon.

On the other hand, to fix your BCE loss just change y_true -> target and y_pred -> preds.

BTW: unless its for pedagogical reason, you can use eg.losses.Crossentropy(binary=True) if you want binary cross entropy.

organic-chemistry commented 2 years ago

Ok, thank you indeed it worked. The example was indeed for pedagogical reason. I created it according to an example from the doc here: https://poets-ai.github.io/elegy/basic-api/modules-losses-metrics/ (The paragraph about losses). Thank you,

cgarciae commented 2 years ago

@organic-chemistry thanks for the report! Recent refactor broke a lot of the documentation, I'll put an issue to remove old links. Some of this now lives in Treex and should be documented there. Sorry for the confusion 😅

organic-chemistry commented 2 years ago

Ok. Do you mean that you are going to stop working on elegy, and that the 'new' version is Treex ? By the way it is not related but thank you for the libraries that you develop and the articles that you write, I found really interesting the one on quantile regression.

cgarciae commented 2 years ago

Ok. Do you mean that you are going to stop working on elegy, and that the 'new' version is Treex ?

Oh no, sorry for the confusion. Treex is a low level library that implements Modules, Losses and Metrics, while Elegy is a high-level API. Elegy existed before Treex, but once Treex was stable it made sense to refactor Elegy on top of Treex as it simplified the codebase.

By the way it is not related but thank you for the libraries that you develop and the articles that you write

Thanks ☺