NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
257 stars 51 forks source link

Adam Optimizer Code from a Thunder example segments with >2 Parameters #2112

Open kevinstephano opened 5 months ago

kevinstephano commented 5 months ago

Goal

The goal of this issue is to determine why we segment the Optimizer Code. We will likely need to determine an appropriate solution with @wujingyue after determine why the segmentation is occurring.

Background

Thunder is attempting to implement a fused Adam optimizer and each parameter should result in a kernel. The example has 2 parameters, therefore, we would expect 2 kernels. 1 for each parameter.

Example

Here is the example code: @jjsjann123 could you make this inplace-safe?

import torch
from torch import Tensor
from typing import List, Union, Optional
import math
import thunder
from torch.utils.benchmark import Timer
from thunder.tests.litgpt_model import Config, GPT

def _dispatch_sqrt(x: float):
    if isinstance(x, Tensor):
        return x.sqrt()
    return math.sqrt(x)

def _single_tensor_adam(
    params: list[Tensor],
    grads: list[Tensor],
    exp_avgs: list[Tensor],
    exp_avg_sqs: list[Tensor],
    max_exp_avg_sqs: list[Tensor],
    state_steps: list[Tensor],
    grad_scale: Tensor | None,
    found_inf: Tensor | None,
    *,
    amsgrad: bool,
    has_complex: bool,
    beta1: float,
    beta2: float,
    lr: float | Tensor,
    weight_decay: float,
    eps: float,
    maximize: bool,
    capturable: bool,
    differentiable: bool,
):

    for i, param in enumerate(params):
        grad = grads[i] if not maximize else -grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step_t = state_steps[i]

        # update step
        step_t = thunder.prims.copy_(step_t + 1, step_t)

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg = thunder.prims.copy_(exp_avg * (beta1) + grad * (1 - beta1), exp_avg)
        # exp_avg = exp_avg * (beta1) + grad * (1 - beta1)

        exp_avg_sq = thunder.prims.copy_((exp_avg_sq * beta2) + (1 - beta2) * grad * grad, exp_avg_sq)
        # exp_avg_sq = (exp_avg_sq * beta2) + (1 - beta2) * grad * grad

        step = step_t
        bias_correction1 = 1 - beta1**step
        bias_correction2 = 1 - beta2**step

        step_size = lr / bias_correction1

        bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)

        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])

            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt) + (eps)
        else:
            denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt) + (eps)

        param = thunder.prims.copy_(param + (-step_size) * exp_avg / denom, param)

def get_jit_state(jit_params):
    steps = []
    exp_avgs = []
    exp_avg_sqs = []
    max_exp_avg_sqs = []
    for param in jit_params:
        step = torch.tensor(0.0, dtype=torch.float, device=device)
        exp_avg = torch.zeros_like(param, memory_format=torch.preserve_format)
        exp_avg_sq = torch.zeros_like(param)
        max_exp_avg_sq = None  # torch.zeros_like(param, memory_format=torch.preserve_format)

        steps.append(step)
        exp_avgs.append(exp_avg)
        exp_avg_sqs.append(exp_avg_sq)
        max_exp_avg_sqs.append(max_exp_avg_sq)

    return steps, exp_avgs, exp_avg_sqs, max_exp_avg_sqs

def jit_step(optim_func, jit_params, jit_state):
    jit_grads = [param.grad for param in jit_params]
    steps, exp_avgs, exp_avg_sqs, max_exp_avg_sqs = jit_state
    # with torch.no_grad():
    optim_func(
        jit_params,
        jit_grads,
        exp_avgs=exp_avgs,
        exp_avg_sqs=exp_avg_sqs,
        max_exp_avg_sqs=max_exp_avg_sqs,
        state_steps=steps,
        grad_scale=None,
        found_inf=None,
        amsgrad=False,
        has_complex=False,
        beta1=0.9,
        beta2=0.999,
        lr=0.001,
        weight_decay=0,
        eps=1e-8,
        maximize=False,
        capturable=False,
        differentiable=False,
    )

device = "cuda"
dim = 1024
n_param = 1
params = [torch.randn(dim, dim, device=device, requires_grad=True) for _ in range(n_param)]

# model_name = "open_llama_3b"
# m = GPT(Config.from_name(model_name)).to(device)
# params = list(m.parameters())
# params = params[:50]

with torch.no_grad():
    orig_param = params[0].clone().detach()

with torch.no_grad():
    jit_params = [param.clone().detach() for param in params]

for param in jit_params:
    param.requires_grad_(True)

def computation_and_backward(params):
    result = torch.empty_like(params[0])
    for param in params:
        result = result + param

    result.sum().backward()

def attach_grads(params):
    for param in params:
        param.grad = torch.randn_like(param)

