triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.94k stars 1.57k forks source link

Feature Request: Optionally clone tensors before autotune for in-place operations #1563

Open orsharir opened 1 year ago

orsharir commented 1 year ago

There's no (apparent?) way to use autotune for in-place kernels that read and then write back the result to one of the tensor arguments, as each evaluation will apply the kernel again and again to the same set of arguments producing a different output depending on the number of evaluations.

Suggestion: Just as there's a reset_to_zero optional argument to autotune(), there should be an optional clone_before_evaluation argument, that creates a cloned version of the specified arguments and autotune on those, before a final run on the original arguments.

As a temporary workaround one can of course always avoid in-place updates, but then the kernel requires additional memory allocations.

pommedeterresautee commented 1 year ago

In triton ops, you can see how to zero the output before each auto tune confit

https://github.com/openai/triton/blob/main/python/triton/ops/matmul.py

It may help regarding your need

orsharir commented 1 year ago

I'm aware of the reset_to_zero option, but that doesn't help with in-place operation. For example, let's say I want to implement in-place scalar multiplication, i.e., v[:] = s. Using reset_to_zero will zero-out v. If not doing anything then we'll get v s ^k, where k is the number of configurations autotune evaluates. That's why I'm suggesting adding a clone_before_evaluation option to be used for benchmarking in autotune.

pommedeterresautee commented 1 year ago

Not sure to follow you, but reset to zero is just a specific implementation for this specific kernel, the real feature is the hook, you provide whatever function you want, including copying some tensor to some other for instance (to reset its state)

orsharir commented 1 year ago

I was referring to the optional flag reset_to_zero that you can pass to triton.autotune (https://triton-lang.org/main/python-api/generated/triton.autotune.html). I wasn't aware there's an additional functionality in triton.Config for specifying a pre_hook. Forgive me for assuming that you were just referring to the built-in reset_to_zero.

However, I'm still not sure if this pre_hook functionality will work for in-place ops. First, pre_hook is implemented such that it can only apply in-place operations on the arguments, so it doesn't seem like you could use clone in this setting. Let's say it was corrected to return the arguments as well so cloning is possible. Then, it appears that pre_hook will be called every time the kernel is launched, regardless if it's launched during benchmarking or during a regular call. Using clone() in pre_hook will then result in a kernel that has no effect on the inputs to an in-place operation -- it will effectively be a no-op -- even after the kernel was benchmarked.

It seem that the role of pre_hook is to specify an optional behavior for the kernel itself rather than for how to benchmark it correctly. the reset_to_zero flag uses a different mechanism in Autotuner() to only reset the specified arguments during benchmarking to ensure multiple evaluations won't result in incorrect output the first time the kernel is used (for given input dimensions). If the cache key for the config already exists then no action is taken. I think that that's the kind of functionality that is needed here.

Based on the pre-existing pre_hook flag to Config, then perhaps a more general pre_benchmark_hook (or setup_benchmark_hook) can be added, if the specific case of cloning seem too narrow. This flag will receive as input self.nargs (as pre_hook) but will also return a new dictionary of named arguments that will be used as the arguments only before benchmarking a specific config, not during regular calls.

If interested and the proposed approach seem reasonable, then I'm willing to prepare a PR. Either the more general pre_benchmark_hook or the more specific clone_before_benchmarking flag.

pommedeterresautee commented 1 year ago

The reset function looks like:

def init_to_zero(name):
    return lambda nargs: nargs[name].zero_()

First, pre_hook is implemented such that it can only apply in-place operations on the arguments

Basically, init_to_zero returns a callable which will receive a dict of args (nargs), and it will be called during the benchmark process (through kernel_call). https://github.com/openai/triton/blob/19e7238d50ae1736cf4b5185d36599b80c4d12ef/python/triton/runtime/autotuner.py#L76 You are free to put whatever you want inside the callable as the mechanism is quite generic, it can apply more than one operation to more than one argument. It may include copying content of tensor A to tensor B

Then, it appears that pre_hook will be called every time the kernel is launched, regardless if it's launched during benchmarking or during a regular call.

AFAIK pre hook is called during benchmarking only. And even in this case, nothing stops you to put some if inside the callable depending of some variable captured in closure or provided in nargs to not run it when it makes sense to not run it.

https://github.com/openai/triton/blob/19e7238d50ae1736cf4b5185d36599b80c4d12ef/python/triton/runtime/autotuner.py#L74-L80C5

orsharir commented 1 year ago
  1. It definitely appears to be called every time, even after benchmarking, see here: https://github.com/openai/triton/blob/19e7238d50ae1736cf4b5185d36599b80c4d12ef/python/triton/runtime/autotuner.py#LL109C1-L109C1
  2. Yes, it's possible to copy data between tensors, but that means that I'll need to have additional "dummy" tensors passed to kernel just to support auto-tuning in-place operations.
  3. Similarly, I could implement a number of workarounds, but they will end up being very dependent on the specific inner implementation of autotuner, whereas I prefer to work with a public API.

Ultimately, I think this is a very common use case, not some rare edge case, and so there's not reason Triton shouldn't support it. In-place operations are common and help avoid excessive memory allocations. I'm very happy to implement this and submit a PR, but I would appreciate comments on the proposed solution before I work on this to ensure it won't be for nothing.

pommedeterresautee commented 1 year ago

Oh you are right, sorry I went too fast.

If I may, regarding the proposition, I agree that implementing tricks which relies on inner implementation of auto tuner is a bit fragile.

Adding (for instance) a new hook specific for some kind of kernels may make things a bit more complicated to use and maintain.

However if you could make a single API easier to adapt to different use cases (and well documented / obvious to understand), it may be a great win for everybody.