microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Generate code for elementwise operations #962

Closed dcrc2 closed 2 years ago

dcrc2 commented 2 years ago

When the ks entry point is an elementwise function, generate its code in python rather than via ksc. The main purpose of this is to allow the loop to be parallelized in future (for GPU). It also avoids copying the output tensor.

Current limitations:

It doesn't look like it will be hard to generalize either of those things.

The following code is generated for vrelu3:

torch::Tensor entry(torch::Tensor arg0) {

    KS_ASSERT(arg0.sizes().size() == 1u);
    KS_ASSERT(arg0.is_contiguous());
    KS_ASSERT(arg0.scalar_type() == scalar_type_of_Float);
    auto* arg_data0 = arg0.data_ptr<float>();

    auto ret = torch::empty_like(arg0);
    auto* ret_data = ret.data_ptr<float>();
    for (int i = 0, ne = arg0.size(0); i != ne; ++i) {
        ret_data[i] = ks::relu3$af(&g_alloc , arg_data0[i]);
    }
    return ret;
}
torch::Tensor entry_vjp(torch::Tensor arg0, torch::Tensor arg1) {

    KS_ASSERT(arg0.sizes().size() == 1u);
    KS_ASSERT(arg0.is_contiguous());
    KS_ASSERT(arg0.scalar_type() == scalar_type_of_Float);
    auto* arg_data0 = arg0.data_ptr<float>();

    KS_ASSERT(arg1.sizes().size() == 1u);
    KS_ASSERT(arg1.is_contiguous());
    KS_ASSERT(arg1.scalar_type() == scalar_type_of_Float);
    auto* arg_data1 = arg1.data_ptr<float>();

    auto ret = torch::empty_like(arg0);
    auto* ret_data = ret.data_ptr<float>();
    for (int i = 0, ne = arg0.size(0); i != ne; ++i) {
        ret_data[i] = ks::rev$relu3$af(&g_alloc , arg_data0[i], arg_data1[i]);
    }
    return ret;
}

As things stand, this improves the performance of backwards but not forwards. (Calling [sufrev relu3] on each element is better optimized than doing one loop of [suffwdpass relu3] and another loop of [sufrevpass relu3].)

Before:

elementwise-no

After:

elementwise-yes
awf commented 2 years ago

@dcrc2 I've just realised: I don't see code in this PR to test/benchmark the elementwise calls? Can you make another PR if so to show that?

dcrc2 commented 2 years ago

@dcrc2 I've just realised: I don't see code in this PR to test/benchmark the elementwise calls? Can you make another PR if so to show that?

The existing benchmarks for vrelu3 now run this code. Do you mean that you'd like to be able to compare it to the previous method (where ksc generated the code for map)? I have some results for this above, but we could maintain both versions as separate benchmarks if we wanted.

awf commented 2 years ago

Oh of course, sorry for the noise.