A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
This PR addresses a request to modify te.make_graphed_callables so we can pass in kwargs like the attention masks. See https://github.com/vasunvidia/TransformerEngine/commit/d0a10573b7e029b69c7c5d74e42afc9175b27282 for an initial implementation. If kwargs are provided in te.make_graphed_callables (via the sample_kwargs), they must also be provided whenever the graph is replayed. Note that only tensors are accepted as positional args or kwargs, since otherwise we run into another pile of design problems (what happens if the args differ during graph capture and graph replays?).
To be honest, I don't really like this approach. Ideally te.make_graphed_callables should match the API of torch.cuda.make_graphed_callables, which only supports positional args. But the best ways to handle modules with kwargs in plain PyTorch are creating wrappers that handle the kwargs:
# Non-graphed module
y = mymodule(x, key=val)
# Option 1: wrap module and call with positional args
class MyWrapper(torch.nn.Module):
def __init__(self, module):
self.module = module
def forward(self, x, val):
return self.module(x, key=val)
graphed_forward1 = torch.nn.make_graphed_callables(
MyWrapper(mymodule),
(x, val),
)
y = graphed_forward1(x, val)
# Option 2: wrap module and wrap graphed callable
graphed_forward2 = lambda x, *, val: graphed_forward1(x, val)
y = graphed_forward2(x, key=val)
This is quite clunky. If we accept API divergence from PyTorch, it becomes much cleaner:
# Option 3: modify API for make_graphed_callables
graphed_forward3 = te.make_graphed_callables(
mymodule,
x,
sample_kwargs=dict(key=val),
)
y = graphed_forward3(x, key=val)
While I was touching the code, I also commented and added tests for the custom integration with Megatron-LM interleaved pipeline parallelism.
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[ ] Bug fix (non-breaking change which fixes an issue)
[x] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
[ ] Infra/Build change
[x] Code refractor
Changes
Please list the changes introduced in this PR:
Add option to pass kwargs to CUDA graph module
Add tests for CUDA graph integration with Megatron-LM interleaved pipeline parallelism
Description
This PR addresses a request to modify
te.make_graphed_callables
so we can pass in kwargs like the attention masks. See https://github.com/vasunvidia/TransformerEngine/commit/d0a10573b7e029b69c7c5d74e42afc9175b27282 for an initial implementation. If kwargs are provided inte.make_graphed_callables
(via thesample_kwargs
), they must also be provided whenever the graph is replayed. Note that only tensors are accepted as positional args or kwargs, since otherwise we run into another pile of design problems (what happens if the args differ during graph capture and graph replays?).To be honest, I don't really like this approach. Ideally
te.make_graphed_callables
should match the API oftorch.cuda.make_graphed_callables
, which only supports positional args. But the best ways to handle modules with kwargs in plain PyTorch are creating wrappers that handle the kwargs:This is quite clunky. If we accept API divergence from PyTorch, it becomes much cleaner:
While I was touching the code, I also commented and added tests for the custom integration with Megatron-LM interleaved pipeline parallelism.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: