fferflo / einx

Universal Tensor Operations in Einstein-Inspired Notation for Python.
https://einx.readthedocs.io/en/stable/
MIT License
311 stars 8 forks source link

set_at calls are unable to be cached in some cases #4

Closed alexanderswerdlow closed 7 months ago

alexanderswerdlow commented 7 months ago

First of all, thanks so much for making this library! I just discovered it a week ago and it's been super handy!

When I was profiling my code, I noticed that one of my calls took a long time and didn't seem to be cached. Specifically, I have the following call:

# cond.encoder_hidden_states is [1, 77, 8192]
# dropout_idx is [1, 8]
# self.uncond_hidden_states is [77, 1024]
dropout_idx = torch.rand(cond.encoder_hidden_states.shape[0], self.cfg.model.num_conditioning_pairs) < self.cfg.model.training_layer_dropout
if dropout_idx.sum() > 0:
    set_at("[b] tokens ([n] d), masked [2], tokens d -> b tokens ([n] d)", cond.encoder_hidden_states, dropout_idx.nonzero(), self.uncond_hidden_states)

I noticed that:

  1. I was unable to use set_at with a boolean mask, so I resorted to nonzero()
  2. The solving fails if the index is empty (e.g., dropout_idx.nonzero() is [0, 2])

and most importantly:

  1. The expression solving is not cached in this case as dropout_idx.nonzero() varies in size.

In this case, I imagine that after dropout_idx.numel() calls, all subsequent ones will be cached, but this is a bit of a surprise to a naive user.

Is there a simple way around this? It's a pretty common operation [in my code at least] to index a vector like this.

Perhaps it might also be a good idea to add an internal mechanism [with an option to disable globally] that warns users after n uncached calls?

Thanks again for the library!!

fferflo commented 7 months ago

Hi Alexander, thanks for your interest!

One possible alternative for einx.set_at with boolean masks would be to use einx.where, so maybe something like this?

dropout_mask = torch.rand(cond.encoder_hidden_states.shape[0], self.cfg.model.num_conditioning_pairs) < self.cfg.model.training_layer_dropout
cond.encoder_hidden_states = einx.where("b n, tokens d, b tokens (n d) -> b tokens (n d)", dropout_mask, self.uncond_hidden_states, cond.encoder_hidden_states)

This would also address the caching issue.

einx currently only supports static shapes, so the retracing can't be avoided if the input shapes differ. Dynamic shapes would be an interesting extension, but would require lots of changes to how einx handles shapes internally.

I added support for zero-sized updates/ coordinates in set_at and an option to warn on excessive retraces. Thanks for the suggestions!

(The tests currently fail, I think due to some changes introduced with Torch 2.2, but it should work with Torch <= 2.1 for now)

alexanderswerdlow commented 7 months ago

That works perfectly—thanks for those updates!

Another related dynamic case I came up with while trying to use the library was support for combining n > 2 tensors. I'm guessing this falls under the same [very reasonable] limitation to only support static shapes but might be worth calling out in the "gotchas". I haven't dug into the internals so I don't know how hard this would be, but perhaps in the future this could be supported in a restricted way [e.g., allowing a iterable of tensors so long as they have the same shape, but allowing varying lengths].

Anyway I'll close this issue, but I'm very excited for the future of this library!