torch / nngraph

Graph Computation for nn
Other
299 stars 97 forks source link

overload operator __unm and __sub to support module chaining #113

Closed iamalbert closed 8 years ago

iamalbert commented 8 years ago

By overloading __unm__ and __sub__ of nn.Module and __sub__ of nngraph.Node, graph construction is easier and human readable.

For example,

h1 = nn.Linear(20, 10)()
h2 = nn.Linear(10, 1)(nn.Tanh()(nn.Linear(10, 10)(nn.Tanh()(h1))))
mlp = nn.gModule({h1}, {h2})

can be written as

h1 = - nn.Linear(20,10)
h2 = h1
     - nn.Tanh()
     - nn.Linear(10,10)
     - nn.Tanh()
     - nn.Linear(10, 1)
mlp = nn.gModule({h1}, {h2})

In the original syntax, as the graph getting bigger and more complicated, the increasing amount of nested parentheses may be confusing too. In this case, module chaining is clearer and easier to debug since we can see the data flow at a glance.

For example

input = nn.Identity()()
L1 = nn.Tanh()(nn.Linear(10, 20)(input))
L2 = nn.Tanh()(nn.Linear(30, 60)(nn.JoinTable(1)({input, L1})))
L3 = nn.Tanh()(nn.Linear(80, 160)(nn.JoinTable(1)({L1, L2})))
g = nn.gModule({input},{L3})

can be written as

input = - nn.Identity()
L1 =  input 
     - nn.Linear(10, 20) 
     - nn.Tanh()
L2 =  { input, L1 }
     -  nn.JoinTable(1)
     -  nn.Linear(30,60) 
     -  nn.Tanh()
L3 = { L1,L2 }
     - nn.JoinTable(1)
     - nn.Linear(80,160)
     - nn.Tanh()
g = nn.gModule({input},{L3})
soumith commented 8 years ago

i like it, but can you also add the examples in the main README?

soumith commented 8 years ago

Thank you!