awslabs / keras-apache-mxnet

[DEPRECATED] Amazon Deep Learning's Keras with Apache MXNet support
https://github.com/awslabs/keras-apache-mxnet/wiki
Other
290 stars 65 forks source link

Keras switch (K.switch) with mxnet backend #247

Closed kirimaru-jp closed 5 years ago

kirimaru-jp commented 5 years ago

Hi all,

I'd like to use Keras switch function (Keras-MXNet 2.2.4.2) with mxnet-cu101 on Windows 10 x64, so I'm trying to run the following code but it didn't work

from keras import backend as K
from keras.layers import Dense
from keras.models import Sequential

def customized_loss(y_true, y_pred):
    loss = K.switch(K.equal(y_true, -1), 0, K.square(y_true-y_pred))
    return K.sum(loss)

model = Sequential([ Dense(3, input_shape=(4,)) ])
model.compile(loss=customized_loss, optimizer='sgd')

The code is from here, but I used Mxnet instead of Theano.

It showed

Using MXNet backend
Traceback (most recent call last):

  File "<ipython-input-2-22cee122aac0>", line 10, in <module>
    model.compile(loss=customized_loss, optimizer='sgd')

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\backend\mxnet_backend.py", line 5392, in compile
    sample_weight_mode, **kwargs)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 342, in compile
    sample_weight, mask)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training_utils.py", line 412, in weighted
    score_array = fn(y_true, y_pred)

  File "<ipython-input-2-22cee122aac0>", line 6, in customized_loss
    loss = K.switch(K.equal(y_true, -1), 0, K.square(y_true-y_pred))

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\backend\mxnet_backend.py", line 94, in func_wrapper
    train_symbol = func(*args, **kwargs)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\backend\mxnet_backend.py", line 3162, in switch
    else_expression, KerasSymbol))

AssertionError

How could I fix that? Thank you in advance!

kirimaru-jp commented 5 years ago

I found a solution here, and it worked! https://stackoverflow.com/a/48906681