coreweave / tensorizer

Module, Model, and Tensor Serialization/Deserialization
MIT License
153 stars 24 forks source link

fix(serialization): Track tensor encryption status per-tensor #164

Closed Eta0 closed 2 weeks ago

Eta0 commented 2 weeks ago

Track per-tensor encryption status during serialization

This PR complements #163 by adding new tests and tracking per-tensor encryption statuses during serialization to avoid attempting to decrypt tensors that aren't encrypted. Decrypting tensors that are not encrypted normally raises an exception in tensorizer._crypt from a MAC mismatch, which makes it fairly safe, but adds noise by spamming concerning-looking decryption errors after an initial error.

If decryption fails during error handling in a bulk write operation, its CryptographyError is elevated to be the top-level exception, with the previous top-level exception chained to it. Other non-CryptographyError exceptions from the last-ditch decryption step are ignored, since they are likely duplicates of the exception that brought the code to that error path to begin with, and the code is about to exit with an exception anyway. (If we ever get something like Python 3.11's ExceptionGroups, this would be a good place to use them).