necludov / jam

Implementation of Action Matching
https://arxiv.org/abs/2210.06662
MIT License
36 stars 6 forks source link

Running Tutorials #1

Closed atong01 closed 1 year ago

atong01 commented 1 year ago

Very cool, love the update!

Two notes:

  1. The links for stochastic action matching and unbalanced action matching are switched
  2. I can't seem to run the examples I get
    
    ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
    ipython 7.9.0 requires jedi>=0.10, which is not installed.
    ---------------------------------------------------------------------------
    ModuleNotFoundError                       Traceback (most recent call last)
    [<ipython-input-1-32e701e1e453>](https://37onzeme1u8-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230208-060050-RC01_508200239#) in <module>
      8 try:
    ----> 9   import flax
     10 except ModuleNotFoundError:

ModuleNotFoundError: No module named 'flax'

During handling of the above exception, another exception occurred:

AttributeError Traceback (most recent call last) 5 frames /usr/local/lib/python3.8/dist-packages/flax/core/meta.py in Partitioned() 263 return self.replace(names=tuple(names)) 264 --> 265 def get_partition_spec(self) -> jax.sharding.PartitionSpec: 266 """Returns the Partitionspec for this partitioned value.""" 267 return jax.sharding.PartitionSpec(*self.names)

AttributeError: module 'jax.sharding' has no attribute 'PartitionSpec'



Also numpy seems like it's not imported but this is a small issue.
necludov commented 1 year ago

Thank you, Alexander! It seems like flax just stopped working on google colab. Simply running

import jax
!pip install --quiet flax
import flax

yields the error you have.

necludov commented 1 year ago

I solved it for now by downgrading Flax to 0.6.4.