alexbw / bayarea-dl-summerschool

Torch notebooks and slides for the Bay Area Deep Learning Summer School
Apache License 2.0
97 stars 30 forks source link

Training error. Help #4

Open ilichev-andrey opened 7 years ago

ilichev-andrey commented 7 years ago

Hello. I teach a neural network for two of my classes. Error occurs at the stage of training. How to fix it?

th> require 'nn'; th> trainset = torch.load('animals_peoples2.t7') th> testset = torch.load('animals_peoples2.t7') th> classes = {'animals', 'peoples'}

th> print(trainset) { data : ByteTensor - size: 17299x3x96x96 label : ByteTensor - size: 17299 }

th> print(#trainset.data) 17299 3 96 96 [torch.LongStorage of size 4]

th> setmetatable(trainset, ..> {__index = function(t, i) ..> return { ..> t.data[i], ..> t.label[i] ..> } ..> end} ..> );

th> function trainset:size() ..> return self.data:size(1) ..> end

th> trainset.data = trainset.data:double()

th> print(trainset:size()) 17299

th> print(trainset[33]) { 1 : DoubleTensor - size: 3x96x96 2 : 1 }

th> redChannel = trainset.data:select(2, 1)

th> print(#redChannel) 17299 96 96 [torch.LongStorage of size 3]

th> mean = {} -- store the mean, to normalize the test set in the future

th> stdv = {} -- store the standard-deviation for the future

th> for i=1,3 do -- over each image channel ..> mean[i] = trainset.data:select(2, 1):mean() -- mean estimation ..> print('Channel ' .. i .. ', Mean: ' .. mean[i]) ..> trainset.data:select(2, 1):add(-mean[i]) -- mean subtraction ..> ..> stdv[i] = trainset.data:select(2, i):std() -- std estimation ..> print('Channel ' .. i .. ', Standard Deviation: ' .. stdv[i]) ..> trainset.data:select(2, i):div(stdv[i]) -- std scaling ..> end Channel 1, Mean: 0 Channel 1, Standard Deviation: 0 Channel 2, Mean: nan Channel 2, Standard Deviation: 0 Channel 3, Mean: nan Channel 3, Standard Deviation: 0

th> net = nn.Sequential() th> net:add(nn.SpatialConvolution(3, 6, 9, 9)) -- 3 input image channels, 6 output channels, 9x9 convolution kernel th> net:add(nn.ReLU()) -- non-linearity th> net:add(nn.SpatialMaxPooling(2,2,2,2)) -- A max-pooling operation that looks at 2x2 windows and finds the max. th> net:add(nn.SpatialConvolution(6, 16, 9, 9)) th> net:add(nn.ReLU()) -- non-linearity th> net:add(nn.SpatialMaxPooling(2,2,2,2)) th> net:add(nn.View(16x9x9)) -- reshapes from a 3D tensor of 16x9x9 into 1D tensor of 16x9x9 th> net:add(nn.Linear(16x9x9, 120)) -- fully connected layer (matrix multiplication between input and weights) th> net:add(nn.ReLU()) -- non-linearity th> net:add(nn.Linear(120, 84)) th> net:add(nn.ReLU()) -- non-linearity th> net:add(nn.Linear(84, 10)) -- 10 is the number of outputs of the network (in this case, 10 digits) th> net:add(nn.LogSoftMax()) -- converts the output to a log-probability. Useful for classification problems

th> criterion = nn.ClassNLLCriterion()

th> trainer = nn.StochasticGradient(net, criterion) th> trainer.learningRate = 0.001 th> trainer.maxIteration = 5 -- just do 5 epochs of training.

th> trainer:train(trainset)

trainer:train(trainset) StochasticGradient: training

/root/facedetect/torch/install/share/lua/5.1/nn/THNN.lua:110: Assertion `THIndexTensor_(size)(target, 0) == batch_size' failed. at /tmp/luarocks_nn-scm-1-1625/nn/lib/THNN/generic/ClassNLLCriterion.c:50 stack traceback: [C]: in function 'v' /root/facedetect/torch/install/share/lua/5.1/nn/THNN.lua:110: in function 'ClassNLLCriterion_updateOutput' ...ect/torch/install/share/lua/5.1/nn/ClassNLLCriterion.lua:43: in function 'forward' ...ct/torch/install/share/lua/5.1/nn/StochasticGradient.lua:35: in function 'train' [string "_RESULT={trainer:train(trainset)}"]:1: in main chunk [C]: in function 'xpcall' /root/facedetect/torch/install/share/lua/5.1/trepl/init.lua:661: in function 'repl' ...tect/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:199: in main chunk [C]: at 0x004064f0

ilichev-andrey commented 7 years ago

how to determine these parameters:

5x5 convolution kernel? net:add(nn.SpatialConvolution(3, 6, 5, 5))

reshapes from a 3D tensor of 16x5x5 into 1D tensor of 16x5x5? net:add(nn.View(16x5x5))