microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Use torch::tensor instead of ks::tensor in entry points #931

Closed dcrc2 closed 3 years ago

dcrc2 commented 3 years ago

AB#19887

~Draft PR: the current version breaks some of the python tests.~ fixed

Move the conversions between torch types and ks types into C++ code (generated by python/ksc/cgen.py). Example code generated for relu3:

namespace ks {
namespace entry_points {
namespace generated {

torch::Tensor entry(torch::Tensor arg0) {
    if (g_logging) {
        std::cerr << "vrelu3$aT1f(" << arg0 << ") =" << std::endl;
    }
    auto ks_arg0 = convert_argument<ks::tensor<1, ks::Float>>(arg0);
    auto ks_ret = ks::vrelu3$aT1f(&g_alloc, ks_arg0);
    auto ret = convert_return_value<torch::Tensor>(ks_ret);
    if (g_logging) {
        std::cerr << ret << std::endl;
    }
    return ret;
}

torch::Tensor entry_vjp(torch::Tensor arg0, torch::Tensor arg1) {
    if (g_logging) {
        std::cerr << "sufrev$vrelu3$aT1f(" << arg0 << ", "  << arg1 << ") =" << std::endl;
    }
    auto ks_arg0 = convert_argument<ks::tensor<1, ks::Float>>(arg0);
    auto ks_arg1 = convert_argument<ks::tensor<1, ks::Float>>(arg1);
    auto ks_ret = ks::sufrev$vrelu3$aT1f(&g_alloc, ks_arg0, ks_arg1);
    auto ret = convert_return_value<torch::Tensor>(ks_ret);
    if (g_logging) {
        std::cerr << ret << std::endl;
    }
    return ret;
}

}
}
}
dcrc2 commented 3 years ago

An additional complication: we have some tests (test/python/test_tracing_core.py) which use ksc-generated code without PyTorch. In order to support this, I've moved the torch-specific code to a separate header (knossos-entry-points-torch.h), so that we include either knossos-entry-points.h or knossos-entry-points-torch.h as appropriate. In non-PyTorch mode it is an error to define an entry point which uses a tensor.

At the moment this is all done by adding a new boolean parameter use_torch to the C++-generating functions. In future there might be more than two choices here (supporting bindings to various languages/libraries), so we might want to upgrade this parameter to an abstract base class.