tensorflow / swift-apis

Swift for TensorFlow Deep Learning Library
Apache License 2.0
794 stars 133 forks source link

Implement more layers that are available in Keras #54

Open rxwei opened 5 years ago

tanmayb123 commented 5 years ago

@rxwei Quick question on this: Do we also want to add layers like Add Subtract Multiply Concatenate etc.? Two things to note:

  1. If we do, we could just sequence these operations instead of having to run the operations on the Tensors.
  2. How would a layer handle an arbitrary length of inputs? I'm aware of how to send multiple pre-defined inputs using a struct, but could you just pass an array of Tensors (i.e. [Tensor<Scalar>]) for many inputs?
rxwei commented 5 years ago

Hi @tanmayb123, let’s not add those operator-like layers for now and defer it to further discussions. Ideally, we would want functions to be able to conform to the layer protocol, but it’s currently impossible in Swift, so we’ll probablu need a single type wrapper that can turn any differentiable function to a layer.

As for arbitrary-length input/output, there’s ongoing work on making Array conform to Differentiable. It should be resolved next week or so.

rxwei commented 5 years ago

Random thought: We can define layer wrappers for common arities (unary and binary), and define a layer(_:) factory function for turning any Differentiable-to-Differentiable function into a layer.

Rough sketch:

public struct ClosureLayer<T: Differentiable, U: Differentiable>: Layer {
    @noDerivative
    public var function: @differentiable (T) -> U
    public func applied(to input: T) -> U {
        return function(input)
    }
}

public struct ClosureLayer2<T: Differentiable, U: Differentiable, V: Differentiable>: Layer {
    ...
}

func layer<T: Differentiable, U: Differentiable>(_ function: @differentiable (T) -> U) -> ClosureLayer<T, U> {
    return ClosureLayer(function: function)
}

...

Then, you'll be able to use functions in sequenced(in:through:).

input.sequenced(in: context, through: conv, maxPool, layer(sin), ...)

What's better: You can now use a trailing closure to create anonymous layers!

let myLayer = layer { (input: Tensor<Float>) in
    sin(cos(x)) + someParameter
}

We have plans and designs for differentiating w.r.t. closure captures. Therefore, you will even be able to differentiate through this layer and optimize someParameter.

Shashi456 commented 5 years ago

@rxwei could we go onto make a list of what is available and whats to be done? so that gives a clearer picture of what layers need to be added? #53, #52 also are layer adding issues. So making a issue and referencing all of them might make the task easier.

Shashi456 commented 5 years ago

The implementations include :

Convolution Layer :

Pooling :

Normalization :

Embedding :

Recurrent :

Core :

Recursive Neural Networks #68

~Activation Layers :~ ~- [ ] Relu~ ~- [ ] ELU~ ~- [ ] Leaky Relu~

There are a few more layers in core and Activation and a few more classes like merge classe which has add, concat, dot, minimum, maximum etc, there's convolutional recurrent layers and noise and local class, the above ones are important as of now imo and we can focus on implementing those.

Activation Layers will be added as function, refer to the discussion above.

@rxwei is this okay for a start? I'll make a more comprehensive list in a while.

tanmayb123 commented 5 years ago

@Shashi456 that’s a great list - thanks for compiling it. Just two things:

  1. Dropout has been implemented, but it’s not checked on your list.
  2. According to what Richard and I discussed above, I thought we’re not planning on creating layers like activation layers? Rather, to just pass values through the functions, or pass the functions to layers as an activation function (as you can do right now).
Shashi456 commented 5 years ago

@tanmayb123 Alright I will remove the activate layers. But is that for sure? Weren't they made layers to make the process more intuitive in the first place?

tanmayb123 commented 5 years ago

@rxwei what do you think?

rxwei commented 5 years ago

@Shashi456 Thanks a lot for listing these! Looks good to me. I'd suggest starting with the non-recurrent ones first.

aman-bhu commented 5 years ago

@rxwei , I am willing to contribute. Can I implement one the above listed layers?

rxwei commented 5 years ago

Absolutely! What would you like to implement?

aman-bhu commented 5 years ago

I am planning for Conv 3D Layer.

rxwei commented 5 years ago

Sounds great. Look forward to your PR.

Shashi456 commented 5 years ago

@rxwei @dan-zheng I wanted to ask if it'd be possible to add more aliases for different kinds of layers in the repo? For example GlobalAvgpooling = GlobalAveragePooling etc. and maybe also for the losses. Like Meansquarederror = MSE and sigmoidcrossentropy = XENT

rxwei commented 5 years ago

IMO it is ideal to stick with one set of names for consistency in all our models and example code. Currently we are leaning towards consistency with Keras. This will ensure we have overall consistency in our recommendations, while the user has the full freedom to define any aliases they want in their libraries.

Shashi456 commented 5 years ago

~#130 shows that upsampling 3D doesn't work. We are currently looking at ways to fix it. One way to do it is, to take an approach inspired by the Keras Implementation of the same.~

Solved.

dan-zheng commented 5 years ago

130 shows that upsampling doesn't work. We are currently looking at ways to fix it. One way to do it is, to take an approach inspired by the Keras Implementation of the same.

To be precise, only UpSampling3D doesn't work, because it works with 8-D tensors that are too high-dimensional for broadcasting.

lakshya-sky commented 5 years ago

Hi, @Shashi456 when seperableconv2d will be available? so that i can implement mobilenet using s4tf.

Shashi456 commented 5 years ago

@Dash2507, sometime next week. I'm working on it locally right now, I'll push it once I'm done with the other PRs.

Shashi456 commented 5 years ago

@rxwei just had a simple question, So the convolution layers also have a zero padding layer but we already have a padded function, Do i write the layers anyway? I'm just trying to avoid redundancy since they would be wrappers just calling this function

rxwei commented 5 years ago

We already have such layers, Reshape, for example. Adding a layer wrapper for each function is definitely not ideal and would complicate our API surface. Instead of throwing a lot of work into implementing those wrapper layers, I'd suggest trying define a Function (or, Lambda) layer that takes any arbitrary differentiable function and uses it inside callAsFunction(_:). Essentially, it's going to look like this:

public struct Function<InputScalar: TensorFlowFloatingPoint, OutputScalar: TensorFlowFloatingPoint>: Layer {
    public typealias Input = Tensor<InputScalar>
    public typealias Input = Tensor<OutputScalar>
    public var body: @differentiable (Input) -> Output
    public init(body: @differentiable (Input) -> Output) {
        self.body = body
    }
    public func callAsFunction(_ input: Input) -> Output {
        body(input)
    }
}

With this, you can turn any closure to a layer:

let tanhLayer = Function<Float, Float>(tanh)
let reshapeLayer = Function<Float, Float> { x in x.reshaped(to: [10, 10]) }
let paddingLayer = Function<Float, Float> { x in x.padded(forSizes: [(0, 1)], with: 0) }

Would you like to prototype this?

Shashi456 commented 5 years ago

Alright, I'll get a PR up later today.

jon-tow commented 5 years ago

I've attempted an implementation of an Embedding layer but am running into problems with the Layer protocol's input type requirements. Given that an Embedding layer consumes tensors of indices (UInt/Int) there's no way to satisfy the differentiability of callAsFunction(_:). Is there a work around to this?

@dan-zheng I've noticed an implementation of a Differentiable Embedding struct in the GPT-2 model found in the swift-models repo (GPT-2 Transformer). This doesn't conform to the Layer protocol but could we bring it into the API since it's quite useful for NLP tasks?

Shashi456 commented 5 years ago

@jon-tow did you also define a vjp for your embedding layer?

rxwei commented 5 years ago

I've attempted an implementation of an Embedding layer but am running into problems with the Layer protocol's input type requirements. Given that an Embedding layer consumes tensors of indices (UInt/Int) there's no way to satisfy the differentiability of callAsFunction(_:). Is there a work around to this?

For now, you can define a nested Input structure and mark the vocabulary property as @noDerivative. Something like:

struct Embedding<Scalar: TensorFlowFloatingPoint> {
    struct Input: Differentiable {
        @noDerivative var vocabulary: Tensor<Int32>
    }
    func callAsFunction(_ input: Input) -> Tensor<Scalar> {
        ...
    }
}
jon-tow commented 5 years ago

