tum-pbs / PhiFlow

A differentiable PDE solving framework for machine learning
MIT License
1.42k stars 193 forks source link

Smoke Plume Sim Difference between Jax and PyTorch #71

Closed DalInar closed 2 years ago

DalInar commented 2 years ago

Hello! I'm running the simple smoke plume demo from the examples and tried comparing the results with phi.torch vs phi.jax. The top gif shows the results I'm getting with phi.jax, and the bottom show the results with phi.torch. The only difference between the runs is from phi.jaximport flow vs from phi.torch import flow. The results with torch seem correct, since they agree with the CPU results. My setup is Python 3.8.10, jax==0.3.14, jaxlib==0.3.14+cuda11.cudnn82, torch==1.12.0+cu116, cuda 11.6.

Am I doing something wrong here?

smoke_plume_jax smoke_plume_torch

Source code:

#from phi.jax import flow
from phi.torch import flow
import matplotlib.pyplot as plt
from tqdm import tqdm
from celluloid import Camera

N_TIME_STEPS = 150

def main():
    velocity = flow.StaggeredGrid(
        values=(0.0, 0.0),
        extrapolation=0.0,
        x=64,
        y=64,
        bounds=flow.Box(x=100, y=100),
    )
    smoke = flow.CenteredGrid(
        values=0.0,
        extrapolation=flow.extrapolation.BOUNDARY,
        x=200,
        y=200,
        bounds=flow.Box(x=100, y=100),
    )
    inflow = 0.2 * flow.CenteredGrid(
        values=flow.SoftGeometryMask(
            flow.Sphere(
                x=40,
                y=9.5,
                radius=5,
            )
        ),
        extrapolation=0.0,
        bounds=smoke.bounds,
        resolution=smoke.resolution,
    )
    pressure = None

    @flow.math.jit_compile
    def step(velocity_prev, smoke_prev, pressure, dt=1.0):
        smoke_next = flow.advect.mac_cormack(smoke_prev, velocity_prev, dt) + inflow
        buoyancy_force = smoke_next * (0.0, 0.1) @ velocity     # resamples smoke to velocity sample points
        velocity_tent = flow.advect.semi_lagrangian(velocity_prev, velocity_prev, dt) + buoyancy_force * dt
        velocity_next, pressure = flow.fluid.make_incompressible(velocity_tent, (), flow.Solve('auto', 1e-5, 0, x0=pressure))
        return velocity_next, smoke_next, pressure

    plt.style.use("dark_background")
    fig = plt.figure()
    camera = Camera(fig)

    for _ in tqdm(range(N_TIME_STEPS)):
        velocity, smoke, pressure = step(velocity, smoke, pressure)
        smoke_values_extracted = smoke.values.numpy("y,x")
        plt.imshow(smoke_values_extracted, origin="lower")
        camera.snap()

    animation = camera.animate()
    animation.save('media/smoke_plume_torch.gif')
    print("Done!")

if __name__ == "__main__":
    main()
holl- commented 2 years ago

Strange, when I run the script, I get the correct results. I am using PhiFlow from 2.2-develop and Jax 0.3.15 but I havn't encountered this problem with any other version so far. Could you try with the latest version of Φ-Flow?

pip uninstall phiflow
pip install git+https://github.com/tum-pbs/PhiFlow@2.2-develop