tum-pbs / PhiFlow

A differentiable PDE solving framework for machine learning
MIT License
1.39k stars 189 forks source link

jit compile extrapolation #154

Closed KarlisFre closed 6 months ago

KarlisFre commented 6 months ago

Hi, Extrapolations are not treated correctly in jit-compile. If I run this program

from phi.tf.flow import *
from phiml.math.extrapolation import ConstantExtrapolation

@math.jit_compile()
def step(velocity):
    velocity, pressure = fluid.make_incompressible(velocity)
    return velocity

for i in range(100):
    windy = tf.random.uniform((), 0, 1)
    wind = vec(x=windy, y=tf.constant(0.0))
    boundary = ConstantExtrapolation(wind)
    velocity = StaggeredGrid((0,0), extrapolation.combine_sides(x=boundary, y=ZERO_GRADIENT), x=100, y=100, bounds=Box(x=100, y=100))
    print(velocity)
    step(velocity)

It gives a warning and runs very slowly: WARNING: tensorflow:5 out of the last 5 calls to <function JitFunction

holl- commented 6 months ago

It looks like the step function is being traced every time because the velocity boundary varies each time. I'll look into it.

A simple workaround would be to pass the boundary condition value as a tensor and set the boundary condition inside the function.

holl- commented 6 months ago

Changing the boundary behavior would be a major patch, so it will have to wait. In the meantime, here is an implementation of the workaround I mentioned:

from phi.tf.flow import *

@math.jit_compile
def step(velocity: Field, wind: Tensor):
    print("Tracing step()")
    velocity = velocity.with_extrapolation({'x': wind, 'y': ZERO_GRADIENT})
    velocity, pressure = fluid.make_incompressible(velocity)
    return velocity

for i in range(100):
    wind = vec(x=math.random_uniform(), y=0)
    velocity = StaggeredGrid((0,0), {'x': 0, 'y': ZERO_GRADIENT}, x=100, y=100, bounds=Box(x=100, y=100))
    print(velocity)
    step(velocity, wind)