google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

Add context manager api for setting mixed precision policies. #613

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

Add context manager api for setting mixed precision policies.

mod = hk.Linear(1) with hk.mixed_precision.push_policy(hk.Linear, policy): ... x = mod(x) # policy is active ... x = mod(x) # previous policy (if any) is active