CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.51k stars 471 forks source link

Faster & memory-efficient logprobs calculation #583

Open li-plus opened 12 months ago

li-plus commented 12 months ago

The current logprobs_of_labels computes logprobs using a log_softmax followed by a gather. When the input logits is not contiguous, the log_softmax will make a copy of the logits, which is very large (batch_size seq_len vocab_size can be 32 2048 64000 * 2B = 8GB for typical settings).

This PR directly feeds the contiguous logits into log_softmax so as to reduce the peak cuda memory and remove redundant copy.

Test script:

import torch
from torch.utils.benchmark import Timer
from trlx.utils.modeling import logprobs_of_labels

def perf():
    batch_size, seq_len, vocab_size = 32, 2048, 64000
    logits = torch.randn((batch_size, seq_len, vocab_size), dtype=torch.half, device='cuda')
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long, device='cuda')

    # correctness
    assert torch.allclose(logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:]), logprobs_of_labels(logits, input_ids[:, 1:]))

    # peak memory test
    torch.cuda.empty_cache()
    logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])
    print(f'original allocated: {torch.cuda.memory_allocated() / 1e9:.3f} GB, reserved: {torch.cuda.memory_reserved() / 1e9:.3f} GB')

    torch.cuda.empty_cache()
    logprobs_of_labels(logits, input_ids[:, 1:])
    print(f'optimized allocated: {torch.cuda.memory_allocated() / 1e9:.3f} GB, reserved: {torch.cuda.memory_reserved() / 1e9:.3f} GB')

    # speed test
    timer = Timer(stmt="logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])", globals={**globals(), **locals()})
    elapsed_org = timer.timeit(100).mean
    print(f'original costs: {elapsed_org:.4f} s')

    timer = Timer(stmt="logprobs_of_labels(logits, input_ids[:, 1:])", globals={**globals(), **locals()})
    elapsed_opt = timer.timeit(100).mean
    print(f'optimized costs: {elapsed_opt:.4f} s')

perf()

Tested on a Tesla V100, method in this PR is both faster (1.6x speedup) and memory-efficient.

original allocated: 8.389 GB, reserved: 25.164 GB
optimized allocated: 8.389 GB, reserved: 16.779 GB
original costs: 0.0700 s
optimized costs: 0.0435 s
codecov-commenter commented 12 months ago

Codecov Report

Attention: 6 lines in your changes are missing coverage. Please review.

Comparison is base (91a0f43) 43.58% compared to head (730d900) 43.58%. Report is 1 commits behind head on main.

:exclamation: Current head 730d900 differs from pull request most recent head aa1031a. Consider uploading reports for the commit aa1031a to get more accurate results

Files Patch % Lines
trlx/models/modeling_nemo_ppo.py 0.00% 3 Missing :warning:
trlx/trainer/accelerate_ppo_trainer.py 57.14% 3 Missing :warning:

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #583 +/- ## ======================================= Coverage 43.58% 43.58% ======================================= Files 33 33 Lines 4974 4974 ======================================= Hits 2168 2168 Misses 2806 2806 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.