vlfeat / matconvnet

MatConvNet: CNNs for MATLAB
Other
1.4k stars 753 forks source link

How to train a network to learn a function using MatConvNet? #418

Closed prashnani closed 7 years ago

prashnani commented 8 years ago

Hi everyone,

I am trying to understand how to use MatConvNet to learn a function given I/O pairs for training.

I want to discover a function that maps a 486-dimensional input vector to a 1D output value. However, I am unable to get the network to work properly, I need help locating my mistake. Following are the details of what I have done, please let me know if any other info is needed.

Here is my training data layout (1 million samples):

imdb.images ans = data: [4-D single]
label: [1x1000000 single] set: [1x1000000 double]

size(imdb.images.data) ans = 1 1 486 1000000

I am using a multi-layer fully connected neural network with one input layer of 486 neurons, one hidden layer with 100 neurons and one output layer with 1 neuron.

I have defined the network using fully connected layers as follows:

trainOpts.batchSize = 10000 ; trainOpts.numEpochs = 100 ; trainOpts.continue = false ; trainOpts.gpus = [1]; trainOpts.learningRate = 0.001 ; trainOpts.numEpochs = 100 ; trainOpts.expDir = 'xyz' ;

f = 1/100; net.layers = {}; net.layers{end+1} = struct('type','conv',... 'weights',{{frandn(1,1,486,100,'single'),zeros(1,100,'single')}},... 'stride',1,... 'pad',0); net.layers{end+1} = struct('type','sigmoid'); net.layers{end+1} = struct('type','conv',... 'weights',{{frandn(1,1,100,1,'single'),zeros(1,1,'single')}},... 'stride',1,... 'pad',0);
net.layers{end+1} = struct('type','sigmoid'); net.layers{end+1} = struct('type','nnL2');

The function in the loss layer "vl_nnL2" is added to vl_simplenn as discussed in #15 . It is the L2 loss function. I have also changed the error estimation in cnn_train to the following (this is probably along the lines of the change to cnn_trian suggested by @bazilas in #15 ):

% ------------------------------------------------------------------------- function err = error_sqerror(opts, labels, res) % ------------------------------------------------------------------------- predictions = gather(res(end-1).x) ; % be resilient to badly formatted labels if numel(labels) == size(predictions, 4) labels = reshape(labels,1,1,1,[]) ; end error = (abs(labels-predictions).^2); err = sum(squeeze(error));

Here is the plot from the training:

image

What am I missing?

Thank you, Ekta

finlay-liu commented 8 years ago

please see #15 #97 and #115

h612 commented 7 years ago

Hi. Did you find a solution? @prashnani