aamini / evidential-deep-learning

Learn fast, scalable, and calibrated measures of uncertainty using neural networks!
https://proceedings.neurips.cc/paper/2020/file/aab085461de182608ee9f607f3f7d18f-Paper.pdf
Apache License 2.0
432 stars 98 forks source link

Loss goes to NaN #5

Open markus-hinsche opened 3 years ago

markus-hinsche commented 3 years ago

For a regression task, I am using a mid-size CNN consisting of Conv and MaxPool layers in the first layers and Dense layers in the last layers.

This is how I integrate the evidential loss (Before I used MSE loss):

optimizer = tf.keras.optimizers.Adam(learning_rate=7e-7)
def EvidentialRegressionLoss(true, pred):
    return edl.losses.EvidentialRegression(true, pred, coeff=CONFIG.EDL_COEFF)
model.compile(
    optimizer=optimizer,
    loss=EvidentialRegressionLoss,
    metrics=["mae"]
)

This is how I integrated the layer DenseNormalGamma:

    # lots of ConvLayers
    model.add(layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same", activation="relu"))
    model.add(layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same", activation="relu"))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(1024, activation="relu"))
    model.add(layers.Dense(128, activation="relu"))

    model.add(edl.layers.DenseNormalGamma(1))  # Instead of Dense(1)

    return model

Here is the issue I am facing:

Is there any obvious mistake I make? Any thoughts and help appreciated

wanzysky commented 3 years ago

This is maybe because of https://github.com/aamini/evidential-deep-learning/blob/7a22a2c8f35f5a2ec18fd37068b747935ff85376/evidential_deep_learning/losses/continuous.py#L35 , where the log is not safe.

Bunnybeibei commented 22 hours ago

So hou

This is maybe because of

https://github.com/aamini/evidential-deep-learning/blob/7a22a2c8f35f5a2ec18fd37068b747935ff85376/evidential_deep_learning/losses/continuous.py#L35

, where the log is not safe.

I have met the same problem, could you tell me how to solve it?