mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
563 stars 40 forks source link

Kernel freezing frontend #223

Open merlinND opened 5 months ago

merlinND commented 5 months ago

(As discussed with @wjakob and @DoeringChristian)

This code is not intended to work as-is, but rather to inform the kind of cases we need to check for, and how to do the Python function --> frozen function inputs / outputs mapping while being aware of aliasing, literal, etc. A fair bit of the code is there to detect and prevent unsupported cases.

I opened this PR mostly to make it easy to read & comment on the code, but we can also use it to iterate on the frontend.

Some assumptions made by this implementation and that would be nice to have in the final version as well:

A few things that will change for sure in the final implementation:

merlinND commented 5 months ago

A comment by @DoeringChristian made me realize that the launch mechanism was actually not super trivial, so I have pushed it to a new commit. Sorry I didn't remember about these earlier.

On the drjit-core side, the function below is used to create a FrozenKernel variable. Then, if a VarKind::FrozenKernel is found to be scheduled in jitc_eval, we automatically switch to the dedicated frozen kernel launch mechanism. I think that you already have that logic in place, but please let me know if you'd like me to open a similar PR on the drjit-core side.

/**
 * Note: the output variables must have already been created by the caller.
 */
uint32_t jitc_var_frozen_kernel(uint64_t hash_low,
                                uint64_t hash_high,
                                const char *kernel_ir_source,
                                uint32_t size,
                                uint32_t n_inputs,
                                const uint32_t *input_vars,
                                uint32_t n_outputs,
                                const uint32_t *output_vars,
                                const uint32_t *kernel_slot_to_var_index) {
    uint32_t var0;
    if (n_inputs > 0)
        var0 = input_vars[0];
    else if (n_outputs > 0)
        var0 = output_vars[0];
    else
        jitc_fail("jitc_var_frozen_kernel: expected at least one input or output variable.");
    auto [var_info, _] =
        jitc_var_check("jitc_var_frozen_kernel", var0);
    size = std::max(size, var_info.size);

    // TODO: should extend `jitc_var_check` to support arrays
    bool placeholder = false;
    for (uint32_t i = 0; i < n_inputs; ++i) {
        auto [inputs_info, inputs_v]
            = jitc_var_check("jitc_var_frozen_kernel", input_vars[i]);
        if (inputs_info.backend != var_info.backend) {
            jitc_fail(
                "jitc_var_frozen_kernel(): input %u (r%u) has different backend %d, expected %d!",
                i,
                input_vars[i],
                (int) inputs_info.backend,
                (int) var_info.backend);
        }

        size = std::max(size, inputs_info.size);
        placeholder |= inputs_info.placeholder;
    }

    // TODO: make sure to create the input and output variables
    // in exactly the rigth order to make sure they land in the
    // kernel input and output slots in the right order.
    uint32_t result = jitc_var_new_node_0(var_info.backend,
                                          VarKind::FrozenKernel,
                                          VarType::Void,
                                          size,
                                          placeholder,
                                          /*payload*/ 0,
                                          /*disable_lvn*/ true);
    jitc_var_mark_side_effect(result);

    // Mark the outputs as depending on the frozen kernel node.
    for (uint32_t i = 0; i < n_outputs; ++i) {
        Variable *o = jitc_var(output_vars[i]);
        o->dep[0] = result;
        jitc_var_inc_ref(result);
    }

    // Pass all this information and variables indices via `extra`
    Variable *v = jitc_var(result);
    v->extra = 1;
    Extra &e = state.extra[result];
    // Only the input variables and output pointers are made visible to the system,
    // we sneak the number and indices of output variables beyond in the array.
    e.n_dep = n_inputs + n_outputs;
    size_t dep_size
        = (e.n_dep + n_outputs + e.n_dep) * sizeof(uint32_t) + 3 * sizeof(size_t) + sizeof(char *);
    e.dep = (uint32_t *) malloc_check(dep_size);
    for (uint32_t i = 0; i < n_inputs; ++i) {
        e.dep[i] = input_vars[i];
        jitc_var_inc_ref(input_vars[i]);
    }
    for (uint32_t i = 0; i < n_outputs; ++i) {
        e.dep[n_inputs + i] = 0;
    }

    // Retain a copy of the kernel's source, which we will need
    // later for lookup in the kernel cache.
    // TODO: consider looking up the cached kernel right now and
    //       just copying the pointer to its own copy.
    size_t n_chars = strlen(kernel_ir_source) + 1;
    size_t n_bytes = sizeof(char) * n_chars;
    char *kernel_ir = (char *) malloc_check(n_bytes);
    memcpy(kernel_ir, kernel_ir_source, n_bytes);

    // Fill-in the rest of the extras
    uint8_t *ptr_bytes = (uint8_t *) e.dep;
    ptr_bytes += e.n_dep * sizeof(uint32_t);
    *(size_t *)(ptr_bytes) = hash_low;
    ptr_bytes += sizeof(size_t);
    *(size_t *)(ptr_bytes) = hash_high;
    ptr_bytes += sizeof(size_t);
    *(size_t *)(ptr_bytes) = n_outputs;
    ptr_bytes += sizeof(size_t);
    *(char **)(ptr_bytes) = kernel_ir;
    ptr_bytes += sizeof(char *);

    // Mapping from kernel slot to var index
    memcpy(ptr_bytes, kernel_slot_to_var_index, e.n_dep * sizeof(uint32_t));
#if !defined(NDEBUG)
    ptr_bytes += e.n_dep * sizeof(uint32_t);
    // Output var indices
    memcpy(ptr_bytes, output_vars, n_outputs * sizeof(uint32_t));
#endif

    e.callback_data = kernel_ir;
    e.callback = [](uint32_t, int free_var, void *ptr) {
        if (free_var) {
            free(ptr);
        }
    };
    e.callback_internal = true;

    jitc_log(Debug, "jit_var_frozen(r%u (%d outputs) <- %d inputs)", result, n_outputs, n_inputs);
    return result;
}