mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
603 stars 45 forks source link

JIT compiling when inverse rendering #175

Closed uguuuuuu closed 1 year ago

uguuuuuu commented 1 year ago

The JIT compiler compiles every iteration (except for the first) even if it's the second time I run the optimization. Is there a workaround? I want to time the optimization but I want to exclude the compiling time.

njroussel commented 1 year ago

Hi @uguuuuuu

Unless something in your setup has changed between your two optimizations there is no reason for the JIT to re-compile (even the first iteration's) kernels.

There is some tooling to measure different runtimes:

dr.set_flag(dr.JitFlag.KernelHistory, True)
some_computation()
kernels = dr.kernel_history() # will return all kernels executed since the last call to dr.kernel_history() or since the beginning of the process

kernels is a list of dict object. You'll want to filter the list to only keep the kernels which have type JIT (the other stuff is mostly internal to the JIT itself). In each dict you'll find the following keys (all values are in milliseconds I believe)):

uguuuuuu commented 1 year ago

The reason for the re-compiling was because I created a new sampler every iteration. But it still doesn't hit the cache somehow except for the first iteration.

njroussel commented 1 year ago

This most likely indicates that you have some constant that is being baked into your kernel and this constant changes at every iteration. My recommendation would be to step through your code with pdb and enable INFO-level logs. This will at least allow you to understand the scope of your individual kernels and hopefully determine what is changing at each iteration.

Without some sort of a reproducer, I won't be able to help much more than this. I will close this issue for now, as it doesn't seem to be a bug in Dr.Jit directly. We can continue discussing in this thread if you have any more information to add.

uguuuuuu commented 1 year ago

Hi Nicolas, thank you for your advice and help. I really appreciate it. The problem is probably because I have to build a distribution whose number of entries differs from iteration to iteration. Here is a little reproducer.

# create a sampler with spp = 1
sampler = util.create_sampler(spp=1)

dr.set_flag(dr.JitFlag.KernelHistory, True)
for i in range(4):
    # build a distribution whose number of entries is a function of the iteration number
    sampler_cloned, _ = util.prepare_sampler(sampler, n_cells=i+1, seed=0)
    distr = mi.DiscreteDistribution(sampler_cloned.next_1d())
    # sample from distribution
    sampler_cloned, _ = util.prepare_sampler(sampler, 1024, seed=i+1)
    idx = distr.sample(sampler_cloned.next_1d())
    # some computation
    dr.eval(dr.sum(idx))

    kernels = dr.kernel_history()
    dr.kernel_history_clear()
    for k in kernels:
        if k['type'] == dr.KernelType.JIT:
            if k['cache_hit'] == 0:
                print(k)

Then the second kernel is recompiled every iteration.

njroussel commented 1 year ago

Hi again @uguuuuuu

Indeed, I don't think there is a workaround in this case unfortunately.

DoeringChristian commented 1 year ago

Hi @uguuuuuu,

I think I have found a solution to your problem, though it isn't very clean. You can always create a distribution of n in your case 4 bins but only sample from the first i. A better approach would be to implement your own distribution. Here is my approach:

import mitsuba as mi
import drjit as dr

if __name__ == "__main__":
    mi.set_variant("cuda_ad_rgb")

    dr.set_flag(dr.JitFlag.KernelHistory, True)
    # dr.set_log_level(dr.LogLevel.Trace)
    for i in range(4):
        dist_sampler: mi.Sampler = mi.load_dict({"type": "independent"})
        # build a distribution whose number of entries is a function of the iteration number
        epsilon = 1.0

        dist_sampler.seed(seed=0, wavefront_size=4)
        dist_samples = dist_sampler.next_1d()
        dist_samples[dr.arange(mi.UInt, 4) > dr.opaque(mi.UInt, i)] = epsilon
        distr = mi.DiscreteDistribution(dist_samples)

        sampler: mi.Sampler = mi.load_dict({"type": "independent"})
        sampler.seed(seed=i + 1, wavefront_size=1024)
        sample = sampler.next_1d()
        n_epsilons = 4 - (i + 1)
        sample *= dr.opaque(mi.Float, 1.0 - ((epsilon * n_epsilons) / distr.sum()[0]))
        idx = distr.sample(sample)
        # some computation
        dr.eval(dr.sum(idx))
        print(f"{i=}")

        kernels = dr.kernel_history()
        dr.kernel_history_clear()
        for k in kernels:
            if k["type"] == dr.KernelType.JIT:
                if k["cache_hit"] == 0:
                    print(k)
uguuuuuu commented 1 year ago

Hi @DoeringChristian @njroussel I found out that somehow even if the size of the distribution doesn't change, the second kernel still recompiles as long as the values of the distribution depends on the iteration number. Below is a little reproducer

# create a sampler with spp = 1
sampler = util.create_sampler(spp=1)

dr.set_flag(dr.JitFlag.KernelHistory, True)
for i in range(3):
    # build a distribution whose entry values are a function of the iteration number
    sampler_cloned, _ = util.prepare_sampler(sampler, n_cells=64, seed=0)
    distr = mi.DiscreteDistribution2D(mi.TensorXf(sampler_cloned.next_1d() + dr.opaque(mi.Float, i), [8, 8]))
    # sample from distribution
    sampler_cloned, _ = util.prepare_sampler(sampler, 1024, seed=i+1)
    idx, sample, pmf = distr.sample(sampler_cloned.next_2d())
    # some computation
    dr.eval(idx, sample, pmf)

    kernels = dr.kernel_history()
    dr.kernel_history_clear()
    for k in kernels:
        if k['type'] == dr.KernelType.JIT:
            if k['cache_hit'] == 0:
                print(k)
DoeringChristian commented 1 year ago

Hi @uguuuuuu Ok I think I found the issue.

Using a 1D distribution works, however with a 2D distribution we get recompiled kernels. Compare the two reproducers below. I took a look at the compiled kernels and found that r121 and r195 are different between two consecutive kernels. The variable r195 is set to 0x3a0ad917 in the last kernel, which is equal to 0.000529663113412 in floating point representation. When printing the distribution we can see that the normalization factor is 0.000529663. I took a look at the implementations of both the 1D and 2D distributions and indeed we can see that for the normalization of the 1D distribution a opaque is used (see here) whereas the normalization of the 2D distribution is not passed through an opaque (see here).

@njroussel Should I create a PR to fix that?

1D distribution:

import drjit as dr
import mitsuba as mi

if __name__ == "__main__":
    mi.set_variant("cuda_ad_rgb")

dr.set_flag(dr.JitFlag.KernelHistory, True)

if __name__ == "__main__":
    sampler_base: mi.Sampler = mi.load_dict({"type": "independent"})
    for i in range(30):
        sampler = sampler_base.clone()
        sampler.seed(0, 64)
        distr = mi.DiscreteDistribution(sampler.next_1d() + dr.opaque(mi.Float, i))

        sampler = sampler_base.clone()
        sampler.seed(i, 1024)
        idx = distr.sample(sampler.next_1d())

        dr.eval(idx)
        print(f"{idx=}")

        kernels = dr.kernel_history()
        dr.kernel_history_clear()
        for k in kernels:
            if k["type"] == dr.KernelType.JIT:
                if k["cache_hit"] == 0:
                    print(k)

2D distribution:

import drjit as dr
import mitsuba as mi

if __name__ == "__main__":
    mi.set_variant("cuda_ad_rgb")

dr.set_flag(dr.JitFlag.KernelHistory, True)

if __name__ == "__main__":
    sampler_base: mi.Sampler = mi.load_dict({"type": "independent"})
    for i in range(30):
        sampler = sampler_base.clone()
        sampler.seed(0, 64)
        distr = mi.DiscreteDistribution2D(
            mi.TensorXf(sampler.next_1d() + dr.opaque(mi.Float, i), [8, 8])
        )

        sampler = sampler_base.clone()
        sampler.seed(i, 1024)
        idx, sample, pmf = distr.sample(sampler.next_1d())

        dr.eval(idx, sample, pmf)
        print(f"{idx=}")
        print(f"{sample=}")
        print(f"{pmf=}")

        kernels = dr.kernel_history()
        dr.kernel_history_clear()
        for k in kernels:
            if k["type"] == dr.KernelType.JIT:
                if k["cache_hit"] == 0:
                    print(k)
njroussel commented 1 year ago

Thanks for pointing this out @DoeringChristian ! I've pushed a fix in thix commit: https://github.com/mitsuba-renderer/mitsuba3/commit/c69079ad263138adae7586cdd06ff804b3e7dd5f :wink: