Open Haoming02 opened 1 year ago
Additionally,
With ToMe patch applied, I was able to upscale from 1024x1024 to 1536x1536 using 1024
and 128
settings. Just not to 2048x2048 no matter the settings.
It's unfortunate, but the current implementation of ToMe uses more memory when also using xformers/flash attn/torch 2.0 sdp attn or whatever.
Without those implementation, ToMe reduces memory usage by reducing the size of the attention matrices (which were absolutely massive to begin with). But flash attn-like methods already make computing attention a (linear?) space operation, because they don't compute the whole thing at once.
That leaves ToMe in an awkward spot, because it computes similarities for merging all at once, creating a (3*#tokens / 4) x (#tokens / 4)
matrix before immediately argmaxing it down to a 3 * #tokens / 4
vector. I think first matrix is the problem here. Normally, the smaller attn matrices more than make up for the extra space taken by that similarity matrix, but flash attn-like methods make that not the case anymore.
Now, ToMe doesn't actually need to compute this whole matrix, so there is hope. We only need the argmax over the similarities, not the similarities themselves. I'm just not sure how to implement that in native pytorch (flash attn et al. implement it using custom cuda kernels, which I don't want to use because that's what makes it require compilation).
In img2img, without the ToMe patch, I was able to upscale a 1024x1024 image to 2048x2048 using Tiled VAE, with
Encoder Tile Size
set to1024
andDecoder Tile Size
set to96
. The VRAM usage was around 6~7 GB.However, if I apply the ToME patch, regular generation does become faster. But when I try to upscale 1024x1024 image again, it starts throwing Out of Memory Error, even when I set the
Encoder Tile Size
lower to512
andDecoder Tile Size
to64
.The implementation I used was this, which simply calls
tomesd.apply_patch(sd_model, ratio=0.3)
insideon_model_loaded(sd_model)
.Is this a problem on my part? Did I write the implementation wrong? Or is it something else?
Full Error Below: