torch / torch7

http://torch.ch
Other
9k stars 2.38k forks source link

ParallelTable backward strange behaviour #960

Closed tastyminerals closed 7 years ago

tastyminerals commented 7 years ago

Here is the model I am working with:

nn.Sequential {
  [input -> (1) -> (2) -> (3) -> output]
  (1): ParallelTable {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> output]
      |      (1): nn.LookupTable
      |      (2): SplitTable
      |    }
       `-> (2): SplitTable
       ... -> output
  }
  (2): ZipTable
  (3): nn.Sequencer @ nn.Recursor @ nn.Sequential {
    [input -> (1) -> (2) -> (3) -> (4) -> output]
    (1): ParallelTable {
      input
        |`-> (1): nn.Sequential {
        |      [input -> (1) -> output]
        |      (1): GRU(200 -> 200, 0.00)
        |    }
         `-> (2): Linear(54 -> 200)
         ... -> output
    }
    (2): CAddTable
    (3): Linear(200 -> 764)
    (4): nn.LogSoftMax
  }
}

The input to this model is a batch table:

{
  1 : 
    {
      1 : IntTensor - size: 5x32
      2 : DoubleTensor - size: 5x32x54
    }
}

I can successfully do forward pass, however when doing backward pass the model crashes with the following error:

In 1 module of nn.Sequential:
In 2 module of ParallelTable:
./model/SplitTable.lua:45: bad argument #1 to 'resizeAs' (torch.DoubleTensor expected, got torch.IntTensor)
stack traceback:
    [C]: in function 'resizeAs'
    ./model/SplitTable.lua:45: in function <./model/SplitTable.lua:37>
    [C]: in function 'xpcall'
    ...styminerals/torch/install/share/lua/5.1/nn/Container.lua:63: in function 'rethrowErrors'
    ./model/ParallelTable.lua:30: in function 'updateGradInput'

The container that crashes is the second SplitTable in the first ParallelTable. I have copied the source files and removed nn. prefix to manually load them and print whatever they receive and send as INPUT or OUTPUT to see what is going on. Here is a strange thing:

--------------------------- ZipTable backward ------------------------  
<<< INPUT   
{
  1 : 
    {
      1 : 
        {
          1 : DoubleTensor - size: 32x200
          2 : DoubleTensor - size: 32x200
          3 : DoubleTensor - size: 32x200
          4 : DoubleTensor - size: 32x200
          5 : DoubleTensor - size: 32x200
        }
      2 : 
        {
          1 : DoubleTensor - size: 32x54
          2 : DoubleTensor - size: 32x54
          3 : DoubleTensor - size: 32x54
          4 : DoubleTensor - size: 32x54
          5 : DoubleTensor - size: 32x54
        }
    }
}
>>> OUTPUT  
{
  1 : 
    {
      1 : 
        {
          1 : DoubleTensor - size: 32x200
          2 : DoubleTensor - size: 32x200
          3 : DoubleTensor - size: 32x200
          4 : DoubleTensor - size: 32x200
          5 : DoubleTensor - size: 32x200
        }
      2 : 
        {
          1 : DoubleTensor - size: 32x54
          2 : DoubleTensor - size: 32x54
          3 : DoubleTensor - size: 32x54
          4 : DoubleTensor - size: 32x54
          5 : DoubleTensor - size: 32x54
        }
    }
}
------------------------- ParallelTable backward -------------------------  
first ParallelTable 
<<< INPUT   
{
  1 : IntTensor - size: 5x32
}
----------------------- SplitTable backward ---------------------   
WHO Upper SplitTable    
>>> SplitTable updateGradInput
{
  1 : DoubleTensor - size: 5x32x200
}
----------------------- SplitTable backward ---------------------   
WHO Lower SplitTable    
>>> SplitTable updateGradInput 
{
  1 : IntTensor - size: 32
}

First ParallelTable has only one IntTensor 5x32 for some reason and then lower SplitTable has IntTensor 32 which shouldn't have happenened as this SplitTable must have DoubleTensor 32x54 unless I do not understand how backward works in the pipeline.

Why doesn't the first ParallelTable receive whatever ZipTable outputs?

tastyminerals commented 7 years ago

I figured it out, it appears that in backward call I wasn't passing the exact same table I was passing for forward.

{
  1 : 
    {
      1 : IntTensor - size: 5x32
    }
}

instead of

{
  1 : 
    {
      1 : IntTensor - size: 5x32
      2 : DoubleTensor - size: 5x32x54
    }
}