torch / nn

Other
1.34k stars 967 forks source link

How to separate a network into submodules #1295

Closed SeunghyunMoon closed 7 years ago

SeunghyunMoon commented 7 years ago

Hi,

I tried separating a dnn into submodules like below. I thought these are same. However, the former achieves 90% accuracy and the latter does 9% only.. To make the equivalent separated network, what should I do??

--- Original ---- local model = nn.Sequential() model:add(nn.View(-1,784)) model:add(BinaryLinear(784,numHid)) model:add(BatchNormalization(numHid, opt.runningVal)) model:add(nn.HardTanh()) model:add(BinarizedNeurons(opt.stcNeurons))

model:add(BinaryLinear(numHid,numHid,opt.stcWeights)) model:add(BatchNormalization(numHid, opt.runningVal)) model:add(nn.HardTanh()) model:add(BinarizedNeurons(opt.stcNeurons))

model:add(BinaryLinear(numHid,numHid,opt.stcWeights)) model:add(BatchNormalization(numHid, opt.runningVal)) model:add(nn.HardTanh()) model:add(BinarizedNeurons(opt.stcNeurons))

model:add(BinaryLinear(numHid,10,opt.stcWeights)) model:add(nn.BatchNormalization(10))

-------- separated ---------- local firstLayer = nn.Sequential() firstLayer:add(nn.View(-1, 784)) firstLayer:add(BinaryLinear(784,numHid)) firstLayer:add(BatchNormalization(numHid, opt.runningVal)) firstLayer:add(nn.HardTanh()) firstLayer:add(BinarizedNeurons(opt.stcNeurons))

local secondLayer = nn.Sequential() secondLayer:add(BinaryLinear(numHid,numHid,opt.stcWeights)) secondLayer:add(BatchNormalization(numHid, opt.runningVal)) secondLayer:add(nn.HardTanh()) secondLayer:add(BinarizedNeurons(opt.stcNeurons))

local thirdLayer = nn.Sequential() thirdLayer:add(BinaryLinear(numHid,numHid,opt.stcWeights)) thirdLayer:add(BatchNormalization(numHid, opt.runningVal)) thirdLayer:add(nn.HardTanh()) thirdLayer:add(BinarizedNeurons(opt.stcNeurons))

local fourthLayer = nn.Sequential() fourthLayer:add(BinaryLinear(numHid,10,opt.stcWeights)) fourthLayer:add(nn.BatchNormalization(10))

local model = nn.Sequential():add(firstLayer):add(secondLayer):add(thirdLayer):add(fourthLayer)