KomputeProject / kompute

General purpose GPU compute framework built on Vulkan to support 1000s of cross vendor graphics cards (AMD, Qualcomm, NVIDIA & friends). Blazing fast, mobile-enabled, asynchronous and optimized for advanced GPU data processing usecases. Backed by the Linux Foundation.
http://kompute.cc/
Apache License 2.0
1.88k stars 145 forks source link

Support custom parameter in sequence template functions #331

Open Crydsch opened 11 months ago

Crydsch commented 11 months ago

This PR enables custom parameters in the template functions sequence.record(..), sequence.eval(..) and sequence.evalAsync(..) by replacing the specializations with a more generic approach.

In theory the compiler should be able to deduce the correct Constructor of the passed Operation class if there is just one. However it seems to require at least the first argument type to correctly deduce the constructor when given an initializer list like eval<op>({...}). Currently this done with template specializations that passes either std::vector<std::shared_ptr<Tensor>> tensors or std::shared_ptr<Algorithm> algorithm to disambiguate the templates.

I found a way to supply this first parameter type through the respective Operations class itself, rather than hard-coding them as template specializations. This makes the code a little compacter, but more importantly it allows custom user operations to use the same template and thus, nice syntax.

This turns

std::shared_ptr<OpCustom> op{ new OpCustom({CustomParameter}) };
sequence->eval(op);

into

sequence->eval<OpCustom>({CustomParameter});

The only change to existing operations is the addition of the first constructor type to each class. The current functionality is not changed otherwise.