dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
863 stars 68 forks source link

opt_einsum thinks the largest intermediate will be small, but torch.einsum allocates 156 GiB #133

Closed philip-bl closed 3 years ago

philip-bl commented 4 years ago

I want to perform a contraction described by the following code

import torch
import opt_einsum as oe

device = "cuda"
big_core = torch.randn(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, dtype=torch.float64, device=device)
batches_of_small_cores = [torch.randn(512, 25, 25, 2, dtype=torch.float64, device=device) for _ in range(16)]

equation = "αβγi,αβγj,αβγk,αβγl,αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω->αβγω"
print(oe.contract_path(equation, *batches_of_small_cores, big_core, optimize="auto", memory_limit="max_input"))
result = oe.contract(equation, *batches_of_small_cores, big_core, optimize="auto", memory_limit="max_input")

I am predicting memory problems so I ask opt_einsum to do memory_limit="max_input". Well, it turns out that it doesn't work. opt_einsum reports that this contraction barely allocates anything:

([(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)],   Complete contraction:  αβγi,αβγj,αβγk,αβγl,αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω->αβγω
         Naive scaling:  20
     Optimized scaling:  20
      Naive FLOP count:  7.130e+11
  Optimized FLOP count:  7.130e+11
   Theoretical speedup:  1.000
  Largest intermediate:  6.400e+5 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
  20              0 ijklmnopqrstuvwxω,αβγx,αβγw,αβγv,αβγu,αβγt,αβγs,αβγr,αβγq,αβγp,αβγo,αβγn,αβγm,αβγl,αβγk,αβγj,αβγi->αβγω                            αβγω->αβγω)

but opt_einsum is wrong. This path contains one operation - one torch.einsum, but this torch.einsum tries to allocate 156.25 GiB of memory:

RuntimeError: CUDA out of memory. Tried to allocate 156.25 GiB (GPU 0; 7.77 GiB total capacity; 82.00 MiB already allocated; 6.98 GiB free; 82.00 MiB reserved in total by PyTorch)

To be clear, if I change device = "cuda" to device = "cpu", approximately the same amount of memory is allocated. Also, no backpropagation and no gradient tracking is happening in this code snippet.

My guess of what is happening:

  1. opt_einsum thinks that torch.einsum is a dumb function which doesn't allocate any memory other than the memory for the output.
  2. Actually, torch.einsum allocates intermediate tensors, and I don't understand what logic it uses to choose how to allocate them.

If this not fixable, I suggest updating the documentation saying that iwth pytorch, memory_limit is not reliable at all.

dgasmith commented 4 years ago

This is a bit of an edge case where it assumes that all einsum operations use memory footprints like NumPy. That is it does not use intermediates, which clearly does not apply to pytorch (or likely any GPU-based design the for loops are too expensive).

What this is really telling you is that there are no contraction paths available that satisfy the memory footprint constraint that you have provided. Playing with it you need to provide scratch space on the order of 1e7 for this to work.

oe.contract_path(equation, *batches_of_small_cores, big_core, optimize="greedy", memory_limit=1e7)[1]
  Complete contraction:  αβγi,αβγj,αβγk,αβγl,αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω->αβγω
         Naive scaling:  20
     Optimized scaling:  20
      Naive FLOP count:  7.130e+11
  Optimized FLOP count:  2.097e+11
   Theoretical speedup:  3.400
  Largest intermediate:  5.120e+6 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   5              0       αβγj,αβγi->αβγji αβγk,αβγl,αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji->αβγω
   5              0       αβγl,αβγk->αβγlk αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk->αβγω
   5              0       αβγn,αβγm->αβγnm αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm->αβγω
   5              0       αβγp,αβγo->αβγpo αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo->αβγω
   5              0       αβγr,αβγq->αβγrq αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo,αβγrq->αβγω
   5              0       αβγt,αβγs->αβγts αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo,αβγrq,αβγts->αβγω
   5              0       αβγv,αβγu->αβγvu αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo,αβγrq,αβγts,αβγvu->αβγω
   5              0       αβγx,αβγw->αβγxw ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo,αβγrq,αβγts,αβγvu,αβγxw->αβγω
   7              0   αβγlk,αβγji->αβγlkji ijklmnopqrstuvwxω,αβγnm,αβγpo,αβγrq,αβγts,αβγvu,αβγxw,αβγlkji->αβγω
   7              0   αβγpo,αβγnm->αβγponm ijklmnopqrstuvwxω,αβγrq,αβγts,αβγvu,αβγxw,αβγlkji,αβγponm->αβγω
   7              0   αβγts,αβγrq->αβγtsrq ijklmnopqrstuvwxω,αβγvu,αβγxw,αβγlkji,αβγponm,αβγtsrq->αβγω
   7              0   αβγxw,αβγvu->αβγxwvu ijklmnopqrstuvwxω,αβγlkji,αβγponm,αβγtsrq,αβγxwvu->αβγω
  20              0 αβγxwvu,αβγtsrq,αβγponm,αβγlkji,ijklmnopqrstuvwxω->αβγω                            αβγω->αβγω

