Closed llCurious closed 1 year ago
Hi there, instance level mixed precision with jmp is not built into Haiku, however you can implement this in a few ways. I've added them in this colab notebook: https://colab.research.google.com/gist/tomhennigan/874de0420c55f7bd062f24f7ec6b0e51/instance-level-half-precision.ipynb
Approximately the options are:
class LowPrecisionLinear(hk.Linear):
pass
half_policy = jmp.get_policy('compute=half')
hk.mixed_precision.set_policy(LowPrecisionLinear, half_policy)
def f(x):
net = hk.Sequential([
hk.Linear(300), jnp.tan,
hk.Linear(100), jnp.tan,
LowPrecisionLinear(10),
])
return net(x)
def wrap_with_policy(mod: hk.Module, policy: jmp.Policy):
cls = type(mod)
@functools.wraps(mod.__call__)
def wrapper(*args, **kwargs):
old_policy = hk.mixed_precision.get_policy(hk.Linear)
hk.mixed_precision.set_policy(cls, policy)
try:
return mod(*args, **kwargs)
finally:
if old_policy is not None:
hk.mixed_precision.set_policy(cls, old_policy)
else:
hk.mixed_precision.clear_policy(cls)
return wrapper
half_policy = jmp.get_policy('compute=half')
def f(x):
net = hk.Sequential([
hk.Linear(300), jnp.tan,
hk.Linear(100), jnp.tan,
wrap_with_policy(hk.Linear(10), half_policy),
])
return net(x)
Thank you for the suggested options. This is of great help to me.
I noticed that haiku intergrate JMP and supports mixed-precision. An example code for ResNet is as follows:
The practice is to keep the computation of BN in full precision. I wonder if there is anapproach for me to do something like configuring a designated layer to some precision, instread of by module. For instance,
I wonder if: