dmlc / MXNet.jl

MXNet Julia Package - flexible and efficient deep learning in Julia
371 stars 70 forks source link

ArrayDataProvider input format issue #236

Open Petterhg opened 7 years ago

Petterhg commented 7 years ago

So I'm trying to do image classification with Arraydataprovider without success and I'm pretty confident it is related to that MXNet doesn't feed in the data correctly. Raw data input is 66 b/w images.

imageFolder = "./datatide1/"
filenames = map(x -> replace(x, ".jpg", ""), readdir(imageFolder))
labels = map(x -> split(x, "_")[1], filenames)
classes = unique(labels)
classDict = Dict(classes[i] => i for i=1:length(classes))
labelSize = 1
width = 75
heigth = 75

data = zeros(Float32, heigth, width, length(filenames)) 
label = zeros(Int64, length(filenames))

for i in 1:length(filenames)
    image = load(string("datatide1/", filenames[i], ".jpg"))
    image_resized = imresize(image, heigth, width)
    temp = convert(Array{Float32}, image_resized)
    data[:,:,i] = temp 
    label[i] = classDict[labels[i]]
end

So the data is first a 75x75x66 matrix with two classes (faces and houses).

mxData = mx.Variable(:data)
mxLabel  = mx.Variable(:softmax_label)

batch_size = 2

input = mx.Reshape(mxData, shape=(width, heigth,1, batch_size))
eval = mx.Reshape(mxData, shape=(width, heigth,1, batch_size))

train_provider = mx.ArrayDataProvider(:data => mx.NDArray(data),
    :softmax_label => label, 
    batch_size=batch_size, 
    shuffle=true)

So now the data is reshaped to fit Conv net input that needs to be a 4D vector.

conv1 = @mx.chain mx.Convolution(input, kernel=(3,3), num_filter=50)  =>
                  mx.Activation(act_type=:tanh) =>
                  mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))

conv2 = @mx.chain mx.Convolution(conv1, kernel=(5,5), num_filter=50) =>
                  mx.Activation(act_type=:tanh) =>
                  mx.Pooling(pool_type=:max, kernel=(2,2), stride=(1,1))

conv3 = @mx.chain mx.Convolution(conv2, kernel=(3,3), num_filter=30) =>
                  mx.Activation(act_type=:tanh) =>
                  mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))

fc1   = @mx.chain mx.Flatten(conv3) =>
                  mx.FullyConnected(num_hidden=700) =>
                  mx.Activation(act_type=:relu) 

fc2   = @mx.chain mx.FullyConnected(fc1, num_hidden=400) =>
                  mx.Activation(act_type=:relu) 

fc3   = mx.FullyConnected(fc2, num_hidden=2) 
mlp  = mx.SoftmaxOutput(fc3, name=:softmax)

I have varied the hyper parameters a lot now and always get stuck in the same minima of 0.63% accuracy.

What makes me think that there is an error with the data feed is that when I'm predicting on two random samples (one house, one face) not previously seen by the model, it always gives out the same probabilities:

2×2 Array{Float32,2}:
 0.128382  0.128382
 0.871618  0.871618

And this just doesn't make sense since one sample is house and one is face PLUS that I've changed the hyper parameters drastically (including optimization algorithms).

Can anyone see if I'm doing something wrong with the data input or if there is something else fundamentally strange with my implementation? @pluskid

vchuravy commented 7 years ago

Why are you converting the Julia array to an mx.NDArray before passing it to ArrayDataProvider?

Petterhg commented 7 years ago

@vchuravy out of desperation. But I can't really see how that would affect anything?

vchuravy commented 7 years ago

I am a bit worried that mx.Reshape get's confused between Julias ordering of dimensions: (W, H, Channel, Sample) and mxnet ordering (C, S, H, W). I would recommend doing the reshaping in Julia.

There are two methods to investigate if you layout is correct one is:

arg_shapes, out_shapes, aux_shapes = mx.infer_shape(net, input=(75,75,1,66), softmax_label =(66))
println("Arguments:")
for (n,s) in zip(mx.list_arguments(net), arg_shapes)
  println("\t$n => $s")
end
println("Outputs:")
for (n,s) in zip(mx.list_outputs(net), out_shapes)
  println("\t$n => $s")
end

and mx.debug_str

exec = mx.simple_bind(net, mx.cpu(), input=(75,75,1,66))
dbg_str = mx.debug_str(exec)
Petterhg commented 7 years ago

So net in my case would be mxData right? If I build the data according to (C, S, H, W):

data = zeros(Float32, 1, length(filenames), heigth, width)  # <-------------------------
label = zeros(Int64, length(filenames))

for i in 1:length(filenames)
    image = load(string("datatide1/", filenames[i], ".jpg"))
    image_resized = imresize(image, heigth, width)
    temp = convert(Array{Float32}, image_resized)
    data[1,i,:,:] = temp             # <-------------------------
    label[i] = classDict[labels[i]]
end

mxData = mx.Variable(:data)
mxLabel  = mx.Variable(:softmax_label)

batch_size = 2

train_provider = mx.ArrayDataProvider(:data => data,
    :softmax_label => label, 
    batch_size=batch_size, 
    shuffle=true)

I get shape mismatch (the same if I transpose):

AssertionError: Number of samples in softmax_label is mismatch with data

Or... I assume you want me to reshape the Julia array to the format MXNet wants? Or should I feed it in as the Julia array, i.e (W, H, Channel, Sample)?

vchuravy commented 7 years ago

MXNet.jl will handle the transformation from Julia order to C/C++ order for you.

dataP =ArrayDataProvider(:data => reshape(data, (75, 75, 1, 66)))

Petterhg commented 7 years ago

