google / prettytensor

Pretty Tensor: Fluent Networks in TensorFlow
1.24k stars 151 forks source link

How to implement residual blocks with prettytensor? #39

Closed ghost closed 7 years ago

ghost commented 7 years ago

Hi, Can anyone tell me how to implement residual blocks or skip-connection with prettytensor? An idea or a code example will be helpful.

eiderman commented 7 years ago

In non-sequential mode, the easiest way to do it is:

x = x.fully_connected(...)
x += x.fully_connected(...).fully_connected(...)   # 2 layer residual block

In sequential mode you can do this:

x = x.fully_connected(...)
x += x.as_layer().fully_connected(...).fully_connected(...)   # 2 layer residual block, as_layer causes a snapshot to be taken.

You can also register this as a function using the @Register

@Register
def fc_residual(input_, size, residual_sizes, activation_fn=tf.nn.relu):
  # Because we used Register, input is always non-sequential
  x = input_.fully_connected(size, activation_fn)
  residual = input_
  for rsize in residual_sizes:
    residual = residual.fully_connected(rsize, activation_fn)
  return x + residual
ghost commented 7 years ago

Thanks, I will try it.