apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

Simplify `1DConv` to `RNN` in `Sequential` Block #12506

Open thomelane opened 6 years ago

thomelane commented 6 years ago

As far as I'm aware it's not possible to create a Sequential Block going from 1DConv to RNN without implementing a custom transpose block. Although simple for the user to implement, I think this network configuration is common enough for us to think about finding an easier way.

Issue

One potential solution would be to select compatible layouts, but 1DConv only supports 'NCW' (a.k.a. 'NCT'), and RNN only supports 'TNC' and 'NTC'. So they are incompatible (without a transpose).

Although the following code runs without error, the final shape is incorrect since incompatible layouts have been used: i.e. 'NCT' with 'TNC'. Would intuitively expect the output shape to be (8, 1, 3) where T=8, N=1, C=3. But we get (1, 2, 3).

import mxnet as mx

net = mx.gluon.nn.Sequential()
net.add(mx.gluon.nn.Conv1D(channels=2, kernel_size=3))
net.add(mx.gluon.rnn.RNN(hidden_size=3))
net.initialize()

data = mx.ndarray.random.uniform(shape=(1,1,10))
output = net(data)
print(output.shape)

Options/Ideas

1) Support more layout types in 1DConv or RNN. Which could perform the transposes. 2) Create a 'Transpose' Block, so can add between the Conv1D and RNN. e.g.

net = mx.gluon.nn.Sequential()
net.add(mx.gluon.nn.Conv1D(channels=2, kernel_size=3))
net.add(mx.gluon.nn.Transpose(axes=(2,0,1))
net.add(mx.gluon.rnn.RNN(hidden_size=3))

3) Create a generic 'Operator' Block, so can add arbitrary operator in Sequential Blocks. e.g.

net = mx.gluon.nn.Sequential()
net.add(mx.gluon.nn.Conv1D(channels=2, kernel_size=3))
net.add(mx.gluon.nn.Operator('transpose', axes=(2,0,1))
net.add(mx.gluon.rnn.RNN(hidden_size=3))
thomelane commented 6 years ago

@mxnet-label-bot [Feature Request]