Closed phhusson closed 2 years ago
IIRC the torch version does not turn off gradient tracking, throwing @torch.no_grad() over the gen function should help a lot.
Maximum resident set size (kbytes): 4013488
Yup, that was it. Thanks. I can infer both mega and-non mega on my AMD RX 6700 XT 12GB RAM. Mega infers in 45s, non-mega in 16s.
My local patch for GPU inference is attached (rename .csv to .patch)
I added torch.set_grad_enabled(False)
and it seems to have significantly reduced memory usage with torch
Hi,
Thanks for this project, you gave me hope to give my AMD GPU some fun :-) (torch is much easier to get to work on it than XLA)
That being said, it doesn't work yet, because the torch variant eats much more RAM than the flax one:
Would you have some pointers as to how to optimize that?