tum-pbs / PhiFlow

A differentiable PDE solving framework for machine learning
MIT License
1.39k stars 189 forks source link

JIT compilation failes when using implicit diffusion #118

Closed rcremese closed 1 year ago

rcremese commented 1 year ago

Hello, it's me again, I open this issue because I updated my phiflow version to 2.3.2 and tried to run the following code :

import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from snake_ai.utils import Direction
from snake_ai.physim import ConvolutionWindow, GradientField
from phi.jax import flow

@flow.math.jit_compile
def explicit_diffusion(concentration : flow.field.Field, obstacle_mask : flow.field.Field, diffusivity : float, dt : float) -> flow.field.Field:
    return (1-obstacle_mask) * flow.diffuse.explicit(concentration, diffusivity=diffusivity, dt=dt)

@flow.math.jit_compile
def implicit_diffusion(concentration : flow.field.Field, obstacle_mask : flow.field.Field, diffusivity : float, dt : float) -> flow.field.Field:
    return (1-obstacle_mask) * flow.diffuse.implicit(concentration, diffusivity=diffusivity, dt=dt)

field = flow.CenteredGrid(flow.Noise(0.1, 0.1), x=10, y=10)
obstacle = flow.Box(x=(4,6), y=(4,6))
obstacle_mask = flow.CenteredGrid(obstacle, x=10, y=10)

initial_field = (1-obstacle_mask) * field
t=0
dt = 0.1
while t < 1:
    # field = explicit_diffusion(field, obstacle_mask, diffusivity=1, dt=dt)
    field = implicit_diffusion(field, obstacle_mask, diffusivity=1, dt=dt)
    t+=dt

flow.vis.plot([initial_field, field])
plt.show()

As one can see, I'm using jax >= 0.3 as backend in order to JIT compile my code. It's a very basic exemple where I diffuse a random field while taking into account absorbing conditions on an obstacle in the center of the scene.

When I try to run the above code with the flow.diffuse.explicit method it works fine, even when I decorate the function with flow.math.jit_compile. Nevertheless, when I try to use flow.diffuse.implicit with the jit_compile decorator, I get the following exception (I removed the JAX specific stack-trace) :

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rcremese/projects/snake-ai/src/snake_ai/draft.py", line 25, in <module>
    field = implicit_diffusion(field, obstacle_mask, diffusivity=1, dt=dt)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_functional.py", line 206, in __call__
    native_result = self.traces[key](*natives)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/jax/_jax_backend.py", line 168, in run_jit_f
    return self.as_registered.call(jit_f, *args, name=f"run jit-compiled '{f.__name__}'")
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/backend/_backend.py", line 354, in call
    return f(*args)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_functional.py", line 177, in jit_f_native
    result = self.f(**kwargs, **in_key.auxiliary_kwargs)  # Tensor or tuple/list of Tensors
  File "/home/rcremese/projects/snake-ai/src/snake_ai/draft.py", line 14, in implicit_diffusion
    return (1-obstacle_mask) * flow.diffuse.implicit(concentration, diffusivity=diffusivity, dt=dt)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/physics/diffuse.py", line 63, in implicit
    return solve_linear(sharpen, y=field, solve=solve)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_optimize.py", line 524, in solve_linear
    matrix, bias = f.sparse_matrix_and_bias(solve.x0, *f_args, **f_kwargs)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_functional.py", line 366, in sparse_matrix_and_bias
    return self._get_or_trace(key, args, kwargs)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_functional.py", line 303, in _get_or_trace
    matrix, bias = matrix_from_function(self.f, *args, **f_kwargs, auto_compress=True)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_trace.py", line 272, in matrix_from_function
    result = f(**x_kwargs, **aux_args)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/physics/diffuse.py", line 59, in sharpen
    return explicit(x, diffusivity, -dt, substeps=order)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/physics/diffuse.py", line 34, in explicit
    delta = laplace(field, weights=amount) if 'vector' in shape(amount) else amount * laplace(field)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/field/_field.py", line 227, in __mul__
    return self._op2(other, lambda d1, d2: d1 * d2)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/field/_field.py", line 295, in _op2
    other = math.tensor(other)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_tensors.py", line 1753, in tensor
    data = convert_(data, use_dlpack=False)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/backend/_backend.py", line 1515, in convert
    nparray = current_backend.numpy(tensor)
  File "/home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/jax/_jax_backend.py", line 99, in numpy
    return np.array(x)
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function native(implicit_diffusion) at /home/rcremese/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_functional.py:173 for jit. This concrete value was not available in Python because it depends on the values of the arguments natives[0] and natives[1].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The problem arrises only when I try to JIT-compile the function because when I remove the decorator the code works fine. If you need any supplementary information don't hesitate to contact me. Sincerely your.

holl- commented 1 year ago

Thanks! I fixed it on 2.4-develop: https://github.com/tum-pbs/PhiFlow/commit/cfebb7777f0db868ec514a3f1849d24c38a37a2f You can install it using

$ pip install --force-reinstall git+https://github.com/tum-pbs/PhiFlow@2.4-develop
holl- commented 1 year ago

The fix will also be included in version 2.3.3: https://github.com/tum-pbs/PhiFlow/commit/2135494fe28156c9e51d0b396fa77761eb7bf183

rcremese commented 1 year ago

Great ! Thank you for your reactivity !