torch / nn

Other
1.34k stars 967 forks source link

Try to apply softmax to a batch of data with variable length #1297

Open hfxunlp opened 6 years ago

hfxunlp commented 6 years ago

Hi, I want to make SoftMax support variable length input, so you can use a batch of data with different length as the input of this module. This is helpful for Natural Language Processing, especially for the Attention model of seq2seq and Attention-over-Attention model for reading comprehension. I have also raised a corresponding pull request at https://github.com/torch/cunn/pull/489.

hfxunlp commented 6 years ago

I appreciate if I can get any chance to learn how to contribute to THNN, but I have not found any documents currently. Sorry for my lack of skills.

hfxunlp commented 6 years ago

Sorry for I miss the declaration in lib/THNN/generic/THNN.h, and I think these code can work correctly now.

hfxunlp commented 6 years ago

There is the test script that I used to check whether this patch can work correctly or not:

require "nn"
tmodstd=nn.SoftMax()
tmod=nn.LenSoftMax()
minbsize=20
maxbsize=100
minlen=16
maxlen=128
minpadlen=4
maxpadlen=16
psg=true
firstcycle=100
for t=1, firstcycle do
    if psg then
        bsize=math.random(minbsize, maxbsize)
        lens=math.random(minlen, maxlen)
        plens=math.random(minpadlen, maxpadlen)
        lvec=torch.LongTensor(bsize):fill(lens)
        stdi=torch.randn(bsize, lens)
        i=torch.cat(stdi, torch.randn(bsize, plens))
        stdgo=torch.randn(bsize, lens)
        go=torch.cat(stdgo, torch.randn(bsize, plens))
        stdo=tmodstd:forward(stdi)
        o=tmod:forward({i, lvec})
        if not (o:narrow(2, 1, lens):equal(stdo) and o:narrow(2, lens+1, plens):equal(torch.zeros(bsize, plens)) ) then
            psg=false
            print("forward error")
        end
        stdgi=tmodstd:backward(stdi, stdgo)
        gi=tmod:backward({i, lvec}, go)
        if not (gi:narrow(2, 1, lens):equal(stdgi) and gi:narrow(2, lens+1, plens):equal(torch.zeros(bsize, plens)) ) then
            psg=false
            print("backward error")
        end
    end
    xlua.progress(t, firstcycle)
end
if psg then
    print("test pass")
end