NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

[PyTorch] Add option to pass kwargs to CUDA graph module #945

Open timmoon10 opened 2 weeks ago

timmoon10 commented 2 weeks ago

Description

This PR addresses a request to modify te.make_graphed_callables so we can pass in kwargs like the attention masks. See https://github.com/vasunvidia/TransformerEngine/commit/d0a10573b7e029b69c7c5d74e42afc9175b27282 for an initial implementation. If kwargs are provided in te.make_graphed_callables (via the sample_kwargs), they must also be provided whenever the graph is replayed. Note that only tensors are accepted as positional args or kwargs, since otherwise we run into another pile of design problems (what happens if the args differ during graph capture and graph replays?).

To be honest, I don't really like this approach. Ideally te.make_graphed_callables should match the API of torch.cuda.make_graphed_callables, which only supports positional args. But the best ways to handle modules with kwargs in plain PyTorch are creating wrappers that handle the kwargs:

# Non-graphed module
y = mymodule(x, key=val)

# Option 1: wrap module and call with positional args
class MyWrapper(torch.nn.Module):
    def __init__(self, module):
        self.module = module
    def forward(self, x, val):
        return self.module(x, key=val)
graphed_forward1 = torch.nn.make_graphed_callables(
    MyWrapper(mymodule),
    (x, val),
)
y = graphed_forward1(x, val)

# Option 2: wrap module and wrap graphed callable
graphed_forward2 = lambda x, *, val: graphed_forward1(x, val)
y = graphed_forward2(x, key=val)

This is quite clunky. If we accept API divergence from PyTorch, it becomes much cleaner:

# Option 3: modify API for make_graphed_callables
graphed_forward3 = te.make_graphed_callables(
    mymodule,
    x,
    sample_kwargs=dict(key=val),
)
y = graphed_forward3(x, key=val)

While I was touching the code, I also commented and added tests for the custom integration with Megatron-LM interleaved pipeline parallelism.

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

timmoon10 commented 1 day ago

/te-ci pytorch