mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
563 stars 40 forks source link

[bug?] Missing `dr.sync_thread` call in `export_`? #198

Open dvicini opened 9 months ago

dvicini commented 9 months ago

Hi,

We were looking into some buggy behavior when using Dr.Jit's .jax() call. After digging a bit it seems like there might be a missing sync_thread call in export_ for CUDA arrays.

The export_ function in generic.py synchronizes when migrating GPU arrays to host, or for LLVM arrays in general. But there is no synchronization when exporting from a Dr.Jit CUDA array to another framework's CUDA array:

        if b.IsCUDA and migrate_to_host:
            if b is a:
                b = type(a)(b)
            b = b.migrate_(_dr.AllocType.Host)
            _dr.sync_thread()
        elif b.IsLLVM:
            _dr.sync_thread()

Isn't this a bit risky? I could see another framework using a different CUDA stream and just assuming that the memory buffer is already filled in with data.

Here is my minimal reproducer. Due to this being a synchronization problem it might not be easily reproducible on other systems. If I uncomment the synchronization code, it always produces correct results. If I don't synchronize, the second imshow shows a garbage image (see below).

import drjit as dr
import jax
from matplotlib import pyplot as plt
import mitsuba as mi
import numpy as np

mi.set_variant("cuda_ad_rgb")

scene = mi.load_dict({
    'type': 'scene',
    'integrator': {'type': 'direct'},
    'emitter': {'type': 'constant'},
    'shape': {
        'type': 'rectangle',
        'bsdf': {
            'type': 'diffuse',
            'reflectance': {
                'type': 'bitmap',
                'bitmap': mi.Bitmap(np.random.rand(512, 512, 3)),
            },
        },
    },
    'sensor': {
        'type': 'perspective',
        'to_world': mi.ScalarTransform4f.look_at(
            [0, 0, 5], [0, 0, 0], [0, 1, 0]
        ),
        'film': {
            'type': 'hdrfilm',
            'width': 512,
            'height': 512,
            'pixel_format': 'rgb',
        },
    },
})

params = mi.traverse(scene)
image = mi.render(scene, params, spp=128)
plt.figure(figsize=(10, 10))
plt.imshow(np.array(image))
plt.title('First render')

params = mi.traverse(scene)
texture_init = np.array(np.random.rand(512, 512, 3))
params['shape.bsdf.reflectance.data'] = mi.TensorXf(texture_init)
params.update()

def render_image():
  rendered_image = mi.render(scene, params, spp=32, seed=1)
  # dr.eval(rendered_image) # uncommenting these 2 lines fixes the issue
  # dr.sync_thread()
  rendered_image = rendered_image.jax()
  plt.figure(figsize=(10, 10))
  plt.imshow(np.array(rendered_image))
  plt.title('Second render')

render_image()
image
njroussel commented 9 months ago

Hi @dvicini

This is a bit surprising. We should be using the default stream on Dr.Jit, which should force serial execution with other streams. So to my understanding, the copy to host in plt.imshow(np.array(rendered_image)) should only be executed after the _dr.eval(b) (in generic.py) even if the memcpy is launched on another stream as the evaluation is still taking place.

njroussel commented 9 months ago

Nevermind, my understanding was wrong. The synchronization behavior on the default stream seems to be quite a bit more complex than I thought.

dvicini commented 9 months ago

I am also unsure about the synchronization behavior of the default stream. Couldn't it be that the other framework (in this case Jax) launches some compute on a non-default stream that is not waiting on the default stream? So Dr.Jit would be returning a pointer to a memory region that is yet to be used by itself, but maybe Jax already tries to read from it.

I don't think another framework can know that a given CUDA memory pointer is pointing to the result of an unfinished dr.jit kernel?

wjakob commented 9 months ago

We're using stream zero, with the intent of synchronizing with respect to other streams (bidirectionally). We changed to that behavior at some point specifically to deal with other frameworks that might use a non-standard stream.

Inserting sync_thread would be a quite major pessivization since we are then waiting for any work on the GPU to finish, which is not actually needed. (It makes sense for the CPU/LLVM version to do this since we're subsequently accessing the data directly on the same device).

At some point, CUDA also added even more relaxed stream primitives that don't even synchronize with respect to stream zero. If JAX uses those then that could be an issue. But then JAX has an issue as well because it had better expose some synchronization mechanism to deal with the resulting mess (maybe it does, and we should then call that)

njroussel commented 9 months ago

So Dr.Jit would be returning a pointer to a memory region that is yet to be used by itself, but maybe Jax already tries to read from it.

Yes, that seems like what is happening here.

Ok, I found some some exceptions to the synchronization/blocking of the default stream in the CUDA driver API docs. So Jax could be using one of those. Actually this might be on Numpy's end. Jax shouldn't be doing anything else than getting a device pointer here and passing it to numpy in a dlpack.

dvicini commented 9 months ago

This all makes sense to me. I agree that Dr.Jit should not be unnecessarily pessimistic. I will try to investigate on the Jax-side what's going on. One potential cause of these issues could be (if the case) the use of CUDA_API_PER_THREAD_DEFAULT_STREAM. Using a per-thread default stream would also break the synchronization, if Jax uses threads. I will keep you posted

dvicini commented 9 months ago

I got some feedback from the Jax developers on what's going on here: The use of from_dlpack(a.__dlpack__()) to create a Jax array from a Dr.Jit array uses the legacy interface, ignoring any CUDA streams. It turns out that 1) Jax uses non-blocking streams and 2) the dlpack API allows specifying a stream argument. The consumer of the data can specify on which stream it will be using the data, so that the producer can sync appropriately.

Here the current dlpack spec on this topic: https://dmlc.github.io/dlpack/latest/python_spec.html#syntax-for-data-interchange-with-dlpack

Here the Jax code. Dr.Jit currently uses the legacy path under the else branch: https://github.com/google/jax/blob/main/jax/_src/dlpack.py#L67

So the correct solution would be for Dr.Jit to support the stream argument in __dlpack__ and sync to the provided stream (if provided). This is likely not just relevant for Jax, other frameworks might also leverage the dlpack-level stream support (eventually).

futscdav commented 9 months ago

To be clear, this is indeed not limited to Jax and I first encountered the issue when attempting to convert to tensorflow GPU tensors with .tf(). The synchronization seems to fix the issue, if at cost of some performance.

Thanks for looking into this!