tensorflow / swift

Swift for TensorFlow
https://tensorflow.org/swift
Apache License 2.0
6.12k stars 608 forks source link

X10 Tensor Performance for LSTMs #461

Closed tanmayb123 closed 4 years ago

tanmayb123 commented 4 years ago

I'm running an experiment to compare the performance of LSTMs on Swift for TensorFlow and TensorFlow in Python. I'm using the following (badly written) code:

import time
import tensorflow as tf

@tf.function(experimental_compile=True)
def lstm(ih, hh, b, ts_input, ts_hidden, ts_cell, hiddensize):
    z = tf.linalg.matmul(ts_input, ih) + tf.linalg.matmul(ts_hidden, hh) + b
    z0 = z[:, 0:hiddensize]
    z1 = z[:, hiddensize:hiddensize*2]
    z2 = z[:, hiddensize*2:hiddensize*3]
    z3 = z[:, hiddensize*3:]
    i = tf.math.sigmoid(z0)
    f = tf.math.sigmoid(z1)
    c = f * ts_cell + i * tf.math.sigmoid(z2)
    o = tf.math.sigmoid(z3)

    h = o * tf.math.tanh(c)

    return (h, c)

def run_prediction(ih, hh, b, hiddensize, inputs):
    hidden = tf.zeros((inputs.shape[1], hiddensize))
    cell = tf.zeros((inputs.shape[1], hiddensize))
    hiddens = [hidden]
    for i in range(0, inputs.shape[0]):
        i = tf.constant(i)
        hidden, cell = lstm(ih, hh, b, inputs[i], hidden, cell, hiddensize)
        hiddens.append(hidden)
    return hiddens

ih = tf.random.uniform((26, 256*4))
hh = tf.random.uniform((256, 256*4))
b = tf.random.uniform((256*4,))
hiddensize = tf.constant(256)
inputs = tf.random.uniform((380, 128, 26))

def run():
    s = time.time()
    print(run_prediction(ih, hh, b, hiddensize, inputs)[-1].shape)
    e = time.time()
    print(e - s)

run()
run()
import Foundation
import TensorFlow

let device: Device = .defaultXLA

struct LSTMOutput: Differentiable {
    var hidden: Tensor<Float>
    var cell: Tensor<Float>
}

@differentiable(wrt: (ih, hh, b))
func lstm(ih: Tensor<Float>,
          hh: Tensor<Float>,
          b: Tensor<Float>,
          tsInput: Tensor<Float>,
          tsHidden: Tensor<Float>,
          tsCell: Tensor<Float>,
          hiddenSize: Int) -> LSTMOutput {
    let z = matmul(tsInput, ih) + matmul(tsHidden, hh) + b

    let z0 = z.slice(lowerBounds: [0, 0], upperBounds: [z.shape[0], hiddenSize])
    let z1 = z.slice(lowerBounds: [0, hiddenSize], upperBounds: [z.shape[0], hiddenSize * 2])
    let z2 = z.slice(lowerBounds: [0, hiddenSize * 2], upperBounds: [z.shape[0], hiddenSize * 3])
    let z3 = z.slice(lowerBounds: [0, hiddenSize * 3], upperBounds: [z.shape[0], hiddenSize * 4])

    let i = sigmoid(z0)
    let f = sigmoid(z1)
    let c = f * tsCell + i * sigmoid(z2)
    let o = sigmoid(z3)

    let h = o * tanh(c)

    return .init(hidden: h, cell: c)
}

@differentiable(wrt: (ih, hh, b))
func runPrediction(ih: Tensor<Float>,
                   hh: Tensor<Float>,
                   b: Tensor<Float>,
                   hiddenSize: Int,
                   inputs: [Tensor<Float>]) -> [Tensor<Float>] {
    var hidden = Tensor<Float>(zeros: [inputs[0].shape[0], hiddenSize], on: device)
    var cell = Tensor<Float>(zeros: [inputs[0].shape[0], hiddenSize], on: device)
    var hiddens: [Tensor<Float>] = [hidden]
    for i in 0..<withoutDerivative(at: inputs.count) {
        let result = lstm(ih: ih, hh: hh, b: b, tsInput: inputs[i], tsHidden: hidden, tsCell: cell, hiddenSize: hiddenSize)
        hidden = result.hidden
        cell = result.cell
        hiddens.append(hidden)
    }
    return hiddens
}

let ih = Tensor<Float>(randomUniform: [26, 256*4], on: device)
let hh = Tensor<Float>(randomUniform: [256, 256*4], on: device)
let b = Tensor<Float>(randomUniform: [256*4], on: device)
let hiddenSize = 256
let inputs: [Tensor<Float>] = (1...380).map { _ in Tensor(randomUniform: [128, 26], on: device) }

func run() {
    let start = Date().timeIntervalSince1970
    print(runPrediction(ih: ih, hh: hh, b: b, hiddenSize: hiddenSize, inputs: inputs).last!.shape)
    let end = Date().timeIntervalSince1970
    print(end - start)
}

run()
run()

Python gives me the following output:

2020-05-09 23:32:56.595915: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-05-09 23:32:56.606559: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f8598660a50 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-05-09 23:32:56.606573: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-05-09 23:32:56.857759: I tensorflow/compiler/jit/xla_compilation_cache.cc:242] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
(128, 256)
0.544409990310669
(128, 256)
0.27017807960510254

Swift (with .defaultTFEager) gives me:

[128, 256]
0.40241098403930664
[128, 256]
0.37356019020080566

Swift (with .defaultXLA) gives me:

2020-05-09 23:33:09.025920: I tensorflow/compiler/xla/xla_client/xrt_local_service.cc:54] Peer localservice 1 {localhost:35392}
2020-05-09 23:33:09.026451: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.2 AVX AVX2 FMA
2020-05-09 23:33:09.075110: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x11b126e30 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-05-09 23:33:09.075126: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-05-09 23:33:09.077361: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job localservice -> {0 -> localhost:35392}
2020-05-09 23:33:09.077748: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:390] Started server with target: grpc://localhost:35392
2020-05-09 23:33:09.084266: W tensorflow/compiler/jit/xla_device.cc:398] XLA_GPU and XLA_CPU devices are deprecated and will be removed in subsequent releases. Instead, use either @tf.function(experimental_compile=True) for must-compile semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 for auto-clustering best-effort compilation.
[128, 256]
1.3604919910430908
[128, 256]
1.3101921081542969

I was wondering, what causes so much extra overhead on the XLA device in Swift? Is this an issue with the way I've written my code, or is it an issue with how the tensors are implemented? If so, is it a known issue and are there plans to fix it soon?

If so, then I can use S4TF for training the LSTMs in my next project.

This experiment was run on a MacBook Pro, the file was run using swiftc -O main.swift && ./main.

Thanks!