Closed alexanderswerdlow closed 7 months ago
Hi Alexander, thanks for your interest!
One possible alternative for einx.set_at
with boolean masks would be to use einx.where
, so maybe something like this?
dropout_mask = torch.rand(cond.encoder_hidden_states.shape[0], self.cfg.model.num_conditioning_pairs) < self.cfg.model.training_layer_dropout
cond.encoder_hidden_states = einx.where("b n, tokens d, b tokens (n d) -> b tokens (n d)", dropout_mask, self.uncond_hidden_states, cond.encoder_hidden_states)
This would also address the caching issue.
einx currently only supports static shapes, so the retracing can't be avoided if the input shapes differ. Dynamic shapes would be an interesting extension, but would require lots of changes to how einx handles shapes internally.
I added support for zero-sized updates/ coordinates in set_at and an option to warn on excessive retraces. Thanks for the suggestions!
(The tests currently fail, I think due to some changes introduced with Torch 2.2, but it should work with Torch <= 2.1 for now)
That works perfectly—thanks for those updates!
Another related dynamic case I came up with while trying to use the library was support for combining n > 2 tensors. I'm guessing this falls under the same [very reasonable] limitation to only support static shapes but might be worth calling out in the "gotchas". I haven't dug into the internals so I don't know how hard this would be, but perhaps in the future this could be supported in a restricted way [e.g., allowing a iterable of tensors so long as they have the same shape, but allowing varying lengths].
Anyway I'll close this issue, but I'm very excited for the future of this library!
First of all, thanks so much for making this library! I just discovered it a week ago and it's been super handy!
When I was profiling my code, I noticed that one of my calls took a long time and didn't seem to be cached. Specifically, I have the following call:
I noticed that:
set_at
with a boolean mask, so I resorted tononzero()
dropout_idx.nonzero()
is [0, 2])and most importantly:
dropout_idx.nonzero()
varies in size.In this case, I imagine that after
dropout_idx.numel()
calls, all subsequent ones will be cached, but this is a bit of a surprise to a naive user.Is there a simple way around this? It's a pretty common operation [in my code at least] to index a vector like this.
Perhaps it might also be a good idea to add an internal mechanism [with an option to disable globally] that warns users after
n
uncached calls?Thanks again for the library!!