patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 143 forks source link

Allow StateIndex to be passed dynamically #843

Closed NeilGirdhar closed 1 month ago

NeilGirdhar commented 2 months ago

Fixes https://github.com/patrick-kidger/equinox/issues/842

NeilGirdhar commented 2 months ago

If this is acceptable, I'm happy to add some methods if you want to make the code more polished:

class StateIndex:
  def initial_value(self):
    return self.init[0]

  def initial_value_deleted(self):
     return self.init == ()
NeilGirdhar commented 2 months ago

Errors fixed, but I have some questions about the code. I don't understand why StateIndex.marker's sentinel can't be changed to None. Is it because of replacement with eqx.combine, etc.? Also, in State.set, why can't we assert that item.marker is an integer? I guess I don't understand what the code is doing.

patrick-kidger commented 2 months ago

Indeed, I think .marker being an object() might be unnecessary. I think there is a potential footgun here if such a 'raw' StateIndex is passed to State though -- with object() then we'd at least get a unique dictionary key, but with None then they'd all overwrite each other. Probably the appropriate solution is indeedto explicitly raise an error if such a 'raw' StateIndex is passed to State -- i.e. your suggestion of asserting that it is an integer.

NeilGirdhar commented 2 months ago

then we'd at least get a unique dictionary key, but with None then they'd all overwrite each other.

Interesting, okay. You may want to consider adding a comment if that's behavior that you're counting on in for some use cases. (Sorry I've been struggling with COVID all week, and my brain's a bit slower than usual.)

i.e. your suggestion of asserting that it is an integer.

I tried that, but couldn't get it to pass the tests.

Anyway, I do love how the interface of State protects users from setting twice on the same state, or trying to use an expired state. Nice user-facing design.

patrick-kidger commented 2 months ago

You may want to consider adding a comment if that's behavior that you're counting on in for some use cases.

Honestly, I'm not sure that'd be good behaviour to rely on... ! It's certainly not the usual path. I'd be happy to change that without considering it a compatibility break.

Lmk where you land on all of this + when you want a review of this PR. (Once it's passing tests.)

NeilGirdhar commented 2 months ago

Honestly, I'm not sure that'd be good behaviour to rely on... !

Great! If I have time, I'll take a look at this again.

Lmk where you land on all of this + when you want a review of this PR. (Once it's passing tests.)

Hmmm, it passes the tests on my machine on Python 3.11 (the failing test), and 3.12. I'm not sure how to debug this. Do you have any insight into this by any chance?

patrick-kidger commented 2 months ago

It does seem a bit weird! I do note that JAX recently did a new release, which has apparently since been yanked. Possibly something to do with that new release?

NeilGirdhar commented 2 months ago

I do note that JAX recently did a new release, which has apparently since been yanked. Possibly something to do with that new release?

I'm not sure, but I tested with the previous release Jax 0.4.31, whereas the test ran with the new release. I'll re-run the job.

But the error

Closure-converted function called with different dynamic arguments to the example arguments provided.

is related to Equinox, right? I'm not sure how closure conversion works since I haven't used it yet.

NeilGirdhar commented 2 months ago

(Looks like this passes now.)

NeilGirdhar commented 2 months ago

@patrick-kidger Do you have time to take a look at this?

patrick-kidger commented 2 months ago

Yup, I do! Have been otherwise engaged this past week. I expect to have a look at this in the next couple of days :)

NeilGirdhar commented 2 months ago

No worries, take your time :)

patrick-kidger commented 1 month ago

Okay, LGTM -- merged! Thank you for the contribution :)