dmlc / MXNet.jl

MXNet Julia Package - flexible and efficient deep learning in Julia
371 stars 70 forks source link

Symbol.InterShapeKeyword argument name label not found. #20

Closed Andy-P closed 8 years ago

Andy-P commented 8 years ago

I have just begun experimenting with MxNet in hopes of porting all of my https://github.com/Andy-P/RecurrentNN.jl work to it. I'm still getting used to the symbolic approach and thought I would start with a simple example.

I created some synthetic data and tried to run a MLP based on your minst example using the LinearRegressionOutput function instead of softmax. Unfortunately, it seems I don't understand how the various parts and labels fit together. I get a Symbol.InterShapeKeyword argument name mylabel not found. error. Obviously I am missing something, but I am not sure what.

Any help would be appreciated in getting over the initial hump.

Thanks,

Andre

using MXNet

# create some synthetic data
xs = collect(-0.49:0.005:0.5)
rnd() = rand()*0.1-0.05
y = map(x->x+0.3*sin(4*pi*x),xs)
ynoise = map(x->x + rnd()*1.,y);

# put it in data providers
train_provider = mx.ArrayDataProvider(:mydata => xs, :mylabel => ynoise; batch_size=10)
eval_provider  = mx.ArrayDataProvider(:mydata => xs, :mylabel => y;      batch_size=10)

MXNet.mx.ArrayDataProvider(Array{Float32,N}[Float32[-0.49,-0.485,-0.48,-0.475,-0.47,-0.465,-0.46,-0.455,-0.45,-0.445 … 0.455,0.46,0.465,0.47,0.475,0.48,0.485,0.49,0.495,0.5]],[:mydata],Array{Float32,N}[Float32[-0.4524,-0.428786,-0.405393,-0.382295,-0.359563,-0.337266,-0.315474,-0.294252,-0.273664,-0.253773 … 0.294252,0.315474,0.337266,0.359563,0.382295,0.405393,0.428786,0.4524,0.476163,0.5]],[:mylabel],10,199,false,0.0f0,0.0f0)

#-- Option 2: using the mx.chain macro
data = mx.Variable(:mydata)
fc1  = mx.FullyConnected(data = data, name=:fc1, num_hidden=128)
act1 = mx.Activation(data = fc1, name=:relu1, act_type=:relu)
fc2  = mx.FullyConnected(data = act1, name=:fc2, num_hidden=64)
act2 = mx.Activation(data = fc2, name=:relu2, act_type=:relu)
fc3  = mx.FullyConnected(data = act2, name=:fc3, num_hidden=10)
mlp  = mx.LinearRegressionOutput(data = fc3,label=:mylabel, name=:linear) 

MXNet.mx.Symbol(MXNet.mx.MX_SymbolHandle(Ptr{Void} @0x00007fe6bc0f61d0))

# setup model
model = mx.FeedForward(mlp, context=mx.cpu())

