hitsave-io / xyz

Monorepo for HitSave
1 stars 0 forks source link

client: better pytorch support #122

Open EdAyers opened 1 year ago

EdAyers commented 1 year ago

This is a list of things that the client needs to support in pytorch. We need to be able to pickle, hash and restore each of the following:

A very high priority feature is tensor-board-like logging.

The main thing that I am noticing is that you frequently want to pass stateful objects like nn.Module and DataLoader to memo'd functions. And at the moment the entire datasets etc are being hashed every time. Hashing dataloaders is ok if we can hash the dataset. I think that the answer is to have special in-mem hash caches with hashes of objects that we know are going to be immutable. Need to be very careful not to cache the hash for an object that is mutable.

The example that comes up a lot is:

@memo
def load_dataset(**params) → Dataset:
  ...

my_dataloader = DataLoader(load_dataset(..), ...)

Now, my_dataloader is an object so it gets hashed by just hashing all of the child objects. The problem is that hashing the dataset takes ages, you don't really want to do this every time, also we know that the object was a direct result of calling load_dataset. The proposal here is for known-immutable types, we keep a few caches:

Another problem with hashing dataloaders is that if there is a random-sampler module this gets hashed. But we should really have a policy that we don't hash this.

Also need to fix downloads so that they happen offline.

Making pickling support robust binding references.

Something I am realising is that unpickling isn't always enough to exactly reproduce the object, because it may depend on a class or function that now has different code. Ideally, this should cause a dependency change in the function that made it so we are theoretically safe from this but in practice there can be edge cases. We also have that the pickler complains when trying to pickle a function. The plan is:

  1. write a custom pickle handler for functions and classes, where it emits a code dependency (symbol, total digest) pair instead of just the value. This means that we can now pickle code robustly.
  2. When this is unpickled, we cross-check against the binding graph to make sure that the code hasn't changed. If it has then we get an unpickling failure and we return a StoreMiss.