Closed philip-bl closed 3 years ago
This is a bit of an edge case where it assumes that all einsum
operations use memory footprints like NumPy. That is it does not use intermediates, which clearly does not apply to pytorch
(or likely any GPU-based design the for loops are too expensive).
What this is really telling you is that there are no contraction paths available that satisfy the memory footprint constraint that you have provided. Playing with it you need to provide scratch space on the order of 1e7 for this to work.
oe.contract_path(equation, *batches_of_small_cores, big_core, optimize="greedy", memory_limit=1e7)[1]
Complete contraction: αβγi,αβγj,αβγk,αβγl,αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω->αβγω
Naive scaling: 20
Optimized scaling: 20
Naive FLOP count: 7.130e+11
Optimized FLOP count: 2.097e+11
Theoretical speedup: 3.400
Largest intermediate: 5.120e+6 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
5 0 αβγj,αβγi->αβγji αβγk,αβγl,αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji->αβγω
5 0 αβγl,αβγk->αβγlk αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk->αβγω
5 0 αβγn,αβγm->αβγnm αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm->αβγω
5 0 αβγp,αβγo->αβγpo αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo->αβγω
5 0 αβγr,αβγq->αβγrq αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo,αβγrq->αβγω
5 0 αβγt,αβγs->αβγts αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo,αβγrq,αβγts->αβγω
5 0 αβγv,αβγu->αβγvu αβγw,αβγx,ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo,αβγrq,αβγts,αβγvu->αβγω
5 0 αβγx,αβγw->αβγxw ijklmnopqrstuvwxω,αβγji,αβγlk,αβγnm,αβγpo,αβγrq,αβγts,αβγvu,αβγxw->αβγω
7 0 αβγlk,αβγji->αβγlkji ijklmnopqrstuvwxω,αβγnm,αβγpo,αβγrq,αβγts,αβγvu,αβγxw,αβγlkji->αβγω
7 0 αβγpo,αβγnm->αβγponm ijklmnopqrstuvwxω,αβγrq,αβγts,αβγvu,αβγxw,αβγlkji,αβγponm->αβγω
7 0 αβγts,αβγrq->αβγtsrq ijklmnopqrstuvwxω,αβγvu,αβγxw,αβγlkji,αβγponm,αβγtsrq->αβγω
7 0 αβγxw,αβγvu->αβγxwvu ijklmnopqrstuvwxω,αβγlkji,αβγponm,αβγtsrq,αβγxwvu->αβγω
20 0 αβγxwvu,αβγtsrq,αβγponm,αβγlkji,ijklmnopqrstuvwxω->αβγω αβγω->αβγω
I am unsure if we will be in the business of supporting specific backends at this level. Its something that we could potentially do, but would require us keeping a much closer eye on what others are doing in their einsum
and how they plan to evolve it. In all likely hood we will go the NEP18 route where we will become even more agnostic to the backend platform and the caveats that entails.
Generally speaking, the task of finding a contraction path is to break the contraction into pairwise contractions, exponentially reducing the time complexity at some space cost - usually not insignificantly higher intermediate memory. As you have found, introducing the memory_limit
means no part of any path can be found and so the entire thing is deferred to torch.einsum
.
In general I'd say due to the complexities of path finding that the memory_limit
argument rarely has quite the desired effect and maybe the docs could be better on this.
Generally the best way to reduce memory for a contraction is 'slicing' (#95, #125) certain indices once a good path has been found, which I've just checked works well for this case:
import torch
import opt_einsum as oe
device = "cuda"
big_core = torch.randn(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, dtype=torch.float64, device=device)
batches_of_small_cores = [torch.randn(512, 25, 25, 2, dtype=torch.float64, device=device) for _ in range(16)]
equation = "αβγi,αβγj,αβγk,αβγl,αβγm,αβγn,αβγo,αβγp,αβγq,αβγr,αβγs,αβγt,αβγu,αβγv,αβγw,αβγx,ijklmnopqrstuvwxω->αβγω"
path, info = oe.contract_path(equation, *batches_of_small_cores, big_core, optimize="auto-hq")
Note optimize="auto-hq"
finds lower memory.
Now we find best indices to explicitly sum over:
import cotengra as ctg
sf = ctg.SliceFinder(info, target_size=2**26)
inds_to_slice, cost_of_slicing = sf.search()
cost_of_slicing.size # the new largest intermediate
# 40960000.0
cost_of_slicing.overhead # theoretical 'slowdown'
1.0273594262607766
Finally actually perform the contraction:
import tqdm as tqdm
sc = sf.SlicedContractor([*batches_of_small_cores, big_core])
result = sum(sc.contract_slice(i) for i in tqdm.trange(sc.nslices))
# 100%|██████████| 512/512 [00:55<00:00, 9.30it/s]
Maybe at some point the memory_limit
keyword argument could use this approach automatically.
@dgasmith Actually, with the path you found the problem appears as well. opt_einsum.contract_path
reports that the largest intermediate will have 5.120e+6 elements. I am using float64, so that's 5.12e6 * 8 / 1024 / 1024 == 39.0625
megabytes. But actually, when contracting using that path, during the very last operation αβγxwvu,αβγtsrq,αβγponm,αβγlkji,ijklmnopqrstuvwxω->αβγω
pytorch tries to allocate 9.77 gigabytes, and fails (because my GPU doesn't have that much memory).
To be clear, I know how to calculate this particular contraction somewhat efficiently. I guess I'll have to construct the path on my own.
--
For the sake of giving you more knowledge about how torch.einsum
works, I also want to mention that its memory allocation and performance depends on the arguments order https://github.com/pytorch/pytorch/issues/35299.
Is there documentation on how pytorch calculates its intermediates and forms the computation? It seems that opt_einsum
doesn't know enough to give a reasonable bound there. A naive guess says it just runs from left to right building any intermediate along the way it requires.
To echo @jcmgray you can slice along your large 512 index to obtain a linear decrease in memory footprint, we still will not be able to tell you exactly how much memory it will take however.
@philip-bl If you find a good contraction path that can contract this better than e.g. 'auto-hq'
do please let us know! That is useful information for improving the optimizers.
I want to perform a contraction described by the following code
I am predicting memory problems so I ask opt_einsum to do
memory_limit="max_input"
. Well, it turns out that it doesn't work. opt_einsum reports that this contraction barely allocates anything:but opt_einsum is wrong. This path contains one operation - one
torch.einsum
, but thistorch.einsum
tries to allocate 156.25 GiB of memory:To be clear, if I change
device = "cuda"
todevice = "cpu"
, approximately the same amount of memory is allocated. Also, no backpropagation and no gradient tracking is happening in this code snippet.My guess of what is happening:
torch.einsum
is a dumb function which doesn't allocate any memory other than the memory for the output.torch.einsum
allocates intermediate tensors, and I don't understand what logic it uses to choose how to allocate them.If this not fixable, I suggest updating the documentation saying that iwth pytorch,
memory_limit
is not reliable at all.