Open kevinstephano opened 5 months ago
2 changes need to be applied to the script above:
In function _single_tensor_adam
, we can rewrite the program as
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# update step
#step_t = thunder.prims.copy_(step_t + 1, step_t)
step_t_ = step_t + 1
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
# exp_avg = thunder.prims.copy_(exp_avg * (beta1) + grad * (1 - beta1), exp_avg)
exp_avg_ = exp_avg * (beta1) + grad * (1 - beta1)
# exp_avg_sq = thunder.prims.copy_((exp_avg_sq * beta2) + (1 - beta2) * grad * grad, exp_avg_sq)
exp_avg_sq_ = (exp_avg_sq * beta2) + (1 - beta2) * grad * grad
#step = step_t
step = step_t_
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq_, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt) + (eps)
else:
denom = (exp_avg_sq_.sqrt() / bias_correction2_sqrt) + (eps)
#param = thunder.prims.copy_(param + (-step_size) * exp_avg / denom, param)
param_ = param + (-step_size) * exp_avg_ / denom
thunder.prims.copy_(param_, param)
thunder.prims.copy_(exp_avg_, exp_avg)
thunder.prims.copy_(exp_avg_sq_, exp_avg_sq)
thunder.prims.copy_(step_t_, step_t)
n_param
needs to be updatedThe example has 2 parameters,
FYI, n_param = 2
should be the right parameter to use in the script above. Right now it's running with a single parameter (so there's no segmentation).
I tried Jie's revised version of the script and it is segmented due to Scheduler _pointwise_ ***rejected*** because : Connected fusion graph check failed!
. This is a known issue as we don't fuse ops horizontally. Do we want to support this type of fusion?
recording offline discussion. @liqiangxl mentioned that after the code change we seems to be getting only 2 segments. I'll confirm and maybe we can close this one. :crossed_fingers:
I think there's a fundamental piece of information lost when converted from Thunder to nvFuser, which is in the Thunder level, the program is clearly SPMD, whereas when converted to an nvFuser program, that fact no longer exists. We could try to rediscover the fact, but it seems it's more robust and easier to not lose the information by adding some nvFuser primitives to represent such properties.
Could we convert the Thunder program to a batched representation?
Could we convert the Thunder program to a batched representation?
What would a batched representation
look like?
A for-loop with a block?
Or are we assuming all entries in params
share the same shape and we should convert that as batched_params = torch.cat(params, 0)
. I'm not sure if this assumption holds though.
Could we convert the Thunder program to a batched representation?
What would a
batched representation
look like?A for-loop with a block? Or are we assuming all entries in
params
share the same shape and we should convert that asbatched_params = torch.cat(params, 0)
. I'm not sure if this assumption holds though.
I thought these are the inputs, so I assumed it does:
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
Could we convert the Thunder program to a batched representation?
What would a
batched representation
look like? A for-loop with a block? Or are we assuming all entries inparams
share the same shape and we should convert that asbatched_params = torch.cat(params, 0)
. I'm not sure if this assumption holds though.I thought these are the inputs, so I assumed it does:
grad = grads[i] if not maximize else -grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i]
A little lost here.
so grad[i]
, exp_avg[i]
, exp_avg_sq[i]
, params[i]
all share the same shape.
But that's not how the program is parallelized.
We need to have params[i]
and params[j]
to have the same shape in order to batch the program.
I admit that they happen to be the same shape in this example vvv. But that's not guaranteed.
params = [torch.randn(dim, dim, device=device, requires_grad=True) for _ in range(n_param)]
We need to have
params[i]
andparams[j]
to have the same shape in order to batch the program
This is what I assumed. Isn't it guaranteed? Isn't params
here just a normal multidimensional tensor?
We need to have
params[i]
andparams[j]
to have the same shape in order to batch the programThis is what I assumed. Isn't it guaranteed? Isn't
params
here just a normal multidimensional tensor?
Sorry, that was my misunderstanding, but isn't it a common case?
We need to have
params[i]
andparams[j]
to have the same shape in order to batch the programThis is what I assumed. Isn't it guaranteed? Isn't
params
here just a normal multidimensional tensor?Sorry, that was my misunderstanding, but isn't it a common case?
No. params
is a list of tensors (trainable parameters?!), where each tensor could be of a different shape. Think about how you define a single linear layer where you can have weight/bias being of different shapes.
we should convert that as
batched_params = torch.cat(params, 0)
this reads like the code will use much more memory
we should convert that as
batched_params = torch.cat(params, 0)
this reads like the code will use much more memory
Not necessarily. We don't need to instantiate concatenated tensors.
I want to emphasize that the goal of this issue was not to solve broader situation where the optimizer gives nvFuser multiple parameters and turns those multiple parameters into one kernel. We can address fusing multiple parameters, separately. I wanted to address the fact that we were segmenting across what should have been one kernel for each parameter, first, so the performance was not horrific.
I want to emphasize that the goal of this issue was not to solve broader situation where the optimizer gives nvFuser multiple parameters and turns those multiple parameters into one kernel. We can address fusing multiple parameters, separately. I wanted to address the fact that we were segmenting across what should have been one kernel for each parameter, first, so the performance was not horrific.
I realized that I just forgot to respond to the original comment.
I can confirm that with the rewrite seems to resolve the segmenter issue. I tried the rewrite program with n_params = 5
and I'm seeing 5 generated kernel.
Numerics seems to be a bit flaky, but since it's random number I don't know how much we can read into this.
Mismatched elements: 3072 / 1048576 (0.3%)
Greatest absolute difference: 0.00025594234466552734 at index (988, 516) (up to 1e-05 allowed)
Greatest relative difference: 0.551850438117981 at index (902, 778) (up to 1.3e-06 allowed)
So I think for @kevinstephano 's question here:
I want to emphasize that the goal of this issue was not to solve broader situation where the optimizer gives nvFuser multiple parameters and turns those multiple parameters into one kernel. We can address fusing multiple parameters, separately. I wanted to address the fact that we were segmenting across what should have been one kernel for each parameter, first, so the performance was not horrific.
We can rely on the rewrite here: https://github.com/NVIDIA/Fuser/issues/2112#issuecomment-2071070854 and close this issue for now?
We can revisit it later when we need to squeeze more perf. created a new label optimizer
so it's easier for us to find this one.
Goal
The goal of this issue is to determine why we segment the Optimizer Code. We will likely need to determine an appropriate solution with @wujingyue after determine why the segmentation is occurring.
Background
Thunder is attempting to implement a fused Adam optimizer and each parameter should result in a kernel. The example has 2 parameters, therefore, we would expect 2 kernels. 1 for each parameter.
Example
Here is the example code: @jjsjann123 could you make this inplace-safe?