def copy_grads(params, jit_params):
    for param, jit_param in zip(params, jit_params):
        jit_param.grad = param.grad.clone().detach()

adam = torch.optim.Adam(params, fused=True)

# computation_and_backward(params)
# computation_and_backward(jit_params)

attach_grads(params)
copy_grads(params, jit_params)

# 2 iterations
adam.step()
adam.step()

# # thunder.set_execution_callback_file("foo.py")
optim_func = thunder.jit(_single_tensor_adam)
jit_state = get_jit_state(jit_params)

# 2 iterations
jit_step(optim_func, jit_params, jit_state)
jit_step(optim_func, jit_params, jit_state)

torch.testing.assert_close(params, jit_params)

# print(params[0][0], jit_params[0][0])

# Verify that params have changed.
# This should crash
# torch.testing.assert_close(params[0], orig_param)

native_time = Timer(stmt="adam.step()", globals={"adam": adam}).timeit(number=100)
jit_time = Timer(
    stmt="jit_step(optim_func, jit_params, jit_state)",
    globals={"jit_step": jit_step, "jit_params": jit_params, "jit_state": jit_state, "optim_func": optim_func},
).timeit(number=100)

# Sanity after 100 iterations
torch.testing.assert_close(params, jit_params, rtol=1e-4, atol=1e-4)
# print(params[0][0], jit_params[0][0])

print(native_time)
print(jit_time)

exec_trace = thunder.last_traces(optim_func)[-1]

with open("generated_thunder_trace.py", "w") as f:
    f.write(str(exec_trace))

with open("generated_fusion_defintion.py", "w") as f:
    f.write(str(exec_trace.python_ctx()["nvFusion0"].last_used))

with open("generated_kernels.cu", "w") as f:
    f.write(exec_trace.python_ctx()["nvFusion0"].last_used.last_cuda_code())

print("Done")
jjsjann123 commented 5 months ago

2 changes need to be applied to the script above:

  1. making the inplace update safe we just need to ensure there's a clear dependency in the IO_buffer.

In function _single_tensor_adam, we can rewrite the program as

    for i, param in enumerate(params):
        grad = grads[i] if not maximize else -grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step_t = state_steps[i]

        # update step
        #step_t = thunder.prims.copy_(step_t + 1, step_t)
        step_t_ = step_t + 1

        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)

        # Decay the first and second moment running average coefficient
        # exp_avg = thunder.prims.copy_(exp_avg * (beta1) + grad * (1 - beta1), exp_avg)
        exp_avg_ = exp_avg * (beta1) + grad * (1 - beta1)

        # exp_avg_sq = thunder.prims.copy_((exp_avg_sq * beta2) + (1 - beta2) * grad * grad, exp_avg_sq)
        exp_avg_sq_ = (exp_avg_sq * beta2) + (1 - beta2) * grad * grad

        #step = step_t
        step = step_t_
        bias_correction1 = 1 - beta1**step
        bias_correction2 = 1 - beta2**step

        step_size = lr / bias_correction1

        bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)

        if amsgrad:
            # Maintains the maximum of all 2nd moment running avg. till now
            torch.maximum(max_exp_avg_sqs[i], exp_avg_sq_, out=max_exp_avg_sqs[i])

            # Use the max. for normalizing running avg. of gradient
            denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt) + (eps)
        else:
            denom = (exp_avg_sq_.sqrt() / bias_correction2_sqrt) + (eps)

        #param = thunder.prims.copy_(param + (-step_size) * exp_avg / denom, param)
        param_ = param + (-step_size) * exp_avg_ / denom
        thunder.prims.copy_(param_, param)
        thunder.prims.copy_(exp_avg_, exp_avg)
        thunder.prims.copy_(exp_avg_sq_, exp_avg_sq)
        thunder.prims.copy_(step_t_, step_t)
  1. n_param needs to be updated

The example has 2 parameters,

FYI, n_param = 2 should be the right parameter to use in the script above. Right now it's running with a single parameter (so there's no segmentation).

liqiangxl commented 5 months ago

I tried Jie's revised version of the script and it is segmented due to Scheduler _pointwise_ ***rejected*** because : Connected fusion graph check failed!. This is a known issue as we don't fuse ops horizontally. Do we want to support this type of fusion?

jjsjann123 commented 5 months ago

recording offline discussion. @liqiangxl mentioned that after the code change we seems to be getting only 2 segments. I'll confirm and maybe we can close this one. :crossed_fingers:

naoyam commented 5 months ago

I think there's a fundamental piece of information lost when converted from Thunder to nvFuser, which is in the Thunder level, the program is clearly SPMD, whereas when converted to an nvFuser program, that fact no longer exists. We could try to rediscover the fact, but it seems it's more robust and easier to not lose the information by adding some nvFuser primitives to represent such properties.

