feiwang3311 / Lantern

BSD 3-Clause "New" or "Revised" License
167 stars 15 forks source link

Accelerated Backends (CPU & GPU) #8

Open TiarkRompf opened 5 years ago

TiarkRompf commented 5 years ago

High priority:

CUDA backend todos:

Medium priority:

Other potential back-ends to look into: TVM, XLA, ...

dan-zheng commented 5 years ago

A robust solution for GPU support needs to reconcile CPU-only operations (e.g. printf) and operations that should be run on GPU (e.g. cuBLAS library functions).

For cuBLAS/cuDNN support, I think I'll start with the naive implementation of allocating all tensors on GPU memory. This is the shortest path to testing GPU code generation.

However, this essentially breaks all operations that aren't defined with GPU memory in mind: printf certainly won't work (unless it's modified to copy tensors to CPU memory) and even ops like elementwise addition need to be rewritten using library functions like cublasSaxpy.

Redefining many ops for GPU support greatly increases the surface area of the Backend trait, which is not ideal. If you have ideas for avoiding this, or if you have other ideas/feedback about backend support, please share!

dan-zheng commented 5 years ago

@jmd1011 had the idea of running all ops on the CPU by default, and only using GPU ops within an explicitly demarcated section of code (e.g. a training loop). I feel like this design is facilitated by the flexible Backend trait implementation: simply change the backend value to change the codegen target.

This approach leads to a better separation of concerns: rather than handling arbitrary mixings of CPU and GPU ops (which effectively requires each op to worry about the device allocation of its arguments and result), only "chunks" of CPU and GPU code are handled (ops assume tensors all live on either CPU or GPU). This means that the backend-swapping code is responsible for handling "copying tensors between devices" (rather than every single op).

// Adapted from mnistCNN.scala.
val mnist = new DslDriverC[String, Unit] with TensorExp {
  def snippet(a: Rep[String]): Rep[Unit] = {
    // The backend is initially CPU (`var backend: BackendNative`).
    val data = new DataLoader("mnist")
    ...

    // Start training loop. Generate GPU ops!
    backend = new BackendCudnn
    for (epoch <- 0 until epochCount: Rep[Range]) {
       data foreach { (input: Tensor, target: Rep[Int]) =>
         // It's nice to have a way to print values within the training loop.
         // Some ad-hoc mechanism for communication would be good.
         // Strawman syntax:
         // `printf("Loss: %f\n", loss.toCPU())`
         ...
       }
    }

    // Change backend back to CPU.
    backend = new BackendNative
    printf(...)
  }
}

This idea seems similar to "device placement" in TensorFlow:

results = []
a = tf.get_variable("a", ...)
b = tf.get_variable("b", ...)

# GPU 0 performs matmul.
with tf.device('/gpu:0'):
    results.append(tf.matmul(a, b))

# GPU 1 performs addition.
with tf.device('/gpu:1'):
    results.append(a + b)

# TensorFlow handles copying tensors between devices.
with tf.device('/cpu:0'):
    sum = tf.add_n(results)

Here's the equivalent feature in Swift for TensorFlow. It should be possible to implement a similar API in Lantern:

Original incomplete prototype
// Not sure what the type of `f` should be. Any tips?
def withBackend(b: Backend, f: ??? -> ???) = {
  val originalBackend = backend
  // Copy tensors to the new backend.
  // Question: what tensors need to copied?
  // Answer: the ones that are passed as arguments to `f`.
  // Change the backend (i.e. codegen target).
  backend = b
  // Call `f`.
  val result = f(...)
  // Copy `result` to the old backend, then reset the backend.
  backend = originalBackend
}


// Revised based on @GSAir's suggestion below.
def withBackend[T, U](b: Backend, input: T)(f: T => U) = {
  val originalBackend = backend
  // Transfer input to the new backend.
  transferBetweenBackends(originalBackend, b, input)

  // Change the backend (i.e. codegen target), then call `f`.
  backend = b
  val result = f(input)

  // Transfer `result` to the old backend, then reset the backend.
  transferBetweenBackends(b, originalBackend, result)
  backend = originalBackend
}

// Usage:
def withGPU[T, U](input: T)(f: T => U) = withBackend(BackendCudnn, input)(f)

// Type-inference: `withGPU[Tensor, Tensor]` below.
withGPU(Tensor.ones(2, 3)) { x => x + x }

If you have feedback or related ideas, please share!

GSAir commented 5 years ago

For the type of f, you may want to be flexible:

def withBackend[T,U](b: Backend, input: T)(f: T => U) = {
}

The currying form allow you to do:

withBackend[Int, Unit](CPU, 0) { in =>
    printf("%d\n", in)
}
dan-zheng commented 5 years ago

I propose to change the cuDNN backend into a cuBLAS+cuDNN backend.

cuDNN by itself defines high-level NN operations, like convolutions and activation functions. However, it doesn't define lower-level primitives, like matrix multiplication or basic elementwise ops. Thus, a standalone cuDNN backend is not particularly useful.

A cuBlas+cuDNN backend can use cuBLAS for low-level linear algebra primitives and cuDNN for optimized high-level NN ops.

feiwang3311 commented 5 years ago

https://github.com/feiwang3311/Lantern/blob/master/src/main/scala/lantern/ad_lms_vector.scala#L522

@dan-zheng should this line be comparing this.shape(1) with that.shape(0)?

dan-zheng commented 5 years ago

Thanks for the catch! Fixed in https://github.com/feiwang3311/Lantern/commit/db0a80fc93761510c99a80b1daf914d76b33a182.

TiarkRompf commented 5 years ago

I propose to change the cuDNN backend into a cuBLAS+cuDNN backend.

Absolutely makes sense. The use case I had in mind was cuBLAS without cuDNN, but that's covered with BackendCudnn extends BackendCublas.

dan-zheng commented 5 years ago

FYI: I added a concrete todo list to the issue description. Preliminary MNIST CNN support is nearly done. Afterwards, we can evaluate performance and optimize.