Closed DEKHTIARJonathan closed 5 years ago
Feel free to add more questions, and add answers below them.
network["layer_name"]
, example:model.add(DenseLayer(20, None, name="seq_layer_2"))
def __getitem__(self, key):
of core layer. Is there any other advantages of using Network class?# PyTorch-Like Layer Access
layer_1 = model["seq_layer_1"]
Here https://github.com/tensorlayer/tensorlayer/pull/751
A list of dictionary with the same order of layer construction that contains all augments
It is available in master branch, try tl.distributed.Trainer
.
I am more concerned about the general idea than the technical details. For ResNet, I am working on this and for your second point, we can always create an "OutputLayer" that does just this ^^
I like the new layer API, in particular for it implements the idea of layer is a (unary, in most case) function of tensors.
With such signature of layer API implemented, we can stack layers in an easier way:
layers = [
tl.layers.DenseLayer(n_units=20, act=None, name="seq_layer_2"),
tl.layers.PReluLayer(channel_shared=True, name="prelu_layer_2"),
tl.layers.DenseLayer(n_units=50, act=None, name="seq_layer_3"),
tl.layers.PRelu6Layer(channel_shared=False, name="prelu6_layer_3")
]
net_out = transform(layers, net_in)
where
def transform(layers, net):
y = net
for l in layers:
y = l(y)
return y
We can even introduce a combinator for layers:
# stack :: [Layer] -> Layer
def stack(layers):
return StackedLayers(layers)
which allows the following:
vgg_unit = stack(...)
net_out = vgg_unit(net_in)
After some discussion with @lgarithm, we came up with a conceptual design that would highly clarify the TL architecture and make future development a lot more easy and integrable.
@lgarithm: Please edit my post if you want to add new stuff or correct some stuff I m ight have got wrong.
Layer and Network Classes will be converted to factory-like objects. They contain critical information to generate Layers and Network, however, they do not contain anymore any information about the generated Layer.
Layer Class becomes a factory encapsulating the required information to generate a TF Ops, TF Tensors and TF Vars. Layer object do not contain anymore any information about the network or TF objects.
Which means that these informations will be removed:
Layer.outputs
Layer.inputs
Layer.all_params
Layer.all_layers
Layer.all_graph
and so on ...
In order to insure a stable and consistent API, connecting Layer will now happen using the functional API. This manner to call Layers allows us to use Layer as factories without having to give all the parameters again.
tl.layers.Conv2D(prev_layer, ...) # Deprecated
tl.layers.Conv2D(...)(prev_layer) # Using the Layer.__call__() method, factory way.
Now a TL Layer returns a TF Tensor, not a TL Layer anymore.
Network Classes (CustomNetwork and Sequential) also become factories encapsulating the required information to generate the Model using TL Layers.
Which means that these informations will be removed:
Network.outputs
Network.inputs
Network.all_params
Network.all_layers
Network.all_graph
and so on ...
The model generation will be launched using the .compile()
method.
Network.compile(input_placeholder, is_train=True, reuse=True)
The model output a new class CompiledNetwork
, this model will have all the old attributes and methods that makes TL Layers so easy and pleasant to use.
@DEKHTIARJonathan @lgarithm @lgarithm and I just discuss the network API.
For the compatibility, @DEKHTIARJonathan suggest to put existing layers in to tl.depercated.layers
, then all new codes should implement in new way.
However, I am worry about the network API, and suggest to support two ways at the same time.
In my opinion, even though some people feel the existing layers don't have good network abstraction, the existing way can build any networks. While, the compile way doesn't have big advantage over the existing way, and it may not able to handle some complex situation.
Let me discuss with more users, I will put more feedback later.
net = InputLayer(net)
net = DenseLayer(net, 100)
net = InputLayer()(net)
net = DenseLayer(100)(net)
net.compile()
In my opinion, the main differences between our way and keras way is they compile the network in the end but we simply compile network in real-time, and we mix up layer and network together. The other parts are the same.
While, to build complex network, in some cases, the next layer will need the information (like shape) of the input or outputs of previous layers. But the input and net.outputs
is unknown before we compile it. Take dynamic RNN layer below for example, we need the input placeholder/tensor to compute the sequence_length
. Therefore, if we compile the network in the end, we can't give the sequence_length
to DynamicRNNLayer
.
net = EmbeddingInputlayer(
inputs = x,
vocabulary_size = 1000,
embedding_size = 200,
name = 'ebb')
net = DynamicRNNLayer(net,
cell_fn = LSTMCell,
n_hidden = 200,
dropout = (keep_prob if is_train else None),
sequence_length = tl.layers.retrieve_seq_length_op2(x),
return_last = True,
name = 'ann')
sequence_length
of DynamicRNNLayer
.ConcatLayer
ConcatLayer
in the end, any other ways? returns two network via compile?You donβt really need sequence_length
when you constructing the Layer.
in the new API, it would be simply moved to the compile
method, like the following:
def transform(net, layers):
"""Apply a sequence of layers to a net to form new net."""
for layer in layers:
net = layer(net)
return net
def rnn_unit(net):
"""A combination of several layers."""
embed_layer = EmbeddingInputlayer(
vocabulary_size=1000,
embedding_size=200,
name='ebb')
d_rnn_layer = DynamicRNNLayer(cell_fn=LSTMCell,
n_hidden=200,
dropout=(keep_prob if is_train else None),
# sequence_length is moved to compile
# sequence_length=tl.layers.retrieve_seq_length_op2(x),
return_last=True,
name='ann')
return transform(net, [embed_layer, d_rnn_layer])
# where DynamicRNNLayer is defined as the folowing
class DynamicRNNLayer(Layer):
def compile(prev_layer):
sequence_length = tl.layers.retrieve_seq_length_op2(prev_layer)
β¦.
Or, you can make sequence_length
a lambda to the constructor of DynamicRNNLayer,
and call it when you have the net
variable, which is also in the compile method.
@lgarithm in this example, sequence_length
is not get from from previous layer, but from the input tensor. This is just an example, there would be some cases that the next layer need information from previous layers which can't be get before compiling.
for example, the tensor information can be from any previous layers:
n1 = Layer(x)
n2 = Layer(n1)
n3 = Layer(n2, n1.outputs.shape())
n4 = Layer(n3)
I think this is a main drawback of using compiling.
Could anyone explain in more detail what advantages of new Network API over the existing one? It would be better to consider thoroughly whether these advantages are really critical, before choosing to break the compatibility. Thanks.
@zsdonghao The new API is still capable of building such network, you just need to use it in the following way:
n1 = Layer1()(x)
n2 = Layer2()(n1)
n3 = Layer3(n1.outputs.shape())(n2)
n4 = Layer4()(n3)
Basically, the new API isn't less expressive than the old one.
Every expression in the old API can be converted into an equivalent one in the new API, by just moving the first argument net
to the end: ALayer(net, ...)
-> ALayer(...)(net)
.
So we won't drop anything when upgrading to the new API.
The new layer design will please the NLP researchers, and does not make much difference for CNN user?
Basically, TL 2.0 API has been motivated by a few points:
New Deep Learning tend to favor Keras by a large marging. Why ? It's because Keras has been designed with an intuitive API, very easy to manipulate and to use.
I won't go into much details, but I think we can agree that Keras API is a piece of good design.
By trying to support different backends and to be (excessively) simple, Keras is unfortunately very slow (compared to other solutions) and incredibly difficult to use when you want to realize something not supported out of the box.
How many times we see people on stackoverflow struggling to do simple stuff with TF not managing to make it work in Keras.
I'm in no ways an experienced PyTorch user, however this library also has an amazing API. We also try to get inspiration from it. Feel free to give us suggestion about stuff you like in PyTorch.
I will take a simple example.
net_in = tl.layers.InputLayer(...)
net_1 = tl.layers.Conv2d(net_in, ...)
net_2 = tl.layers.Conv2d(net_1, ...)
print(net_2.all_params) # W and b for both Conv2d Layers => 4 params
print(net_2.all_layers) # net_in, net_1 and net_2
print(isinstance(net_2, tl.layers.Layer)) # True
So we have Layer objects which behave like Network objects. Am I supposed to understand Layer objects as Network objects ? Why when you save a Layer, you actually save a Network ?
It's not consistent, not logical, not intuitive. We are used to this mindset, but it doesn't mean it's good.
This fact also bring a ton of redundancy in the Layers (each layer contains an exact copy of the information contained in the all the previous ones). Not memory efficient.
This redundancy brings up a ton of bugs, I can't even count how many of these bugs we solved over the last 6 months with @zsdonghao (I prefer not knowing ^^).
More clarity, more consistency will bring less complexity, less bugs.
To summarise, what we try to achieve with TL 2.0 is simply to make TensorLayer an easier library to use on a day to day basis.
Also to reduce the library maintainance workload, we are less than 8 people working actively on this. We are not supported by any large company (GAFAM/BATX). We do this on our free time, we need to make sure to keep it as efficient as possible.
It's true, TL 2.0 will break code compatibility. However, you can solve it in two ways:
TL 2.0 will be a huge step forward to a more professional and efficient library. We are really excited about it and hope you will share our excitement ;)
In any case, if you feel like some features should be added in TensorLayer, or you always wanted that TL could do something. Please feel free to share sudo-code, we'll try our best to fit it inside the new API if it fits our roadmap ;)
Hello, I am a big fan of TensorLayer who used Keras and Tflearn. Please consider my following comments.
Making TensorLayer more like Keras is not a good decision. Actually, the practice shows that compiling the network after declaration is an disadvantage, it will have problems for building complex models especially dynamic models which need the tensor from previous layers to define the next layers.. the code will become dirty.. Many people would not use that way to build network unless the network is very simple... That is the reason I moved from Keras to TensorLayer and Tflearn.
Besides, other disadvantage I can come out with now is that for tl.models
which has restore
function. It needs the all_params
from intermedia layer. In other word, not only the latest layer need all_xxx
, but also all layers in the network should have all_xxx
.
n1 = Layer1()(x)
n2 = Model1()(n1)
n3 = Layer3(n1.outputs.shape())(n2)
n4 = Layer4()(n3)
n4.compile()
n2.all_params <--- tl.models need all_params list to restore the parameters.
n2.restore(sess)
n4.all_params
I guess this is the reason that Keras' models do not support end_with
like TensorLayer.
Layer as network is the key abstraction of TensorLayer, and make sense to me and my colleagues. Some Keras fans may say the TensorLayer's abstraction level is not high enough, but TensorLayer's abstraction successfully help us to build any complex networks we designed, while, Keras cannot. This is why I think TensorLayer's abstraction is actually better than Keras.
For bugs, @DEKHTIARJonathan mentions above, I think it is just an engineering work, we should not consider changing TensorLayer usage because of it.
To summarise, if the release of network API will deprecates the existing TensorLayer's usage, I believe it is a bad decision. For existing users, if they need to rewrite the code like Keras, why they don't rewrite it in Keras or Tflearn? (they will loss confidence of TensorLayer as TensorLayer become more like Keras). For new users, if the abstraction is similar with Keras, why they don't choice Keras? (as they don't need many advanced layers at the beginning stage.)
Overall, I and my colleagues believe that the existing TensorLayer abstraction is better than other libraries. If the network API must be released, it is fine for me if the existing TensorLayer usage is maintained, otherwise, it is very wired. No one will use tl.depercated.layers
and I probably will move to Tflearn instead.. honest.
For next version, I suggest the community can put more effect on the following few functions:
tl.models
: It is one of our big advantage over Keras and other TensorFlow-based libraries.tl.prepro
is for numpy-array preprocessing (threading_data
), while TensorFlow dataset API and tl.image
have better performance, we should let user know how to choice the best way, and update the example code.Best wishes,
@wagamamaz I thin your remarks comes from a misunderstanding. We are perfectly aware of the case you would like to use, and it works perfectly with the new API.
I said we would like to have a more Keras-like API because it is more intuitive and a cleaner approach. I never said it has to be identical and work the same way.
The few times I try to use Keras to see how it works, I also struggled on similar situations as you pointed out. The plan is by no mean to complexify or prevent users to do things that they currently do easily.
Unfortunately, I can't share with you some code. Because, it's still not perfectly working. However, as I said in our previous meeting.
If any important feature (e.g non sequential model) can't be made to work easily and smoothly, We will cancel this update and abort the update. This would be an absolute no-go. I 100% agree with you.
I would also like to highlight something, we used the term Network.compile()
, maybe this was unfortunate and shall be renamed (any proposition ?). TL compile and Keras compile are by no mean identical.
Keras compile takes a loss and an optimizer. In the other hand, at the current stage, TL compile is only a function that will generate the corresponding TF Ops/Tensors/Vars of each Layer.
In the new API, Network and Layers shall be understood as "blueprints" of networks and layers. The .compile()
method create the corresponding TF Ops/Tensors/Vars for each Layer/Network.
I hope to made it clearer. @wagamamaz I would like to involve you in the test of the new API. Would you be free for some discussions and maybe some tests ?
TL 2.0 is still under heavy discussion/development, and i would love to have someone like you to have an honest opinion. I can show you the differences and the new TL API and you'll be free to try it ;) If something important is not working, I can try to see how to solve it thanks to your help ;)
model = tl.networks.Sequential(name="My_Sequential_1D_Network")
model.add(tl.layers.DenseLayer(n_units=10, act=tf.nn.relu, name="seq_layer_1"))
model.add(tl.layers.DenseLayer(n_units=20, act=None, name="seq_layer_2"))
model.add(tl.layers.PReluLayer(channel_shared=True, name="prelu_layer_2"))
model.add(tl.layers.DenseLayer(n_units=30, act=None, name="seq_layer_3"))
model.add(tl.layers.PReluLayer(channel_shared=False, name="prelu_layer_3"))
plh = tf.placeholder(tf.float16, (100, 32))
train_model = model.compile(plh, reuse=False, is_train=True)
test_model = model.compile(plh, reuse=True, is_train=False)
print(type(train_model)) # tl.models.CompiledNetwork
print(type(test_model)) # tl.models.CompiledNetwork
# What can you do with them ?
# 1. Get Layers by name
## Compiled_Layer are generated by a factory and are immutable.
print(type(train_model["seq_layer_3"])) # tl.models.Compiled_DenseLayer
print(type(train_model["prelu_layer_3"])) # tl.models.Compiled_PReluLayer
# 2. You still can do all the cool stuff you are used to doing
## Compiled_Layer are generated by a factory and are immutable.
print(train_model["seq_layer_3"].outputs) #output of this specific layer => tf.Tensor
print(train_model["seq_layer_3"]._local_weights) # weights Tensors of this specific layer
print(train_model["seq_layer_3"]) # returns the string describing the object.
print(train_model["seq_layer_3"].__dict__) # all the hyperparameters of the layer
print(train_model.inputs) # input of the model => tf.placeholder
print(train_model.outputs) # output of the model => tf.Tensor
print(train_model.all_params) # all the parameters of the model
print(train_model.all_layers) # all the Compiled_Layers of the model
print(train_model.all_drop) # all the dropout placeholders of the model
def fire_module(inputs, squeeze_depth, expand_depth, name):
"""Fire module: squeeze input filters, then apply spatial convolutions."""
with tf.variable_scope(name, "fire", [inputs]):
squeezed = tl.layers.Conv2d(
n_filter=squeeze_depth,
filter_size=(1, 1),
strides=(1, 1),
padding='SAME',
act=tf.nn.relu,
name='squeeze'
)(inputs)
e1x1 = tl.layers.Conv2d(
n_filter=expand_depth,
filter_size=(1, 1),
strides=(1, 1),
padding='SAME',
act=tf.nn.relu,
name='e1x1'
)(squeezed)
e3x3 = tl.layers.Conv2d(
n_filter=expand_depth,
filter_size=(3, 3),
strides=(1, 1),
padding='SAME',
act=tf.nn.relu,
name='e3x3'
)(squeezed)
return tl.layers.ConcatLayer(concat_dim=3, name='concat')([e1x1, e3x3])
class MyCustomNetwork(tl.networks.CustomModel):
def model(self):
input_layer = tl.layers.InputLayer(name='input_layer')
net = fire_module(input_layer, 32, 24, "fire_module_1")
net = fire_module(net, 32, 24, "fire_module_2")
return input_layer, net
model = MyCustomNetwork(name="my_custom_network")
plh = tf.placeholder(tf.float16, (100, 16, 16, 3))
train_model = model.compile(plh, reuse=False, is_train=True)
test_model = model.compile(plh, reuse=True, is_train=False)
print(type(train_model)) # tl.models.CompiledNetwork
print(type(test_model)) # tl.models.CompiledNetwork
And of course all the same stuff as above for the Sequential Network. Do you really think the API is changing so much that it becomes unfixable ?
Maybe I am missing some edge cases, and I'll be glad to see some examples and I'll work to make them possible ;)
@luomai and I actually talked about this, namely after what happened to Caffe2. However, this a completely different situation: the change won't be huge and complex to do.
Moreover, I would like to point out, as you probably know, Tensorflow is going for the 2nd radical change (1st being TF1.0 now heading for TF 2.0). People are still using it.
Sometimes we have to break backward compatibility to be able to improve something. There might be only a very small amount of project that never did a radical change of design in the course of their existance. Not even in Deep Learning... Take JQuery in the WebDesign domain, Every new big versions, so many functionalities are broken that many plugins or libraries needs to be rewritten. Take Django 2.0 released a few months ago, same story. It might be true for you that you will loose confidence in TL, but honestly did you ever have a seen any active project which doesn't deprecate a ton of features over the years ?
Even some of the most important projects of computer science, like Python 2 and Python 3 completely broke everything, however you are still using Python I believe...
Supporting the new & old way to create TL Layers won't be working for long. The approach is completely different and it will require a ton of work to keep mainting both, time that we don't have.
I understand that you don't want to update your codebase, that's why we plan to keep fixing bugs on TL 1.x.
@DEKHTIARJonathan Thank you for the quick reply, let me point out a case that I think compile method can't handle: To build complex network, all intermedia layers should have all_xxx
and outputs
before compiling.
Take the example I and @lgarithm described above, the tensors of a layer (all_xxx
and outputs
) are usually required to build complex network. One example is the dynamic RNN that @zsdonghao described, you will find many problems when building RNN network, and that is why it is hard to use dynamic RNN with Keras. (we can find a lot examples like that) Another disadvantage I can come out with now is that tl.models
need all_params
tensors.
I think this problem need to be solved before we move forward, otherwise, it is just waste of time..
n1 = Layer1()(x)
n2 = Model1()(n1) <== 2) tl.models
n3 = Layer3(n1.outputs.shape())(n2) <== 1) need previous tensor, while with compile, n1 doesn't have `outputs`.
n4 = Layer4()(n3)
n4.compile()
n2.all_params <== 2) tl.models needs `all_params` to restore the parameters. How `n4.compile()` gives the `all_params` to `n2` ?
n2.restore(sess) <== 2) restore parameters
n4.all_params
Using compile to build simple CNN network is fine, but when you try to build complex network, you will always find problems. This is very important for academic users like me. Layer as network is a good abstraction. Besides, the code you show above is not simple than the existing way. I am not saying I don't want to change my codebase, instead, I can't see the reason to change the codebase .. Therefore, if the existing usage move to TL 1.x, I believe no one will use it in my lab, all people are forced to update their codes. But to be honest, I personally will move to Tflearn as it is similar with TL. Alternatively, copy existing TL to build another library...
@wagamamaz alright, let me have a think to this. I prefer thinking first about the problem you highlight instead of saying something wrong ;)
Just to be sure, the new API was designed to be more intuitive, but also to improve the heavy work happening behind the hood, current version of TL is quite incredibly inefficient with a lot of redundancy and quite a lot of complexity in the internals.
Can you give me a quick example of the code above with the current version of TL ? That way I have something to work with ;)
Do you mean something like this (a dummy AE):
def get_ae_model(plh, is_train, reuse): # is_train not useful here, no dropout or BN
with tf.variable_scope("my_scope", reuse=reuse):
net_encoder = tl.layers.InputLayer(plh)
net_encoder = tl.layers.DenseLayer(net_encoder, n_units=50)
net_encoder = tl.layers.DenseLayer(net_encoder, n_units=10)
net_decoder = tl.layers.DenseLayer(net_encoder, n_units=50)
net_decoder = tl.layers.DenseLayer(net_decoder, n_units=100)
return net_encoder, net_decoder
plh = tf.placeholder(tf.float16, (None, 100))
encoder, decoder = get_ae_model(plh, is_train=True, reuse=False)
Could this kind of workaround solve the issue:
class Encoder_Network(tl.networks.CustomModel):
def model(self):
input_layer = tl.layers.InputLayer(name='input_layer')
net_encoder = tl.layers.DenseLayer(n_units=50)(input_layer)
net_encoder = tl.layers.DenseLayer(n_units=10)(net_encoder)
return input_layer, net_encoder
class Decoder_Network(tl.networks.CustomModel):
def model(self):
input_layer = tl.layers.InputLayer(name='input_layer')
net_decoder = tl.layers.DenseLayer(n_units=50)(input_layer)
net_decoder = tl.layers.DenseLayer(n_units=100)(net_decoder)
return input_layer, net_decoder
model_encoder = Encoder_Network(name="my_encoder")
model_decoder = Decoder_Network(name="my_decoder")
plh = tf.placeholder(tf.float16, (None, 100))
encoder = model_encoder.compile(plh, reuse=False, is_train=True) # type == CompiledNetwork
# two possibilities here, it doesn't change anything ;)
decoder = model_decoder.compile(encoder, reuse=False, is_train=True) # type == CompiledNetwork
decoder = model_decoder.compile(encoder.outputs, reuse=False, is_train=True) # type == CompiledNetwork
### And then you can do
encoder.save()
encoder.restore()
encoder.all_layers
encoder.all_params
decoder.save()
decoder.restore()
decoder.all_layers
decoder.all_params
You can even do this:
class AE_Network(tl.networks.CustomModel):
def model(self):
input_layer= tl.layers.InputLayer(name='input_layer')
net_encoder = tl.layers.DenseLayer(n_units=50)(input_layer)
net_encoder = tl.layers.DenseLayer(n_units=10)(net_encoder)
net_decoder = tl.layers.DenseLayer(n_units=50)(net_encoder)
net_decoder = tl.layers.DenseLayer(n_units=100)(net_decoder)
return input_layer, net_decoder
net_AE = AE_Network(name="my_autoencoder")
plh = tf.placeholder(tf.float16, (None, 100))
model_ae = net_AE.compile(plh, reuse=False, is_train=True)
### And then you can do
model_ae.save()
model_ae.restore()
model_ae.all_layers
model_ae.all_params
Could you give me a precise example, where this could not work ? I must be missing something ;)
Concerning dynamic RNNs, to be very honest, I very rarely use RNNs. Could you provide an example, I can't design one myself with current API.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorlayer as tl
tf.logging.set_verbosity(tf.logging.DEBUG)
tl.logging.set_verbosity(tl.logging.DEBUG)
# 1) simple dynamic rnn
input_seqs = tf.placeholder(dtype=tf.int64, shape=[None, None], name="input")
is_train = True
ne = tl.layers.EmbeddingInputlayer(inputs=input_seqs, vocabulary_size=100, embedding_size=100, name='embedding')
# inside DynamicRNNLayer, batch_size need to be computed automatically from previous outputs.
# there are also many cases that require tensor from previous layers as input of next layer
fixed_batch_size = ne.outputs.get_shape().with_rank_at_least(1)[0]
if fixed_batch_size.value:
batch_size = fixed_batch_size.value
else:
from tensorflow.python.ops import array_ops
batch_size = array_ops.shape(ne.outputs)[0]
print(batch_size)
n = tl.layers.DynamicRNNLayer(
ne,
cell_fn=tf.contrib.rnn.BasicLSTMCell,
n_hidden=100,
dropout=(0.7 if is_train else None),
sequence_length=tl.layers.retrieve_seq_length_op2(input_seqs), # previous tensor is required
return_last=False,
return_seq_2d=True,
name='dynamicrnn'
)
n = tl.layers.DenseLayer(n, n_units=100, name="output")
print(input_seqs, ne.outputs)
# 2) intermedia tensor (IMPORTANT)
x = tf.placeholder(tf.float32, [None, 100, 100, 3], name='x')
nin = tl.layers.InputLayer(x, name='in')
n1 = tl.layers.Conv2d(nin, 32, (3, 3), (1, 1), act=tf.nn.relu, name='c1')
n1 = tl.layers.MaxPool2d(n1, (3, 3), (2, 2), 'SAME', name='pad1')
n1 = tl.layers.Conv2d(n1, 32, (3, 3), (1, 1), act=tf.nn.relu, name='c2')
print(n1.all_layers) # intermedia layers should have this before compiling
print(n1.all_params) # not only for debugging, but for building complex network
print(n1.outputs) # I can give more examples later
n2 = tl.layers.Conv2d(nin, 32, (3, 3), (1, 1), act=tf.nn.relu, name='c1')
n2 = tl.layers.MaxPool2d(n2, (3, 3), (2, 2), 'SAME', name='pad1')
n2 = tl.layers.Conv2d(n2, 32, (3, 3), (1, 1), act=tf.nn.relu, name='c2')
print(n1.all_layers)
print(n1.all_params)
print(n1.outputs)
n = tl.layers.ElementwiseLayer([n1, n2], tf.add, name='add')
n.n1_output = n1.outputs # advanced usage of TL which is useful for complex network
n.n2_output = n2.outputs
# 3) tl.models
x2 = tf.placeholder(tf.float32, [None, 224, 224, 3])
# get VGG without the last layer
vgg = tl.models.VGG16(x2, end_with='fc2_relu')
# add one more layer
n3 = tl.layers.DenseLayer(vgg, 100, name='out')
# initialize all parameters
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# download and restore vgg paramters online
vgg.restore_params(sess) # vgg need all_params to restore params
train_params = tl.layers.get_variables_with_name('out')
print(train_params)
I believe there should be some ways to solve this by following Keras. But if the code is quite different from existing usage, why don't just build another library?
@wagamamaz Thanks a lot. I will try to reimplement this and I come back to you ;)
If you are right and it can't be done with the new paradigm, I think we should abort or completely rethink the change ;)
I may need some time, some layers are not compliant with the new API yet ;)
@wagamamaz Actually you can have the tensor output immediately, if you build your model like that without using Squential:
n = InputLayer()(plh)
n = DenseLayer(100)(n)
n.all_xxx
n.outputs
n = DenseLayer(100)(n)
# no compile is required.
But with Sequential, you need to compile in the end. We can actually make a small tool to help people automatically change existing code to the new way.
@DEKHTIARJonathan Ask we discuss, you want to add local_params
, local_drop
in a layer.
I am thinking, should we keep the name the same?
Someone suggests all_layers
--> all_outputs
.
If we change it, we can still allow users to use all_layers
, but we can give a warning.
@property
def all_layers():
logging.warning("all_layers --> all_outputs")
return self.all_outputs
As discuss @DEKHTIARJonathan @luomai , in some case:
1) sub-network : is_train=False
, and the other layers : is_train=True
cnn = tl.models.MobileNetV1(x, end_with='depth13', is_train=False)
net = tl.layers.BatchNormLayer(cnn, is_train=True, name=βbn1β)
net = tl.layers.Conv2d(net, 32, (3, 3), name=βcnnβ)
sess = tf.InteractiveSession()
tl.layers.initialize_global_variables(sess)
cnn.restore_params(sess)
manual compile mode may become difficult. need to think about it.
2) reuse same layer in a model, compile model can't support it. TF vars and Ops are only created at compilation... the reuse here is not useful.
with tf.variable_scope('test'):
model.add(tl.layers.DenseLayer(n_units=50, act=tf.nn.relu, name="seq_layer_9"))
with tf.variable_scope('test', reuse=True):
model.add(tl.layers.DenseLayer(n_units=50, act=tf.nn.relu, name="seq_layer_9"))
@DEKHTIARJonathan one thing that I would like to do in the new API is
class VGG16Base(object):
"""The VGG16 model."""
@staticmethod
def vgg16_simple_api(net_in, end_with):
with tf.name_scope('preprocess'):
# Notice that we include a preprocessing layer that takes the RGB image
# with pixels values in the range of 0-255 and subtracts the mean image
# values (calculated over the entire ImageNet training set).
# Rescale the input tensor with pixels values in the range of 0-255
net_in.outputs = net_in.outputs * 255.0
mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean')
net_in.outputs = net_in.outputs - mean
maxpool = MaxPool2d(filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool')
conv1 = Conv2d(n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1')
conv2 = Conv2d(n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2')
conv3 = Conv2d(n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3')
conv4_5 = Conv2d(n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4')
layers = [
# conv1
conv1, conv1, maxpool,
# conv2
conv2, conv2, maxpool,
# conv3
conv3, conv3, conv3, maxpool,
# conv4
conv4_5, conv4_5, conv4_5, maxpool,
# conv5
conv4_5, conv4_5, conv4_5, maxpool,
FlattenLayer(name='flatten'),
DenseLayer(n_units=4096, act=tf.nn.relu, name='fc1_relu'),
DenseLayer(n_units=4096, act=tf.nn.relu, name='fc2_relu'),
DenseLayer(n_units=1000, name='fc3_relu'),
]
net = net_in
for l in layers:
net = l(net)
# if end_with in net.name:
if net.name.endswith(end_with):
return net
raise Exception("unknown layer name (end_with): {}".format(end_with))
It could simplify the current implementation VGG a lot.
Let's see if we can do something like this:
class VGG16Base(object): """The VGG16 model."""
@staticmethod
def vgg16_simple_api(net_in, end_with):
with tf.name_scope('preprocess'):
# Notice that we include a preprocessing layer that takes the RGB image
# with pixels values in the range of 0-255 and subtracts the mean image
# values (calculated over the entire ImageNet training set).
# Rescale the input tensor with pixels values in the range of 0-255
net_in.outputs = net_in.outputs * 255.0
mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean')
net_in.outputs = net_in.outputs - mean
maxpool = MaxPool2d(filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool')
conv1 = Conv2d(n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1')
conv2 = Conv2d(n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2')
conv3 = Conv2d(n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3')
conv4_5 = Conv2d(n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4')
layers = [
# conv_1
("conv_1", [conv1, conv1, maxpool]),
# conv_2
("conv_2", [ conv2, conv2, maxpool]),
# conv_3
("conv_3", [conv3, conv3, conv3, maxpool]),
# conv_4
("conv_4", [conv4_5, conv4_5, conv4_5, maxpool]),
# conv_5
("conv_5", [conv4_5, conv4_5, conv4_5, maxpool]),
FlattenLayer(name='flatten'),
DenseLayer(n_units=4096, act=tf.nn.relu, name='fc1_relu'),
DenseLayer(n_units=4096, act=tf.nn.relu, name='fc2_relu'),
DenseLayer(n_units=1000, name='fc3_relu'),
]
net = net_in
for l in layers:
net = l(net)
# if end_with in net.name:
if net.name.endswith(end_with):
return net
raise Exception("unknown layer name (end_with): {}".format(end_with))
@DEKHTIARJonathan @lgarithm could we support these?
1) advanced decorator, supports the existing way to define layer.
net = DenseLayer(net, 100)
model.add(DenseLayer(100))
https://github.com/ildoonet/tf-pose-estimation/blob/master/tf_pose/network_base.py#L170 https://github.com/ildoonet/tf-pose-estimation/blob/master/tf_pose/network_mobilenet.py#L23
2) auto naming, adds name automatically if not using reuse.
with tl.layers.auto_name():
model.add(DenseLayer(100))
model.add(DenseLayer(100))
3) skip, supports resnet with model
model.add(Conv2d(name=βconv1β))
model.add(Conv2d(name=βconv2β))
model.add(Conv2d(name=βconv3β))
model("conv1", "conv3β).add(Concat())
As TensorFlow 2.0 is coming later the end of this year and comes with the new eager execution mode, this PR shall be merged after TF 2.0 and must be compatible with TF 2.0. We also need to check how does the eager mode affect the existing API and check its performance overheads.
A bunch of old TF APIs would be removed, and we need to check how does it affect us.
Also, according to a recent discussion with other TL contributors @zsdonghao @wagamamaz @nebulaV @Windaway @lgarithm @fangde , a API design doc is required to create consensus to the proposed API change.
I just had a long conversation with @luomai, and I've learnt that the most concern of changing the API is about backward compatibility.
My favourite change in TL 2.0 would be moving the net
argument to the __call__
method of a layer instance, that means
y = Layer(net, otherParams)
would change to
y = Layer(otherParams)(net)
This would make it easier to reuse layers with the same otherParams
(like padding, strides).
To see how it could simplify the code, just compare https://github.com/tensorlayer/tensorlayer/blob/master/tensorlayer/models/vgg16.py and https://github.com/tensorlayer/tensorlayer/issues/770#issuecomment-416543581
I think this should be the only user visible change.
But I don't know how would existing users react to this change if it has to happen and would like to hear about some opinion.
I think other changes like model API should be additional features and it should be technically possible to keep the original API unchanged.
Dear friends,
I will try to write and summarize a few thoughts regarding the upcoming changes in TensorLayer, namely for version 2.0 πππ.
Main features that will be released.
I may forget some, if so I will update my post.
TensorLayer 1.x recurring issues.
TensorLayer has always been a quite fuzzy and messy (and it's really improving π). The results have been an incredible number of bugs in the different Layers. Implementing one single feature oftenly assume that you partly rewrite the code for absolutely each Layer (I did it already 2 times, and I'm doing it for the 3rd time with the Network API). This is extremely dangerous with a high risk of introducing an incredible number of bugs (I remind you that we had to release a very large number of release candidate to fix all the bugs 1.8.6rc0, 1.8.6rc1, 1.8.6rc2, 1.8.6rc3, 1.8.6rc4, 1.8.6rc5 1.8.6rc6).
Every so often we find new bugs, just by reading at the code:
tl.layers.DeformableConv2d
fixed (PR #573)tl.layers.ConvLSTMLayer
fixed (PR #676)tl.layers.TernaryConv2d
fixed - self.inputs not defined (PR #658)Additionally, the current Layer API is slightly counter intuitive:
layer.all_params
returns the params of a network and not a layerlayer.all_layers
is quite ironic right ?layer.count_params()
send you the number of params in the graph, not inside the layerProposition: Breaking the breaking the backward compatibility with the Network API
As TensorFlow did when releasing TF 1.0 or Django (with Django 2.0) in a non deep learning context. Very big libraries oftenly decide to let the good old times behind them and clean the code base. I believe that it is not a problem to break backward compibility if it is for the better and done very rarely.
What are the changes, I believe would highly improve TL maintainability and clarity for TL 2.0:
A few words on the Network API
I believe the network API should NOT be mandatory in TL. It should bring additional, non essential features.
The following should be possible
However, the following functionalities should be removed from Layer and move to Network API:
The list above is not exhaustive. In the same time, new functionalities can be added:
The list above is not exhaustive
Presentation of the current Network API
It is not finialized, is subject to changes. I plan on releasing two Network Classes:
Sequential (similar idea than Keras)
For easy and sequential models, the Sequential Network API is here for rapid prototyping.
Custom Model API
CustomModel/CustomNetwork/Model/Network: I haven't really decide on the name yet. This Class haven't been created yet, it is subject to change.