tum-pbs / PhiFlow

A differentiable PDE solving framework for machine learning
MIT License
1.39k stars 189 forks source link

jit_compile does not work with ZERO_GRADIENT extrapolation #136

Open KarlisFre opened 1 year ago

KarlisFre commented 1 year ago

Hi, this code gives an error: RuntimeError: The size of tensor a (63) must match the size of tensor b (62) at non-singleton dimension 0

Without jit-compile it works. I am using 2.4-develop branch.

from phi.torch.flow import *

object_geometry = Box(x=(40,60), y=(40, 60), z=(40, 60))
object_geometry = object_geometry.rotated((0.0, 0.1, 0.0))
OBSTACLE = Obstacle(object_geometry, velocity=(0.,0.,0.), angular_velocity=(0.,0.2, 0.2))

velocity = StaggeredGrid((0, 0, 0), ZERO_GRADIENT, x=62, y=62, z=62, bounds=Box(x=100, y=100, z=100))
smoke = CenteredGrid(0, ZERO_GRADIENT, x=62, y=62, z=62, bounds=Box(x=100, y=100, z=100))
INFLOW = 0.2 * resample(Sphere(x=50, y=50, z=10, radius=5), to=smoke, soft=True)
velocity, pressure = fluid.make_incompressible(velocity)

obst_mask = resample(OBSTACLE.geometry, smoke)
plot({"3D": obst_mask})
vis.show()

@jit_compile  # Only for PyTorch, TensorFlow and Jax
def step(v, s, p,obstacle, dt=1.):
    s = advect.mac_cormack(s, v, dt) + INFLOW
    buoyancy = resample(s * (0, 0, 0.1), to=v)
    v = advect.semi_lagrangian(v, v, dt) + buoyancy * dt
    new_geometry = obstacle.geometry.rotated((0.0, 0.1, 0.1))
    obstacle=obstacle.copied_with(geometry=new_geometry)
    v, p = fluid.make_incompressible(v, (obstacle))
    return v, s, p, obstacle

#for _ in view(smoke, velocity, 'pressure',obst_mask, play=True, namespace=globals(),port=6006).range(warmup=1):
for i in range(10):
    velocity, smoke, pressure, OBSTACLE = step(velocity, smoke, pressure, OBSTACLE)
    obst_mask = resample(OBSTACLE.geometry, smoke)
    print(i)
holl- commented 1 year ago

Hi, thanks for the bug report. Running the simulation with Jax or TensorFlow seems to work and is also faster once jit-compiled. I recommend using Jax for the time being. I'll look into why torch fails.