Closed dcrc2 closed 3 years ago
I've tried to avoid having to #include "knossos.h"
from the pybind code, in order to minimize compile times. But when a function takes a scalar argument, the entry points still use the ks typedefs ks::Float
or ks::Integer
, so we do need to include those definitions. I've resolved this here by moving the typedefs to a new header knossos-types.h
.
(Aside: I'm not sure we should actually be using ks::Float
or ks::Integer
in our entry points at all. Should we use torch::Scalar
(or a 0-dimensional torch::Tensor
) instead? I tried using torch::Scalar
, but seems that it would prevent us testing the function with plain python floating-point arguments, which is how we're using these functions at the moment.)
(Aside: I'm not sure we should actually be using
ks::Float
orks::Integer
in our entry points at all. Should we usetorch::Scalar
(or a 0-dimensionaltorch::Tensor
) instead? I tried usingtorch::Scalar
, but seems that it would prevent us testing the function with plain python floating-point arguments, which is how we're using these functions at the moment.)
I like the idea of plain Python arguments, which will AFAIU always mean double
and int
, so either way we don't wan't to use ks::{Float,Int}
in the entry points?
This PR changes the function
generate_cpp_for_py_module_from_ks
so that two C++ files are generated:This separation will be necessary for building CUDA code. (The first of these files will then become a
.cu
file, whereas the second remains a.cpp
file.)At the moment we could just concatenate the two strings together and compile a single C++ file. But I've changed the function
build_module_using_pytorch_from_cpp_backend
to accept multiple C++ strings, as it seemed cleaner to compile these separately even in the CPU-only case.In order for this to work, the file defining the pybind11 module needs to have access to declarations of the C++ entry points. So the functions which generate code for the entry points have been changed so that they return declarations as well as definitions.