MXNet.mx.FeedForward(MXNet.mx.Symbol(MXNet.mx.MX_SymbolHandle(Ptr{Void} @0x00007fe6bc0f61d0)),[CPU0],#undef,#undef,#undef)

optimizer = mx.SGD(lr=0.1, momentum=0.9, weight_decay=0.00001)

MXNet.mx.SGD(0,0,MXNet.mx.SGDOptions(0.1,0.9,1.0e-5,1.0,0,MXNet.mx.FixedLearningRateScheduler(0.1),MXNet.mx.FixedMomentumScheduler(0.9)))

# fit parameters
mx.fit(model, optimizer, train_provider, n_epoch=1)
INFO: Start training on [CPU0]
INFO: Initializing parameters...

LoadError: MXNet.mx.MXError("[17:53:07] src/symbol/symbol.cc:103: Symbol.InterShapeKeyword argument name mylabel not found.\nCandidate arguments:\n\t[0]mydata\n\t[1]fc1_weight\n\t[2]fc1_bias\n\t[3]fc2_weight\n\t[4]fc2_bias\n\t[5]fc3_weight\n\t[6]fc3_bias\n\t[7]linear_label\n")
while loading In[11], in expression starting on line 2

 [inlined code] from /Users/andrep/.julia/v0.4/MXNet/src/base.jl:57

 in __infer_shape#2__ at /Users/andrep/.julia/v0.4/MXNet/src/symbol.jl:56

 in init_model at /Users/andrep/.julia/v0.4/MXNet/src/model.jl:66

 in _init_model at /Users/andrep/.julia/v0.4/MXNet/src/model.jl:93

 in fit at /Users/andrep/.julia/v0.4/MXNet/src/model.jl:226

[17:53:07] ./dmlc-core/include/dmlc/logging.h:208: [17:53:07] src/symbol/symbol.cc:103: Symbol.InterShapeKeyword argument name mylabel not found.
Candidate arguments:
    [0]mydata
    [1]fc1_weight
    [2]fc1_bias
    [3]fc2_weight
    [4]fc2_bias
    [5]fc3_weight
    [6]fc3_bias
    [7]linear_label
pluskid commented 8 years ago

I think instead of pass a :mylabel, you might need to create a Variable object and pass that instead, just like what was done for mydata.

pluskid commented 8 years ago

Note according to the doc here. The label argument is of type Symbol. This is really confusing because Julia has a built-in type Symbol (which we refer to as Base.Symbol as for the name argument). The Symbol should be a MXNet symbol, which could be created like

label_sym = mx.Variable(:mylabel)

side notes

I somehow want to rename the mx.Symbol type in Julia to something else because it is seriously conflicting with the Julia built-in Symbol type. I'm sure new comers will always get confused even though I explicitly stated in the beginning of the document the different between Symbol and Base.Symbol. Maybe I will rename it to Node, or do you have other suggestions / concerns? @tqchen @antinucleon @piiswrong @mli

mli commented 8 years ago

Agree that even symbol is a little bit too general in mxnet. It is more like a symbolic expression. I think nose is fine

mli commented 8 years ago

I mean "node"

tqchen commented 8 years ago

mx.SymbolNode then. Since this is the type that rarely used explicitly, so it was fine for now I guess

tqchen commented 8 years ago

We can open a vote on mxnet issue as usual.

pluskid commented 8 years ago

I was talking about the Julia side only, but I'm open for renaming for libmxnet in general, that makes thing more consistent, though involves more work.

tqchen commented 8 years ago

I see, I think this part should be fine as it is more on type concept name(which rarely used as opposed to function name), as long as the naming is clear and we can see connection between projects, it is good

Andy-P commented 8 years ago

Indeed, my confusion was mostly caused by confusing Julia's Symbol with mx.Symbol. Thanks. Changing the model to...

    data = mx.Variable(:mydata)
    label_sym = mx.Variable(:mylabel)
    fc1  = mx.FullyConnected(data = data, name=:fc1, num_hidden=128)
    act1 = mx.Activation(data = fc1, name=:relu1, act_type=:relu)
    fc2  = mx.FullyConnected(data = act1, name=:fc2, num_hidden=64)
    act2 = mx.Activation(data = fc2, name=:relu2, act_type=:relu)
    fc3  = mx.FullyConnected(data = act2, name=:fc3, num_hidden=1)
    mlp  = mx.LinearRegressionOutput(data = fc3, label = label_sym, name=:linear)  

...solved the problem (I think). However, I am now getting the following error.

LoadError: MXNet.mx.MXError("[22:29:14] src/ndarray/ndarray.cc:159: 
Check failed: from.shape() == to->shape() operands shape mismatch")

which suggest that there is a mismatch in the size of my layers, no? I understand my model's dimensions as:

fc1 =  1 x 128
fc2 =  128 x 64
fc3 =  64 x 1

Which seems correct as I am giving it data that is a vector of length 199, and labels with the same dimensions. Not sure where I am still going wrong.

I have to say, I find it hard to understand the feedback. There isn't any indication which layers are misshaped. Some feedback on the expected shape vs the actual shape would make it easier to diagnose.

Thanks for your help,

Andre

pluskid commented 8 years ago

@Andy-P I tweaked your code a bit, and the followings are working for me:

using MXNet

# create some synthetic data
xs = collect(-0.49:0.005:0.5)
xs = reshape(xs, 1, length(xs))
rnd() = rand()*0.1-0.05
y = vec(map(x->x+0.3*sin(4*pi*x),xs))
ynoise = map(x->x + rnd()*1.,y);

# put it in data providers
train_provider = mx.ArrayDataProvider(:mydata => xs, :mylabel => ynoise; batch_size=10)
eval_provider  = mx.ArrayDataProvider(:mydata => xs, :mylabel => y;      batch_size=10)

#-- Option 2: using the mx.chain macro
data = mx.Variable(:mydata)
lbl  = mx.Variable(:mylabel)
fc1  = mx.FullyConnected(data = data, name=:fc1, num_hidden=128)
act1 = mx.Activation(data = fc1, name=:relu1, act_type=:relu)
fc2  = mx.FullyConnected(data = act1, name=:fc2, num_hidden=64)
act2 = mx.Activation(data = fc2, name=:relu2, act_type=:relu)
fc3  = mx.FullyConnected(data = act2, name=:fc3, num_hidden=1)
mlp  = mx.LinearRegressionOutput(data = fc3,label=lbl, name=:linear) 

# setup model
model = mx.FeedForward(mlp, context=mx.cpu())

optimizer = mx.SGD(lr=0.1, momentum=0.9, weight_decay=0.00001)

# fit parameters
mx.fit(model, optimizer, train_provider, n_epoch=10)

There are several things that might be causing problems:

  1. There was a bug in ArrayDataProvider, producing incorrect mini-batch. This is fixed in the master, I'm going to prepare a v0.0.4 release of MXNet.jl today. You could wait for that or use the latest master.
  2. The input needs to be a 2D tensor (i.e. a matrix), so I used reshape explicitly to make it of size (1,N) instead of (N,).
  3. However, the singleton dimension in the label get dropped automatically, so I have to use vec explicitly to convert the label from (1,N) back to shape (N,). I think this is very consistency, and will add an issue in the upstream libmxnet.
  4. Though not relevant to this example, I also realized that the linear regression output only accept scalar labels, which is a bit limited. We will try to extend this to multi-dimension output soon.
Andy-P commented 8 years ago

Reshaping the inputs did indeed fixed the shape error generated by fit!. However the accuracy is always 0.00. Is that due to the batch size problem mentioned above?

pluskid commented 8 years ago

@Andy-P I'm not sure. Which version are you running? Can you do a Pkg.update() to update to the latest MXNet.jl? Also, the default evaluation metric (Accuracy) is for classification. You might want to make one for regression purpose, this part can be written in Julia. See https://github.com/dmlc/MXNet.jl/blob/master/src/metric.jl . We will try to add more built-in metrics soon.

Andy-P commented 8 years ago

@pluskid Will upgrade. I suspect my problem with accuracy is, as you pointed out, just because of the fact that I am defaulting to a classification metric when I should write a custom one. Will try it out today

Andy-P commented 8 years ago

Pkg.update() solved the problem. I also wrote a custom MSE evaluation metric as below. I can make a pull request if you like.

    type MSE <: AbstractEvalMetric
      acc_sum  :: Float64
      n_sample :: Int

      MSE() = new(0.0, 0)
    end

    function _update_single_output(metric :: MSE, label :: NDArray, pred :: NDArray)
      label = copy(label) 
      pred  = copy(pred)

      n_sample = size(pred)[end]
      metric.n_sample += n_sample

      for i = 1:n_sample
        metric.acc_sum += (label[i] - pred[i])^2
        # println("$(label[i]) - $(pred[i]) ^2 => $((label[i]-pred[i])^2)")
      end
    end

    function update!(metric :: MSE, labels :: Vector{NDArray}, preds :: Vector{NDArray})
      @assert length(labels) == length(preds)
      for i = 1:length(labels)
        _update_single_output(metric, labels[i], preds[i])
      end
    end

    function get(metric :: MSE)
      return [(:MSE, metric.acc_sum / metric.n_sample)]
    end

    function reset!(metric :: MSE)
      metric.acc_sum  = 0.0
      metric.n_sample = 0
    end
pluskid commented 8 years ago

@Andy-P Thanks! Please make a PR, but please change acc_sum to mse_sum and remove the println statement. Thanks!