IDSIA / brainstorm

Fast, flexible and fun neural networks.
Other
1.3k stars 154 forks source link

bugfix in PyCudaHandler merge and split operations #113

Open Osambezy opened 8 years ago

Osambezy commented 8 years ago

The attribute call to gpudata is done by pycuda inside the function, no need to do it outside the function.

Before this change, using a merge layer with pycuda 2016.1 was giving the following errors in forward and backward pass:

  File "test.py", line 214, in main
    trainer.train(network, getter_tr, valid_getter=getter_va)
  File "brainstorm/training/trainer.py", line 99, in train
    self.stepper.run()
  File "brainstorm/training/steppers.py", line 103, in run
    self.net.forward_pass(training_pass=True)
  File "brainstorm/structure/network.py", line 430, in forward_pass
    layer.forward_pass(self.buffer[layer_name], training_pass)
  File "brainstorm/layers/merge_layer.py", line 52, in forward_pass
    buffers.outputs.default)
  File "brainstorm/handlers/pycuda_handler.py", line 323, in merge_tt
    block=block, grid=grid)
  File "/lib/python2.7/site-packages/pycuda/driver.py", line 383, in function_call
    handlers, arg_buf = _build_arg_buf(args)
  File "/lib/python2.7/site-packages/pycuda/driver.py", line 158, in _build_arg_buf
    raise TypeError("invalid type on parameter #%d (0-based)" % i)
TypeError: invalid type on parameter #0 (0-based)

  File "test.py", line 214, in main
    trainer.train(network, getter_tr, valid_getter=getter_va)
  File "brainstorm/training/trainer.py", line 99, in train
    self.stepper.run()
  File "brainstorm/training/steppers.py", line 104, in run
    self.net.backward_pass()
  File "brainstorm/structure/network.py", line 444, in backward_pass
    layer.backward_pass(self.buffer[layer_name])
  File "brainstorm/layers/merge_layer.py", line 59, in backward_pass
    buffers.input_deltas.inputs_2)
  File "brainstorm/handlers/pycuda_handler.py", line 364, in split_add_tt
    block=block, grid=grid)
  File "/lib/python2.7/site-packages/pycuda/driver.py", line 383, in function_call
    handlers, arg_buf = _build_arg_buf(args)
  File "/lib/python2.7/site-packages/pycuda/driver.py", line 158, in _build_arg_buf
    raise TypeError("invalid type on parameter #%d (0-based)" % i)
TypeError: invalid type on parameter #0 (0-based)