google / objax

Apache License 2.0
769 stars 77 forks source link

Request: remove strict dependence on tensorflow #199

Closed bytbox closed 3 years ago

bytbox commented 3 years ago

This feels like borderline bikeshedding, but...

tensorflow is quite a heavy dependency. In particular, I can't install it using pip, right now, because I'm on archlinux running python 3.9, which TF doesn't yet (as of this writing) support.

As I understand, the tensorflow dependence introduced by ced0e704e is only for the objax2TF converter. If that was made optional, it would really be quite convenient, at least for me.

david-berthelot commented 3 years ago

Where is the tensorflow dependency? It's not in the requirements.txt file: https://github.com/google/objax/blob/master/requirements.txt

If you're referring to the ones in docs, it's only for documentation building and shouldn't used when doing pip install objax unless I'm mistaken. Would you mind clarifying how this problem happens for you?

bytbox commented 3 years ago

Ah, sorry. I thought the tensorflow dependence was deliberate, and didn't notice it wasn't in requirements.txt. There's an import line in util/__init__.py which fails when tensorflow is not installed.

>>> import objax
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/scott/tmp/lib/python3.9/site-packages/objax/__init__.py", line 17, in <module>
    from ._patch_jax import *
  File "/home/scott/tmp/lib/python3.9/site-packages/objax/_patch_jax.py", line 23, in <module>
    from .util import re_sign
  File "/home/scott/tmp/lib/python3.9/site-packages/objax/util/__init__.py", line 18, in <module>
    from .objax2tf import Objax2Tf
  File "/home/scott/tmp/lib/python3.9/site-packages/objax/util/objax2tf.py", line 17, in <module>
    from jax.experimental import jax2tf
  File "/home/scott/tmp/lib/python3.9/site-packages/jax/experimental/jax2tf/__init__.py", line 16, in <module>
    from .jax2tf import convert, shape_as_value, split_to_logical_devices
  File "/home/scott/tmp/lib/python3.9/site-packages/jax/experimental/jax2tf/jax2tf.py", line 42, in <module>
    import tensorflow as tf  # type: ignore[import]
ModuleNotFoundError: No module named 'tensorflow'
david-berthelot commented 3 years ago

Thanks, I've created a fix, it should be in tomorrow and I'll make a patch release that will contain it.

david-berthelot commented 3 years ago

Release 1.3.1 is out with the patch.