tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
471 stars 74 forks source link

For Larger Tensor, Memory Efficiency Is Terrible #6212

Open SeanNijjar opened 8 months ago

SeanNijjar commented 8 months ago

I've noticed when running some of the larger all-gather configs in test_all_gather.py, that the host memory usage explodes, relative to the size of the tensors in the test.

This behaviour seems to be somewhat non-determinstic in that it's harder to see it when running the test configuration standalone. However, when running in a sequence with the other configs in the sweep, I do notice it seems to be a lot more likely to happen.

For example, this is one such config: input_shape=[1, 1, 32768, 32768] dim=3 layout=ttl.tensor.Layout.TILE num_links=1 num_devices=4 mem type=DRAM Bfloat16

What I would expect for this test would be 10GB consumed + overhead (2GB for the input tensor, then 2 GB for each of the four output tensors = 10GB total). For the runs with this behaviour, I'll sometimes see it reach 70-90GB consumed by the pytest process.

I assume this issue is somewhere in the tt_lib interface. Either that or given it's most likely to show up when running a series of configs, then maybe we're not garbage collecting/freeing old tensor data from prior tests and it's causing a pathological case for an allocator or something.

I haven't profiled memory usage at all yet, given that this currently isn't on critical path. However, it may rear its head as we start bringing up models with larger sequence lengths (e.g. may go up to 32K sequence length in some cases).

SeanNijjar commented 8 months ago

FYI @jliangTT, I'm not sure who is the appropriate person assign this to.

jliangTT commented 8 months ago

We should use the TT-NN board if we are not sure. Hey @eyonland and @arakhmati , this issue could lie deeper in the stack, triaging to your queue at a top of the funnel to help route this one.

SeanNijjar commented 8 months ago

btw I was debating a P2 or P3 on this one. Maybe we move it to P3 and keep this in mind for complaints of people hitting out of memory errors in the future.

eyonland commented 8 months ago

We need to run valgrind and see where the memory usage is going to. I'll add this to my todo this week.

arakhmati commented 8 months ago

@SeanNijjar can you show us what the test looks like?

SeanNijjar commented 8 months ago

Sure @arakhmati, it's test_all_gather: ___.

Since I've opened the issue, the test file has changed slightly. The main test body looks like this:

def run_all_gather_on_t3000_impl(
    all_devices,
    num_devices,
    input_shape,
    dim,
    num_links,
    input_dtype,
    layout,
    mem_config,
    use_program_cache,
    function_level_defaults,
    num_iters=1,
):
    # omitted: misc checks for test case validity ....

    devices = get_devices_for_t3000(all_devices, num_devices)

    input_tensor = torch.rand(input_shape).bfloat16()

    input_tensors = torch.chunk(input_tensor, num_devices, dim)
    tt_input_tensors = []
    for i, t in enumerate(input_tensors):
        tt_input_tensors.append(ttl.tensor.Tensor(t, input_dtype).to(layout).to(devices[i], mem_config))

    tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, num_links, output_mem_config=mem_config)

    # readback/compare vvv
SeanNijjar commented 8 months ago

Hey @eyonland, I had a quick offline chat with @arakhmati. He suggested some further triage I can try. I think before I send you down a reproducibility rabbit hole, I'll collect some more info and send an update when I've got something more conclusive.

SeanNijjar commented 8 months ago

I got some more info for you guys: I was only able to reproduce when running a full sweep. I few things I noticed:

1) Memory consumption slowly crept up over the lifetime of the sweep (this isn't the memory consumption exploding in the issue). But wanted to bring it up because we might have a leak somewhere?

2) The memory consumption explodes after my allgather runs on device. It's happening some time during copy back or compare. I've included a snapshot of htop and the test running

image

For reference, at the start of this test case, we were sitting at around 20GB of memory usage for the process.