tum-pbs / PhiFlow

A differentiable PDE solving framework for machine learning
MIT License
1.43k stars 193 forks source link

Documentation issue + no jit compilation for advection function #110

Closed rcremese closed 1 year ago

rcremese commented 1 year ago

Hi, it's me again, I'm still using phiflow and enjoyed it so far but I'm stuck with a documentation issue. In the jit_compile documentation page it's written the following : Args f Function to be traced. All positional arguments must be of type Tensor or PhiTreeNode returning a single Tensor or PhiTreeNode.

But I can not find any information about the PhiTreeNode class in the documentation. Would be nice to include a link to the class or to remove it from the doc if deprecated. Morever, I managed to implement a jit-compiled method which arguments weren't Tensors but CenteredGrid and int, are this special case of PhiTreeNode ?

Additionally, I'm facing a problem when trying to jit-compile the advection fonction with your wrapper. The input to the advection function is a PointCloud of Tensors I want to advect with velocity field, represented as a centered grid, and which is the second parameter of the function. Given that jited_advection = flow.math.jit_compile(flow.advect.advect) and that I try to jit-compile the advection function to improve its speed, I get the following error message which is particularly unhelpfull :

Traceback (most recent call last):    
  File "/home/rocremes/projects/snake-ai/src/snake_ai/physim/main.py", line 57, in deterministic_walk
    point_cloud : flow.PointCloud = jited_advection(point_cloud, gradient_field, time_step)
  File "/home/rocremes/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_functional.py", line 197, in __call__
    self.traces[key] = self._jit_compile(key)
  File "/home/rocremes/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/_functional.py", line 171, in _jit_compile
    PHI_LOGGER.debug(f"Φ-jit: '{f_name(self.f)}' called with new key. shapes={[s.volume for s in in_key.shapes]}, args={in_key.tree}")
  File "/home/rocremes/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/field/_point_cloud.py", line 170, in __repr__
    return "PointCloud[%s]" % (self.shape,)
  File "/home/rocremes/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/field/_field.py", line 136, in __getattr__
    return BoundDim(self, name)
  File "/home/rocremes/mambaforge/envs/snake-env/lib/python3.9/site-packages/phi/math/magic.py", line 449, in __init__
    raise AttributeError
AttributeError

Used backend : JAX

I would love to know if it's possible to jit compile the advection function and to differentiate through it. Thank you for the time you might take to respond.

holl- commented 1 year ago

Hi again,

  1. I'm actually not sure why pdoc doesn't automatically link to the PhiTreeNode documentation. I'll look into it. If you know why the links are broken, let me know.
  2. Compiling and differentiating through the advection should be possible. Could you send me a runnable script so I can reproduce the error?
holl- commented 1 year ago

I found the reason for the broken links. The online documentation will be fixed once 2.3 is released (which should be within the week).

rcremese commented 1 year ago

Thank you for the quick response. I also saw several broken links for demos in the documentation. If you are interested, I can list some of them. I will try to provide you a reproducible error as soon as possible. What is the best for this ? A collab notebook or a python file with all the dependencies listed ?

holl- commented 1 year ago

Sure, if you have a list, that would be great! A Python file would be the easiest to debug for me.

rcremese commented 1 year ago
from phi.jax import flow
import matplotlib.pyplot as plt
import numpy as np

# Field and gradient field definition
def l2(x):
    return flow.math.l2_loss(x)

field = flow.CenteredGrid(l2, bounds=flow.Box(x=(-1,1), y=(-1,1)),x=100, y=100)
gradient_field = flow.field.spatial_gradient(-field, type=flow.StaggeredGrid)

# Point cloud definition
points=[]
for x in np.linspace(-0.9,0.9,10):
    for y in np.linspace(-0.9,0.9,10):
        points.append(flow.vec(x=x, y=y))
points = flow.tensor(points, flow.instance('point'))
point_cloud = flow.PointCloud(points, bounds=field.bounds)

# First visualisation
flow.vis.plot([field, gradient_field, point_cloud], show_color_bar=False, same_scale=False)

# Advection functions definition (with and without decorator)
jited_advection = flow.math.jit_compile(flow.advect.points)

@flow.math.jit_compile
def step(point_cloud :flow.PointCloud, field: flow.field.Field, dt : float):
    return flow.advect.points(point_cloud, gradient_field, dt=0.1)

history = []
for t in range(10):
    ## 1 Works fine if not jit compiled
    point_cloud = flow.advect.points(point_cloud, gradient_field, dt=0.1)
    ## 2 test cases in which jit compilation through an error
    # point_cloud = jited_advection(point_cloud, gradient_field, dt=0.1)
    # point_cloud = step(point_cloud, gradient_field, dt=0.1)

    history.append(point_cloud)

# Final visualisation
flow.stack(history, flow.batch('time'))
flow.vis.plot(history, animate='time')
plt.show()
rcremese commented 1 year ago

Sorry to put it in comment of this issue but I haven't found an email adress to send you my python file. For the bug to occur, simply comment l.33 and uncomment l.35 or l.36. The thrown error is the one I copied in the first message.

My PhiFlow version : phiflow @ git+https://github.com/tum-pbs/PhiFlow@128d0809675b0be370c37d260ad19dae64a7d22e My backend : JAX

rcremese commented 1 year ago

Some of the boken links found in the documentation :

I don't know how hard it would be to implement, but having a search tool for the documentation would be a nice feature.

holl- commented 1 year ago

I'll check it out, soon.

holl- commented 1 year ago

The error seems to be fixed on the latest 2.3-develop version. You can upgrade by first uninstalling phiflow, then

pip install git+https://github.com/tum-pbs/PhiFlow@2.3-develop

I've updated the links there as well.

rcremese commented 1 year ago

Sorry to bother you again but it seems the 2.3-develop tag is unknown.

pip install git+https://github.com/tum-pbs/PhiFlow@2.3-develop
Collecting git+https://github.com/tum-pbs/PhiFlow@2.3-develop
  Cloning https://github.com/tum-pbs/PhiFlow (to revision 2.3-develop) to /tmp/pip-req-build-5o0sgquo
  Running command git clone --filter=blob:none --quiet https://github.com/tum-pbs/PhiFlow /tmp/pip-req-build-5o0sgquo
  WARNING: Did not find branch or tag '2.3-develop', assuming revision or ref.
  Running command git checkout -q 2.3-develop
  error: pathspec '2.3-develop' did not match any file(s) known to git.
  error: subprocess-exited-with-error
holl- commented 1 year ago

Yeah, 2.3 has been released! You can now write

$ pip install phiflow==2.3.0

The fixed documentation is also online!