I am unsure if we will be in the business of supporting specific backends at this level. Its something that we could potentially do, but would require us keeping a much closer eye on what others are doing in their einsum and how they plan to evolve it. In all likely hood we will go the NEP18 route where we will become even more agnostic to the backend platform and the caveats that entails.

jcmgray commented 4 years ago

Generally speaking, the task of finding a contraction path is to break the contraction into pairwise contractions, exponentially reducing the time complexity at some space cost - usually not insignificantly higher intermediate memory. As you have found, introducing the memory_limit means no part of any path can be found and so the entire thing is deferred to torch.einsum.

In general I'd say due to the complexities of path finding that the memory_limit argument rarely has quite the desired effect and maybe the docs could be better on this.

Generally the best way to reduce memory for a contraction is 'slicing' (#95, #125) certain indices once a good path has been found, which I've just checked works well for this case:

import torch
import opt_einsum as oe

device = "cuda"
big_core = torch.randn(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, dtype=torch.float64, device=device)
batches_of_small_cores = [torch.randn(512, 25, 25, 2, dtype=torch.float64, device=device) for _ in range(16)]

equation = "αβγi,αβγj,αβγk,αβγl,αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω->αβγω"
path, info = oe.contract_path(equation, *batches_of_small_cores, big_core, optimize="auto-hq")

Note optimize="auto-hq" finds lower memory.

Now we find best indices to explicitly sum over:

import cotengra as ctg

sf = ctg.SliceFinder(info, target_size=2**26)
inds_to_slice, cost_of_slicing = sf.search()

cost_of_slicing.size   # the new largest intermediate
# 40960000.0

cost_of_slicing.overhead  # theoretical 'slowdown'
1.0273594262607766

Finally actually perform the contraction:

import tqdm as tqdm
sc = sf.SlicedContractor([*batches_of_small_cores, big_core])
result = sum(sc.contract_slice(i) for i in tqdm.trange(sc.nslices))
# 100%|██████████| 512/512 [00:55<00:00,  9.30it/s]

Maybe at some point the memory_limit keyword argument could use this approach automatically.

philip-bl commented 4 years ago

@dgasmith Actually, with the path you found the problem appears as well. opt_einsum.contract_path reports that the largest intermediate will have 5.120e+6 elements. I am using float64, so that's 5.12e6 * 8 / 1024 / 1024 == 39.0625 megabytes. But actually, when contracting using that path, during the very last operation αβγxwvu,αβγtsrq,αβγponm,αβγlkji,ijklmnopqrstuvwxω->αβγω pytorch tries to allocate 9.77 gigabytes, and fails (because my GPU doesn't have that much memory).

To be clear, I know how to calculate this particular contraction somewhat efficiently. I guess I'll have to construct the path on my own.

--

For the sake of giving you more knowledge about how torch.einsum works, I also want to mention that its memory allocation and performance depends on the arguments order https://github.com/pytorch/pytorch/issues/35299.

dgasmith commented 4 years ago

Is there documentation on how pytorch calculates its intermediates and forms the computation? It seems that opt_einsum doesn't know enough to give a reasonable bound there. A naive guess says it just runs from left to right building any intermediate along the way it requires.

To echo @jcmgray you can slice along your large 512 index to obtain a linear decrease in memory footprint, we still will not be able to tell you exactly how much memory it will take however.

jcmgray commented 4 years ago

@philip-bl If you find a good contraction path that can contract this better than e.g. 'auto-hq' do please let us know! That is useful information for improving the optimizers.