tensorflow / swift-apis

Swift for TensorFlow Deep Learning Library
Apache License 2.0
794 stars 135 forks source link

Parameter counting API for Layer #412

Open t-ae opened 5 years ago

t-ae commented 5 years ago

Currently we can't easily know how many parameters Layer instance has. It'll be useful for estimating model size.

Keras's Layer has this feature. https://github.com/keras-team/keras/blob/c10d24959b0ad615a21e671b180a1b2466d77a2b/keras/engine/base_layer.py#L1105-L1123

rxwei commented 5 years ago

This can be done once we flesh out the property wrapper solution for parameters (along the lines of #250).

mikowals commented 5 years ago

I have been counting parameters based on how variables used to be updated in optimisers:

extension Layer {
    var parameterCount: Int {
        let floatKeyPaths = self.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self)
        let doubleKeyPaths = self.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self)
        var parameters = 0         
        for kp in floatKeyPaths {
           parameters += self[keyPath: kp].shape.contiguousSize
        }
        for kp in doubleKeyPaths {
            parameters += self[keyPath: kp].shape.contiguousSize
        }
        return parameters
    }
}

Is there anything inherently wrong in doing it this way or is it just not ideal as more variable types get used?

rxwei commented 5 years ago

One problem with this approach is duplication for each variable type as you mentioned. A bigger problem is that you might include tensors that are not parameters of a model, e.g. ones that are marked @noDerivative.

mikowals commented 5 years ago

Thanks for explaining this Richard.

eaplatanios commented 5 years ago

@mikowals @rxwei another problem is that you might not include tensors that are parameters. E.g., different data types or potentially, arrays of tensors (I believe this would not account for tensor arrays).

Kshitij09 commented 4 years ago

I'd like to work on this one. Could you please assign this to me? :slightly_smiling_face: