tensorflow / swift-apis

Swift for TensorFlow Deep Learning Library
Apache License 2.0
794 stars 133 forks source link

Multiple return parameters #67

Open tanmayb123 opened 5 years ago

tanmayb123 commented 5 years ago

I'm trying to implement a simple/limited version of the meshgrid op. This is what I've got:

func meshgrid(x: Tensor<Float>, y: Tensor<Float>) -> (Tensor<Float>, Tensor<Float>) {
    let outputX = x.reshaped(to: [-1, 1])
    let outputY = y.reshaped(to: [-1, 1])
    let multFactX = Tensor<Float>(ones: [x.scalarCountTensor.scalarized()])
    let multFactY = Tensor<Float>(ones: [y.scalarCountTensor.scalarized()])
    return ((outputX * multFactX).transposed(), outputY * multFactY)
}

Of course, that can't be made differentiable because tuples aren't differentiable. So, I had to implement this:

struct TensorPair<T: TensorFlowFloatingPoint>: Differentiable {

    var first: Tensor<T>
    var second: Tensor<T>

    @differentiable
    init(_ first: Tensor<T>, _ second: Tensor<T>) {
        self.first = first
        self.second = second
    }

}

@differentiable
func meshgrid(x: Tensor<Float>, y: Tensor<Float>) -> TensorPair<Float> {
    let outputX = x.reshaped(to: [-1, 1])
    let outputY = y.reshaped(to: [-1, 1])
    let multFactX = Tensor<Float>(ones: [x.scalarCountTensor.scalarized()])
    let multFactY = Tensor<Float>(ones: [y.scalarCountTensor.scalarized()])
    return TensorPair((outputX * multFactX).transposed(), outputY * multFactY)
}

It works, but it's not a very elegant solution. Thoughts:

  1. Will tuples ever be differentiable?
  2. This could be made more a bit more elegant by returning an array of tensors with the two values.
rxwei commented 5 years ago
  • Will tuples ever be differentiable?

Extensions and protocol conformances for compound types (also called structural types) are part of Swift's roadmap: https://github.com/apple/swift/blob/master/docs/GenericsManifesto.md#extensions-of-structural-types. When this is possible, tuples of types that conform to Differentiable will conform to Differentiable.

That said, I may decide to define a bunch of differential operators overloaded for functions that return 2-tuples. The potential downside is that there will be a lot of API duplications that can screw up code completion.

  • This could be made more a bit more elegant by returning an array of tensors with the two values.

Array will be differentiable soon: https://github.com/apple/swift/pull/23183.

Here's a general guideline: If the API you'd like to define returns a fixed number of outputs, I'd recommend defining a struct that meaningfully distinguishes the semantics of the return aggregate from general pairs. Each struct element should be labeled according to their role. In order words, before tuples can conform to protocols, use structs with well-named elements instead of overgeneralized tuple structs.

joaqo commented 4 years ago

I'm using the struct solution at the moment, but just for curiosity, will this be implemented? It's useful for one off things which don't require the verbosity of defining a new struct type.