Open wonjeon opened 4 months ago
Thanks - yeah this is a known issue (similar to what's reported in https://github.com/google/jax/discussions/8494).
Unfortunately, numpy's serialization only recognizes numpy's built-in dtypes, and the package currently offers no way to extend that. The best workaround for the time being would be something like this:
>>> np.save('a.npy', a.view('uint8'))
>>> np.load('a.npy').view(float8_e5m2)
array(1.5, dtype='float8_e5m2')
@jakevdp Thanks for your response and the information on the workaround. Confirmed that it works.
I tried the following code snippet, and it doesn't seem to work. Is this an already known issue?