NVIDIA / warp

A Python framework for high performance GPU simulation and graphics
https://nvidia.github.io/warp/
Other
4.24k stars 241 forks source link

[BUG] Incorrect code generation with in-place division when using Warp 1.4.0+ #342

Open romerojosh opened 1 week ago

romerojosh commented 1 week ago

Bug Description

I have a program that works fine with Warp 1.3.1 and breaks with Warp 1.4.0+. I've found one kernel so far that leads to issues:

@wp.kernel
def trisolve_periodic_multi(x: wp.array2d(dtype=Any),
                            q: wp.array2d(dtype=Any),
                            s: wp.array2d(dtype=Any),
                            qe: wp.array2d(dtype=Any),
                            ap: wp.array2d(dtype=Any),
                            am: wp.array2d(dtype=Any),
                            ac: wp.array2d(dtype=Any),
                            n: int):

  irhs = wp.tid()

  q[0, irhs] = -ap[0, irhs] / ac[0, irhs]
  s[0, irhs] = -am[0, irhs] / ac[0, irhs]
  fn = x[n - 1, irhs]
  x[0, irhs] /= ac[0, irhs]

  # forward elimination sweep
  for i in range(1, n):
    p = x.dtype(1.0) / (ac[i, irhs] + am[i, irhs]* q[i - 1, irhs])
    q[i, irhs] = -ap[i, irhs] * p
    s[i, irhs] = -am[i, irhs] * s[i - 1, irhs] * p
    x[i, irhs] = (x[i, irhs] - am[i, irhs] * x[i - 1, irhs]) * p

  s[n - 1, irhs] = x.dtype(1.0)
  qe[n - 1, irhs] = x.dtype(0.0)

  # backward pass
  for i in range(n - 2, -1, -1):
    s[i, irhs] += q[i, irhs] * s[i + 1, irhs]
    qe[i, irhs] = x[i, irhs] + q[i, irhs] * qe[i + 1, irhs]

  x[n - 1, irhs] = ((fn - ap[0, irhs] * qe[0, irhs] - am[0, irhs] * qe[n - 2, irhs]) /
                    (ap[0, irhs] * s[0, irhs] + am[0, irhs] * s[n - 2, irhs] + ac[0, irhs]))

  # backward elimination pass
  for i in range(n - 2, -1, -1):
    x[i, irhs] = x[n - 1, irhs] * s[i, irhs] + qe[i, irhs]

After some digging, I think the issue might stem from the line with in-place division:

x[0, irhs] /= ac[0, irhs]

If I look at the generated code, I find the following in 1.3.1 for that line:

        // x[0, irhs] /= ac[0, irhs]                                                              <L 21>
        var_19 = wp::address(var_x, var_1, var_0);
        var_20 = wp::address(var_ac, var_1, var_0);
        var_21 = wp::load(var_19);
        var_22 = wp::load(var_20);
        var_23 = wp::div(var_21, var_22);
        wp::array_store(var_x, var_1, var_0, var_23);

while in 1.4.0, I see just a single load operation, which is clearly not right:

        // x[0, irhs] /= ac[0, irhs]                                                              <L 21>
        var_19 = wp::address(var_ac, var_1, var_0);

The output from 1.4.0 and 1.4.1 from compiling the program also prints a warning like:

Warning: in-place op <ast.Div object at 0x7f71739d6aa0> is not differentiable

which also indicates this in-place division might be causing issues.

As an additional experiment, I also changed the in place division line to:

x[0, irhs] = x[0, irhs] / ac[0, irhs]

and that fixes it with Warp 1.4.0+.

I don't care that my kernel is not differentiable, so it is a little alarming that the in-place division warning seems to just produce incorrect code. Is this expected and should this warning be fatal?

System Information

No response

shi-eric commented 1 week ago

Thanks for reporting this @romerojosh, I'm surprised that we don't test in-place operations in our unit tests 😬

The problem started with a69d061f6a959307828acf239f73a58012e389bf according to git bisect:

a69d061f6a959307828acf239f73a58012e389bf is the first bad commit
commit a69d061f6a959307828acf239f73a58012e389bf
Author: Zach Corse <zcorse@nvidia.com>
Date:   Wed Sep 4 22:10:21 2024 -0700

    Array in-place autodiff fixes

 CHANGELOG.md             |   1 +
 warp/codegen.py          |  63 +++++++++++++++++-
 warp/native/array.h      |   1 -
 warp/tests/test_array.py | 168 +++++++++++++++++++++++++++++++++++++++++++++++
 warp/types.py            |  37 +++++++++++
 5 files changed, 266 insertions(+), 4 deletions(-)
shi-eric commented 1 week ago

In this case, the new code added as part of a69d061 does nothing when the operation is a multiply or division: https://github.com/NVIDIA/warp/blob/1d9a0719d11b76162e84c32468081306603052f5/warp/codegen.py#L2579

For now I will restore the old behavior by adding a missing make_new_assign_statement() and muting the warning unless the verbosity is turned on.