nicholas-leonard / dp

A deep learning library for streamlining research and development using the Torch7 distribution.
Other
343 stars 140 forks source link

Should dp.TextSource provide an impl of frequencyTree()? #154

Open hsheil opened 9 years ago

hsheil commented 9 years ago

When running recurrentlanguagemodel.lua with a custom text dataset and --softmaxtree the following error occurs:

/home/hsheil/torch/install/bin/luajit: recurrentlanguagemodel.lua:222: attempt to call method 'frequencyTree' (a nil value)
stack traceback:
    recurrentlanguagemodel.lua:222: in main chunk
    [C]: in function 'dofile'
    ...heil/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:131: in main chunk
    [C]: at 0x00406670

Line 222 is the culprit:

   elseif opt.softmaxtree then -- uses frequency based tree
      local tree, root = ds:frequencyTree()
      softmax = nn.SoftMaxTree(inputSize, tree, root, opt.accUpdate)
   end

I guess --softmaxtree is most beneficial for the billion word dataset anyway - but to avoid this error should dp.TextSource provide an impl of the frequencyTree() method?

Cmd-line params to reproduce (only the combination of --softmaxtree and --dataset TextSource is pertinent I think):

th recurrentlanguagemodel.lua --lstm --cuda --dataPath data --dataset TextSource --softmaxtree

adonisues commented 9 years ago

I had same problem. It seems that there is no initialization of softmaxtree. So I have solved it as following

add below( it is from dp/data/penntreebank.lua) to dp/data/textsource.lua

-- this can be used to initialize a SoftMaxTree function TextSource:frequencyTree(binSize) binSize = binSize or 100 local wf = torch.IntTensor(self:wordFrequency()) local vals, indices = wf:sort() local tree = {} local id = indices:size(1) function recursiveTree(indices) if indices:size(1) < binSize then id = id + 1 tree[id] = indices return end local parents = {} for start=1,indices:size(1),binSize do local stop = math.min(indices:size(1), start+binSize-1) local bin = indices:narrow(1, start, stop-start+1) assert(bin:size(1) <= binSize) id = id + 1 table.insert(parents, id) tree[id] = bin end recursiveTree(indices.new(parents)) end recursiveTree(indices) return tree, id end