Ok, still the same weird results though..

data = zeros(Float32, heigth, width, length(filenames))
label = zeros(Int64, length(filenames))

for i in 1:length(filenames)
    image = load(string("datatide1/", filenames[i], ".jpg"))
    image_resized = imresize(image, heigth, width)
    temp = convert(Array{Float32}, image_resized)
    data[ :, :, i] = temp 
    label[i] = classDict[labels[i]]
end

mxData = mx.Variable(:data)
mxLabel  = mx.Variable(:softmax_label)
batch_size = 1

train_provider = mx.ArrayDataProvider(:data => reshape(data, (heigth,width,1,length(filenames))),
    :softmax_label => label, 
    batch_size=batch_size, 
    shuffle=true)

conv1 = @mx.chain mx.Convolution(mxData, kernel=(3,3), num_filter=50)  =>
                  mx.Activation(act_type=:tanh) =>
                  mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))

conv2 = @mx.chain mx.Convolution(conv1, kernel=(5,5), num_filter=50) =>
                  mx.Activation(act_type=:tanh) =>
                  mx.Pooling(pool_type=:max, kernel=(2,2), stride=(1,1))

conv3 = @mx.chain mx.Convolution(conv2, kernel=(3,3), num_filter=30) =>
                  mx.Activation(act_type=:tanh) =>
                  mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))

fc1   = @mx.chain mx.Flatten(conv3) =>
                  mx.FullyConnected(num_hidden=100) =>
                  mx.Activation(act_type=:relu) 

fc2   = @mx.chain mx.FullyConnected(fc1, num_hidden=100) =>
                  mx.Activation(act_type=:relu) 

fc3   = mx.FullyConnected(fc2, num_hidden=2) 

mlp  = mx.SoftmaxOutput(fc3, name=:softmax)

#------------------------------------------------
model = mx.FeedForward(mlp, context=[mx.gpu(0)])

optimizer = mx.SGD(lr=0.5, momentum=0.8, weight_decay=0.00001) 

mx.fit(model,
    optimizer, 
    train_provider,
    eval_metric=mx.Accuracy(),
    n_epoch=150, 
    )

End accuracy 0.61, but just to point out that it started the training on 0.61 as well... Output from predict using one face and one house:

2×2 Array{Float32,2}:
 0.0354191  0.0354191
 0.964581   0.964581

Could it really be that it hasn't learned ANYTHING? Here are two example pictures: face_00001 house_00001

vchuravy commented 7 years ago

Your label is between 0 and 1? You can use mx.MultiACE(2) to see the logloss associated with each class.

From a purely network design perspective, I would use mx.ADAM instead of mx.SGD and use a mx.XavierInitializer . The number of channels for your convolutional network should also increase (as an example, 64, 64, 128, 128, FC, FC instead of your current 50, 50, 30, FC, FC) and I would use the same activation function throughout the network (relu or LeakyReLU) and I have found BatchNorm to help as well.

Have you experimented with the size of your FC layers?

Petterhg commented 7 years ago

Yes labels are: face = 0 house = 1

I have experimented with the architecture (also having more and wider FC's) but what strikes me as weird is just that it gives the exact same probability for both images with mx.predict. Isn't that super strange? I mean even if it hasn't learned anything. It feels to me as it is scoring on the same image all the time or something.

vchuravy commented 7 years ago

I agree that it is a bit weird and it looks like you are hitting a local optima over and over again.

Is your dataset balanced and are you using a validation dataset? (Also check that your dataset actually contains what you expect by converting data back to Images). How different is the exposure between images? I found that BatchNorm helps a lot with when the data is differ in intensity (you could also normalise beforehand)

Petterhg commented 7 years ago

Yes I have verified that the data is what it is supposed to be. I have now implemented dropout, batch norm, two more conv layers and expanded the FC's with the same result. When I do temp = convert(Array{Float32}, image_resized) it actually normalizes all values between 0 and 1 (pretty sweet). I will continue investigating and update here if I find the solution! Thanks for all the help!

Petterhg commented 7 years ago

SOLVED

This weird behavior originated from a too small batch size! When I increased the batch size everything started working as expected and I reached a 99 % accuracy on both classes within 100 epochs.

- HOW is it possible that the batch size influenced the result of the model in such a way!?

vchuravy commented 7 years ago

What was your batch size if I may ask. Also which version of MXNet proper are you using and which optimizer?

-V

On Mon, 17 Apr 2017 at 21:08 Petter notifications@github.com wrote:

SOLVED

This weird behavior originated from a too small batch size! When I increased the batch size everything started working as expected and I reached a 99 % accuracy on both classes within 100 epochs.

- HOW is it possible that the batch size influenced the result of the model in such a way!?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/dmlc/MXNet.jl/issues/236#issuecomment-294472963, or mute the thread https://github.com/notifications/unsubscribe-auth/AAI3ama7O29z1kSaXAXSyLlO28vNdadTks5rw1YggaJpZM4M-sex .

Petterhg commented 7 years ago

Changed from 1 to 5 and I get good results with all adaptive lr algorithms! This in particular was with RMSProp, but got just as good with ADAM.

vchuravy commented 7 years ago

@pluskid Do you think that we are hitting a particular cornercase here or is that just a side effect of using SGD based methods?

As far as I understand it SGD estimates the real gradient via a sample of the data. SGD should still work on one sample.

pluskid commented 7 years ago

Interesting. Batch size = 1 should work theoretically (with properly chosen hyperparameters), despite being inefficient. @vchuravy I'm leaving for a flight in 2 minutes, do you have a chance to run our existing examples (such as MNIST or even simpler ones that use Array provider directly) by changing the batch size to 1? Just to see if it is some bug.