jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
214 stars 28 forks source link

[Q] How to properly save and load fp8 NumPy arrays? #207

Open apivovarov opened 1 month ago

apivovarov commented 1 month ago

I would like to save and load an f8m5e2 array. I initially tried using the standard numpy.save() and numpy.load() functions, but loading fails.

.local/lib/python3.10/site-packages/numpy/lib/format.py", line 325, in descr_to_dtype
    return numpy.dtype(descr)
TypeError: data type '<f1' not understood

.local/lib/python3.10/site-packages/numpy/lib/format.py", line 683, in _read_array_header
    raise ValueError(msg.format(d['descr'])) from e
ValueError: descr is not a valid dtype descriptor: '<f1'

I found that I can save and load float8 arrays using a lower-level API (np.tobytes / np.frombuffer), as shown below:

import ml_dtypes
import numpy as np
import json

# Create the array
x = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e5m2)

# Save the array
with open("a.npy", "wb") as f:
    f.write(x.tobytes())

# Save the array's shape and dtype separately
meta = {"shape": x.shape, "dtype": str(x.dtype)}
with open("a_meta.json", "w") as f:
    json.dump(meta, f)

# Load the array
with open("a.npy", "rb") as f:
    data = f.read()

# Load the metadata
with open("a_meta.json", "r") as f:
    meta = json.load(f)

# Reconstruct the array
x2 = np.frombuffer(data, dtype=getattr(ml_dtypes, meta["dtype"])).reshape(meta["shape"])

print(x2)

Is the solution above (np.tobytes / np.frombuffer) considered best practice for this case?

@jakevdp Jake, can you comment on it?

Related Issues

jakevdp commented 1 month ago

Unfortunately NumPy's array serialization only works with NumPy's built-in dtypes. Probably the easiest way to serialize arrays with custom dtypes is to view them as unsigned int:

import ml_dtypes
import numpy as np
import json

# Create the array
x = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e5m2)

np.save('x.npy', x.view('uint8'))
x2 = np.load('x.npy').view(ml_dtypes.float8_e5m2)

print(np.all(x == x2))
# True

Your approach of serializing the raw bytes also works, though I'd recommend not naming the file with a .npy extension with that approach, because this extension typically implies the file is loadable with np.load.

apivovarov commented 1 month ago

Hi Jake, Thank you for your reply!

I have one additional question

I tired to use pickle. It works. File size is almost the same as default np.save approach

import ml_dtypes
import numpy as np
import pickle

# Create the array
a = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e5m2)
b = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e4m3)

# Save
with open('a.npy.pkl', "wb") as f:
  pickle.dump(a, f)

with open('b.npy.pkl', "wb") as f:
  pickle.dump(b, f)

# Load back
a2 = np.load('a.npy.pkl', allow_pickle=True)
b2 = np.load('b.npy.pkl', allow_pickle=True)

print(np.all(a == a2))
print(np.all(b == b2))

Seems that it works out of the box and saves ml_dtypes dtype info into the the file.

What are the disadvantages of using pickle?

Cons which I found:

>>> b2 = np.load('b.npy.pkl', allow_pickle=True)
Traceback (most recent call last):
  File "/home/user/.local/lib/python3.9/site-packages/numpy/lib/npyio.py", line 441, in load
    return pickle.load(fid, **pickle_kwargs)
ModuleNotFoundError: No module named 'numpy._core'
jakevdp commented 1 month ago

Yes, pickle works, but has downsides. The two you mention are the main issues: unpickling allows for arbitrary code execution, and will often break when used in an environment with different package versions.