masa-su / pixyz

A library for developing deep generative models in a more concise, intuitive and extendable way
https://pixyz.io
MIT License
487 stars 41 forks source link

Support "Automatic Mixture Precision ( AMP )" #180

Open ItoMasaki opened 1 year ago

ItoMasaki commented 1 year ago

What I did

What I did not

What you can do

What you can not do

Operation check

p = Normal() p.to("cuda") model = Model(loss=-LogProb(p).mean(), distributions=[p], use_amp=False)

x = torch.ones(10000, 1).to("cuda") y = torch.zeros(10000, 1).to("cuda") loss = model.train({"x": x, "y": y})


- Result is like below

GPU RAM ( If use_amp argument is False ) 1305MiB / 15360MiB

GPU RAM ( If use_amp argument is True ) 737MiB / 15360MiB



## Reference information

- [AUTOMATIC MIXED PRECISION PACKAGE - TORCH.AMP](https://pytorch.org/docs/stable/amp.html)