tum-pbs / PhiFlow

A differentiable PDE solving framework for machine learning
MIT License
1.39k stars 189 forks source link

values returned from jit compiled function #155

Closed KarlisFre closed 1 month ago

KarlisFre commented 6 months ago

Hi, The workaround mentioned in #154 does not work if I want to use the value from the jit-compiled function:

from phi.tf.flow import *

@math.jit_compile
def step(velocity: Field, wind: Tensor):
    print("Tracing step()")
    velocity = velocity.with_extrapolation({'x': wind, 'y': ZERO_GRADIENT})
    velocity, pressure = fluid.make_incompressible(velocity)
    return velocity

for i in range(100):
    wind = vec(x=math.random_uniform(), y=0)
    velocity = StaggeredGrid((0,0), {'x': 0, 'y': ZERO_GRADIENT}, x=100, y=100, bounds=Box(x=100, y=100))
    velocity = step(velocity, wind)
    velocity = advect.semi_lagrangian(velocity, velocity, dt=0.1) # do something with velocity

It throws an error: The tensor <tf.Tensor 'natives_2:0' shape=(2,) dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=native(step), id=139997627365216), which is out of scope.

holl- commented 6 months ago

Right, same problem on the function output. This should work:

@math.jit_compile
def step(velocity: Field, wind: Tensor):
    print("Tracing step()")
    velocity = velocity.with_extrapolation({'x': wind, 'y': ZERO_GRADIENT})
    velocity, pressure = fluid.make_incompressible(velocity)
    return velocity.values

for i in range(100):
    wind = vec(x=math.random_uniform(), y=0)
    velocity = StaggeredGrid((0,0), {'x': 0, 'y': ZERO_GRADIENT}, x=100, y=100, bounds=Box(x=100, y=100))
    velocity = velocity.with_values(step(velocity, wind))
    velocity = advect.semi_lagrangian(velocity, velocity, dt=0.1)  # do something with velocity

Cheers!