IDSIA / brainstorm

Fast, flexible and fun neural networks.
Other
1.3k stars 152 forks source link

Handler-specific storage requirements in layers #29

Closed untom closed 9 years ago

untom commented 9 years ago

While implementing convolutions/pooling, I've stumbled upon the following problem: Depending on how you implement the operations, you might need more/less storage. Specifically, the GPU and CPU implementations might need different amounts of it / might not need it at all. Currently I have two examples for this:

In both cases, one of the two handlers needs additional storage, while the other doesn't. What's even weirder: the argmax can be seen as a buffer, and could be handled by the buffer manager. However, that'd lead to wasting memory on the GPU, where we'd allocate the buffer but never use it (which might also be confusing to users who inspect these buffers expecting them to mean something). The descriptors OTOH are cudnn-specific structures and probably not meant to be stored in buffers.

I can think of twho solutions

  1. Add something like handler.allocate_pool/conv_specific_memory(...) that returns some sort of opaque datastructure (maybe a list of descriptors, allocations), which is then stored within each layer and always passed to the conv/pooling methods...
  # in layer ctor:
  self._pooling_data = self.handler.allocate_pool_specific_memory()`

  # in forward path
  def forward_pass(...):
        # each specific handler implementation is free to ignore the last argument if he doesn't need it
        self.handler.conv2d_forward_batch(inputs, window, outputs, pad, stride, self._pooling_data)
  1. Allocate/deallocate cudnn-specific stuff in each call, and make "argmax" an internal buffer of the pooling layer

I'm not superhappy with either solution, since both are slightly ugly. I like solution 1 slightly more, but it has the additional problem of making the API a bit more complicated. What do you guys think?

flukeskywalker commented 9 years ago

How about this strategy:

untom commented 9 years ago

One problem with that: argmax contains integers. However, all of our internals are assumed to be floating point numbers, and there currently is no way to request a different dtype.

flukeskywalker commented 9 years ago

Yes, these and dropout masks for examples will have to be float for now. We can discuss plans to work around this in the future, but it will involve additional kernels, for example.