naoyam commented 5 months ago

Could we convert the Thunder program to a batched representation?

jjsjann123 commented 5 months ago

Could we convert the Thunder program to a batched representation?

What would a batched representation look like?

A for-loop with a block? Or are we assuming all entries in params share the same shape and we should convert that as batched_params = torch.cat(params, 0). I'm not sure if this assumption holds though.

naoyam commented 5 months ago

Could we convert the Thunder program to a batched representation?

What would a batched representation look like?

A for-loop with a block? Or are we assuming all entries in params share the same shape and we should convert that as batched_params = torch.cat(params, 0). I'm not sure if this assumption holds though.

I thought these are the inputs, so I assumed it does:

        grad = grads[i] if not maximize else -grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step_t = state_steps[i]
jjsjann123 commented 5 months ago

Could we convert the Thunder program to a batched representation?

What would a batched representation look like? A for-loop with a block? Or are we assuming all entries in params share the same shape and we should convert that as batched_params = torch.cat(params, 0). I'm not sure if this assumption holds though.

I thought these are the inputs, so I assumed it does:

        grad = grads[i] if not maximize else -grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step_t = state_steps[i]

A little lost here.

so grad[i], exp_avg[i], exp_avg_sq[i], params[i] all share the same shape.

But that's not how the program is parallelized. We need to have params[i] and params[j] to have the same shape in order to batch the program.

I admit that they happen to be the same shape in this example vvv. But that's not guaranteed. params = [torch.randn(dim, dim, device=device, requires_grad=True) for _ in range(n_param)]

naoyam commented 5 months ago

We need to have params[i] and params[j] to have the same shape in order to batch the program

This is what I assumed. Isn't it guaranteed? Isn't params here just a normal multidimensional tensor?

naoyam commented 5 months ago

We need to have params[i] and params[j] to have the same shape in order to batch the program

This is what I assumed. Isn't it guaranteed? Isn't params here just a normal multidimensional tensor?

Sorry, that was my misunderstanding, but isn't it a common case?

jjsjann123 commented 5 months ago

We need to have params[i] and params[j] to have the same shape in order to batch the program

This is what I assumed. Isn't it guaranteed? Isn't params here just a normal multidimensional tensor?

Sorry, that was my misunderstanding, but isn't it a common case?

No. params is a list of tensors (trainable parameters?!), where each tensor could be of a different shape. Think about how you define a single linear layer where you can have weight/bias being of different shapes.

crcrpar commented 5 months ago

we should convert that as batched_params = torch.cat(params, 0)

this reads like the code will use much more memory

naoyam commented 5 months ago

we should convert that as batched_params = torch.cat(params, 0)

this reads like the code will use much more memory

Not necessarily. We don't need to instantiate concatenated tensors.

kevinstephano commented 5 months ago

I want to emphasize that the goal of this issue was not to solve broader situation where the optimizer gives nvFuser multiple parameters and turns those multiple parameters into one kernel. We can address fusing multiple parameters, separately. I wanted to address the fact that we were segmenting across what should have been one kernel for each parameter, first, so the performance was not horrific.

jjsjann123 commented 4 months ago

I want to emphasize that the goal of this issue was not to solve broader situation where the optimizer gives nvFuser multiple parameters and turns those multiple parameters into one kernel. We can address fusing multiple parameters, separately. I wanted to address the fact that we were segmenting across what should have been one kernel for each parameter, first, so the performance was not horrific.

I realized that I just forgot to respond to the original comment.

I can confirm that with the rewrite seems to resolve the segmenter issue. I tried the rewrite program with n_params = 5 and I'm seeing 5 generated kernel.

Numerics seems to be a bit flaky, but since it's random number I don't know how much we can read into this.

Mismatched elements: 3072 / 1048576 (0.3%)
Greatest absolute difference: 0.00025594234466552734 at index (988, 516) (up to 1e-05 allowed)
Greatest relative difference: 0.551850438117981 at index (902, 778) (up to 1.3e-06 allowed)
jjsjann123 commented 4 months ago

So I think for @kevinstephano 's question here:

I want to emphasize that the goal of this issue was not to solve broader situation where the optimizer gives nvFuser multiple parameters and turns those multiple parameters into one kernel. We can address fusing multiple parameters, separately. I wanted to address the fact that we were segmenting across what should have been one kernel for each parameter, first, so the performance was not horrific.

We can rely on the rewrite here: https://github.com/NVIDIA/Fuser/issues/2112#issuecomment-2071070854 and close this issue for now?

We can revisit it later when we need to squeeze more perf. created a new label optimizer so it's easier for us to find this one.