Hey @Shashi456. Yup. It just wouldn't compile as it relied on the Raw.gather(params:, atIndices:) function which requires a BinaryInteger for the second argument. Thanks @rxwei I'll give it a try.

eaplatanios commented 5 years ago

I'm not sure I understood correctly what you're trying to do, but I would do something along the lines of:

struct Embedding {

var embeddings: Tensor<Scalar>

@differentiable(wrt: self)

func callAsFunction(_ indices: Tensor<Int32>) -> Tensor<Scalar> {

    ...

}

}

Cheers, Anthony

On Mon, Jun 17, 2019 at 2:23 PM Jonathan Tow notifications@github.com wrote:

Hey @Shashi456 https://github.com/Shashi456. Yup. It just wouldn't compile as it relied on the Raw.gather(params:, atIndices:) function which requires a BinaryInteger for the second argument. Thanks @rxwei https://github.com/rxwei I'll give it a try.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/swift-apis/issues/54?email_source=notifications&email_token=AAJ4EXDBYLM4P7SHQT72H3DP27JBPA5CNFSM4G5GT4P2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODX4BCZA#issuecomment-502796644, or mute the thread https://github.com/notifications/unsubscribe-auth/AAJ4EXD7YWGB7JKD3PXRZUDP27JBPANCNFSM4G5GT4PQ .

rxwei commented 5 years ago

Specifying @differentiable(wrt: self) is not possible yet because the Layer protocol requires both input and self to be differentiable. There are definitely a lot of ways to resolve this, e.g. defining a separate protocol that Layer inherits from and make that protocol only require self to be differentiable. However, that requires some non-trivial thunking-related engineering right now.

rxwei commented 5 years ago

It just wouldn't compile as it relied on the Raw.gather(params:, atIndices:) function which requires a BinaryInteger for the second argument.

Hope we can merge #151 so that you can use gathering(atIndices:alongAxis:).

jon-tow commented 5 years ago

Richard's advice resolved the compiler issues I had before regarding input types. Thanks for the suggestion @eaplatanios. The only issue left seems to be differentiating gathering. I'll keep an eye out for that merge. Appreciate the help folks!

bartchr808 commented 5 years ago

Hey @jon-tow ! Actually we were mistaken but #156 already added gathering(atIndices:alongAxis:) so you should have access to it! 😄

jon-tow commented 5 years ago

@bartchr808 I had some tests passing and everything seemed okay. I was wondering what was going on! Thanks for letting me know :). I'll submit a PR sometime today.

Shashi456 commented 5 years ago

@rxwei So I've been working on the Function layer we were talking about the other day,

public struct Function<InputScalar: TensorFlowFloatingPoint, OutputScalar:TensorFlowFloatingPoint>: Layer {
    public typealias Input = Tensor<InputScalar>

    public typealias Output = Tensor<OutputScalar>

    public typealias Body = @differentiable (Input) -> Output

    @noDerivative public let body: Body

    public init(
        body: @escaping Body) {
        self.body = body
    }

    @differentiable
    public func callAsFunction(_ input: Input) -> Output {
        return body(input)
    }
}

Does this look right? I run into this error that the layer doesn't conform to protocol of Layer and that a call function is needed. As far as i understand, for a structure to inherit a protocol, you need to extend and define all the functions in the protocol, something like abstract classes theoretically. Any thoughts on where i might be going or doing it wrong?

tanmayb123 commented 5 years ago
public struct Function<InputScalar: TensorFlowFloatingPoint, OutputScalar: TensorFlowFloatingPoint>: Layer {
    public typealias Input = Tensor<InputScalar>
    public typealias Output = Tensor<OutputScalar>
    public typealias Body = @differentiable (Input) -> Output

    @noDerivative public let body: Body

    public init(body: @escaping Body) {
        self.body = body
    }

    @differentiable
    public func callAsFunction(_ input: Input) -> Output {
        return body(input)
    }

    @differentiable
    public func call(_ input: Input) -> Output {
        return callAsFunction(input)
    }
}

That compiles for me.