ROCm / MIOpen

AMD's Machine Intelligence Library
https://rocm.docs.amd.com/projects/MIOpen/en/latest/
Other
1.05k stars 218 forks source link

[Solvers] New interface for solvers #1887

Open averinevg opened 1 year ago

averinevg commented 1 year ago

End goals

  1. Ability to reuse almost any solver-related code between primitives and modes. 1.1. Ability to get solvers of any primitive from the registry.
  2. Make developing, maintaining and debugging solver-related code easier. Less code duplication - fewer time spent on that by every developer.
  3. Make library less error-prone. Duplicate code leads to copy-paste errors and fixes only applied on the copy.
  4. Make the source friendlier to reader. 4.1. New team members and readers in general are easily overwhelmed by thousands and thousands of lines in solver-related machinery. This may lead them to seeing everything outside of the solver they make as an opaque black box and not understanding the machinery they routinely use.

Steps

Implementation

A simplified version of the future base class is shown below:

struct SolverBase
{
    const std::string& SolverDbId();
    const std::string& AltSolverDbId();
    bool IsApplicable(const ExecutionContext& ctx, const ProblemDescriptionBase& problem);
    bool IsTunable();
    bool IsDynamic();
    float GetWti(const ExecutionContext& ctx, const ProblemDescriptionBase& problem);
    size_t GetWorkspaceSize(const ExecutionContext& ctx, const ProblemDescriptionBase& problem);
    bool MayNeedWorkspace();
    ConvSolution FindSolution(const ExecutionContext& ctx,
                              const ProblemDescriptionBase& problem,
                              Db& db,
                              const AnyInvokeParams& invoke_ctx);
};

We have several methods inside AnySolver that are only needed for Tuna. My suggestion is to move them to separated class.

#ifdef MIOPEN_ENABLE_FIN
struct FinSolverBase : SolverBase
{
    bool TestSysDbRecord(const ExecutionContext& ctx, const ProblemDescriptionBase& problem, const DbRecord& record);
    std::string GetPerfCfgParams(const ExecutionContext& ctx, const ProblemDescriptionBase& problem, Db& db);
    std::vector<ConvSolution> GetAllSolutions(const ExecutionContext& ctx, const ProblemDescriptionBase& problem);
};
#endif

This is a mixin class which is already implemented

template <class Context, class Problem>
struct SolverMixin : SolverBase
{
    bool IsApplicable(const Context& ctx, const Problem& problem);
    float GetWti(const Context& ctx, const Problem& problem);
    size_t GetWorkspaceSize(const Context& ctx, const Problem& problem);
};

Base class for non tunable solvers is show below. It contains only one method

template <class Context, class Problem>
struct NonTunableSolverBase : SolverMixin<Context, Problem>
{
    ConvSolution GetSolution(const Context& ctx, const Problem& problem);
};
// For non tunable convolution solvers
using ConvSolver = NonTunableSolverBase<ConvolutionContext, ProblemDescription>;
// For activ solvers
using ActivSolver = NonTunableSolverBase<ExecutionContext, miopen::activ::ProblemDescription>;
// For batchnorm solvers
using BatchnormSolver = NonTunableSolverBase<ExecutionContext, miopen::batchnorm::ProblemDescription>;
// For pooling solvers
using PoolingSolver = NonTunableSolverBase<ExecutionContext, miopen::pooling::ProblemDescription>;

Base class for tunable solvers is shown below

template <class Context, class Problem, class PerformanceConfig>
struct TunableSolverBase : SolverMixin<Context, Problem>
{
    PerformanceConfig GetDefaultPerformanceConfig(const Context&, const Problem& problem);
    bool IsValidPerformanceConfig(const Context&, const Problem& problem, const PerformanceConfig& config);
    PerformanceConfig Search(const Context&, const Problem& problem, const AnyInvokeParams&);
    ConvSolution GetSolution(const Context&, const Problem& problem, const PerformanceConfig& config);
};
// For tunable convolution solvers
template <class PerformanceConfig>
using ConvTunableSolver = TunableSolverBase<ConvolutionContext, ProblemDescription, PerformanceConfig>;
averinevg commented 1 year ago

@atamazov @DrizztDoUrden Please take a look

JehandadKhan commented 1 year ago

@averinevg How would this impact the solver registry ? Currently, the solver registry does properly support any primitive other than convolution.

cderb commented 1 year ago

@averinevg I would recommend not breaking apart the functionality used by Fin into a separate class of Solvers. As the functionality Fin will need to access should be inherent to the Solver.

I think it is easy to agree that IsTunable is in the same category of informative functions as IsApplicable and IsDynamic. This would also be important for casting from base solver to tunable and non-tunable solvers.

GetAllSolutions is for sure a tuning specific function as it pulls out details of the process (list of all possible solutions from the solver), but this should be callable from all types of solvers to be effective.

GetPerfCfgParams and TestSysDbRecord (now renamed to TestPerfCfgParams) are vital functions for interacting with the performance database using the serialized PerformanceConfig. These were implemented in any_solver because there wasn't a public access to the PerformanceConfig used by the solver previously. In your proposal it seems this structure would now be retrievable. As a matter of convenience they would be nice to keep, but it appears that the physical limitation from flushing these out in Fin would be lifted.

atamazov commented 1 year ago

@cderb Your thoughts about IsTunable and GetAllSolutions look reasonable to me.

Regarding TestSysDbRecord/TestPerfCfgParams -- this is not a Solver's job to access databases. If fin needs to read database and validate a record, then it is fin who should read database, but then fin is free to use Solver API to validate the record.

atamazov commented 1 year ago

@JehandadKhan

@averinevg How would this impact the solver registry ? Currently, the solver registry does properly support any primitive other than convolution.

I think that the registry will continue to support all primitives after the proposed changes are made. @averinevg Please confirm.

averinevg commented 1 year ago

@JehandadKhan

@averinevg How would this impact the solver registry ? Currently, the solver registry does properly support any primitive other than convolution.

AnySolver would be replaced with a pointer to SolverBase inside the solver registry. All functionality will be kept as @atamazov wrote. After that, we could add the necessary functionality for the rest of the primitives.

averinevg commented 1 year ago

Update: IsTunable() has been moved to SolverBase as proposed by @cderb.

averinevg commented 1 year ago

@cderb I believe that any parts of Fin should not be present in the release build. But we can keep them there for a reasonable transition period. If moving them in Fin turns out to be a non-trivial task, they will remain inside the base class. What is your opinion?

cderb commented 1 year ago

@averinevg if these two functions can be made accessible from all solvers:

PerformanceConfig GetDefaultPerformanceConfig(const Context&, const Problem& problem);
bool IsValidPerformanceConfig(const Context&, const Problem& problem, const PerformanceConfig& config);

Then I can accomplish the following functions within the fin library:

bool TestSysDbRecord(const ExecutionContext& ctx, const ProblemDescriptionBase& problem, const DbRecord& record);
std::string GetPerfCfgParams(const ExecutionContext& ctx, const ProblemDescriptionBase& problem, Db& db);

It would be helpful if there is some transition time so that I can migrate these functions into fin as well.

atamazov commented 1 year ago

Good!

averinevg commented 1 year ago

Update: ifdef TUNA_BUILD was replaced with MIOPEN_ENABLE_FIN

averinevg commented 1 year ago

Update: description has been changed as suggested by @DrizztDoUrden

averinevg commented 1 year ago

ProblemDescriptionBase has been implemented in https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1875