CQCL / qnlp_lorenz_etal_2021_resources

Code and resources for the Lorenz et al. (2021) QNLP paper
GNU General Public License v3.0
25 stars 12 forks source link

jax problems rp_task_simulation #2

Closed nlpirate closed 2 years ago

nlpirate commented 2 years ago

I am trying to run rp_task_simulation.ipynb on google colab but I am getting the following error message back

UnfilteredStackTrace                      Traceback (most recent call last)

<ipython-input-20-0fff64bd88f1> in <module>()
      4     start = time()
----> 5     print('Cost: ', get_cost_jit(rand_unshaped_pars))
      6     print('Time taken for this iteration: ', time()-start)

32 frames

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
While tracing the function get_cost at <ipython-input-18-b448627ced96>:8 for jit, this concrete value was not available in Python because it depends on the value of the argument 'unshaped_params'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)

/usr/local/lib/python3.7/dist-packages/discopy/quantum/gates.py in array(self)
    410     def array(self):
    411         half_theta = self.modules.pi * self.phase
--> 412         sin, cos = self.modules.sin(half_theta), self.modules.cos(half_theta)
    413         return Tensor.np.array([[cos, -1j * sin], [-1j * sin, cos]])
    414 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
While tracing the function get_cost at <ipython-input-18-b448627ced96>:8 for jit, this concrete value was not available in Python because it depends on the value of the argument 'unshaped_params'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

I have already set the variable IMPORT_JAX = True in discopy's config.py. any suggestions on how to fix it?

dimkart commented 2 years ago

Hi, and really sorry for the delay in answering this. It seems the problem might be you are using a wrong version of DisCoPy, can you check this? (0.3.5 is required)

nlpirate commented 2 years ago

thanks for the reply, Jax error seems to be solved using 0.3.5 version of discopy. Trying to execute last cell of the notebook, now I am getting the following error

---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/matplotlib/texmanager.py in _run_checked_subprocess(self, command, tex)
    305                                              cwd=self.texcache,
--> 306                                              stderr=subprocess.STDOUT)
    307         except FileNotFoundError as exc:

25 frames
FileNotFoundError: [Errno 2] No such file or directory: 'latex': 'latex'

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/matplotlib/texmanager.py in _run_checked_subprocess(self, command, tex)
    308             raise RuntimeError(
    309                 'Failed to process string with tex because {} could not be '
--> 310                 'found'.format(command[0])) from exc
    311         except subprocess.CalledProcessError as exc:
    312             raise RuntimeError(

RuntimeError: Failed to process string with tex because latex could not be found
Error in callback <function install_repl_displayhook.<locals>.post_execute at 0x7f2f758fb7a0> (for post_execute):
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/matplotlib/texmanager.py in _run_checked_subprocess(self, command, tex)
    305                                              cwd=self.texcache,
--> 306                                              stderr=subprocess.STDOUT)
    307         except FileNotFoundError as exc:

23 frames
FileNotFoundError: [Errno 2] No such file or directory: 'latex': 'latex'

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/matplotlib/texmanager.py in _run_checked_subprocess(self, command, tex)
    308             raise RuntimeError(
    309                 'Failed to process string with tex because {} could not be '
--> 310                 'found'.format(command[0])) from exc
    311         except subprocess.CalledProcessError as exc:
    312             raise RuntimeError(

RuntimeError: Failed to process string with tex because latex could not be found
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/matplotlib/texmanager.py in _run_checked_subprocess(self, command, tex)
    305                                              cwd=self.texcache,
--> 306                                              stderr=subprocess.STDOUT)
    307         except FileNotFoundError as exc:

23 frames
FileNotFoundError: [Errno 2] No such file or directory: 'latex': 'latex'

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/matplotlib/texmanager.py in _run_checked_subprocess(self, command, tex)
    308             raise RuntimeError(
    309                 'Failed to process string with tex because {} could not be '
--> 310                 'found'.format(command[0])) from exc
    311         except subprocess.CalledProcessError as exc:
    312             raise RuntimeError(

RuntimeError: Failed to process string with tex because latex could not be found
<Figure size 936x576 with 2 Axes>

It seems to be related to latex, but even installing it via pip install latex the error remains. Any suggestions fo fix it?

dimkart commented 2 years ago

Hi again. matplotlib tries to find latex installed in your machine is order to render the text. This is not really necessary for you in order to see the results, you can simply remove the line:

plt.rcParams.update({"text.usetex": True})

If you really want to use latex, have in mind that pip install latex just installs a Python wrapper that expects the actual latex engine to be already pre-installed on your machine. Check https://www.latex-project.org/get/#tex-distributions to find a latex distribution appropriate for your system and install it (but as I said, it's not necessary).

nlpirate commented 2 years ago

it works, thank you

dimkart commented 2 years ago

Thanks for noticing, we have updated the notebooks to remove the unnecessary latex dependency. This issue will be now closed.