microsoft / proxy

Proxy: Next Generation Polymorphism in C++
https://wg21.link/p3086
MIT License
2.15k stars 131 forks source link

What's the best way to implement a LD_PRELOAD-able plug-in registry class in proxy style? #143

Closed MoFHeka closed 3 months ago

MoFHeka commented 3 months ago

I usually create a static InitOnStartupMarker hold a registry class. There's a map in the registry class which key is the plug-in name and value is the constructor. The Lookup function will return the virtual class pointer according the plug-in name from the registry map. Also there's a DeferRegister function which through the static maker in plug-in side code to register the plug-in when load the dynamic library.

Any good idea or example to make it more modern?

Here is a sample:

storage_registration.hpp:

#pragma once

#ifndef MEEPO_EMBEDDING_STORAGE_REGISTRATION_H_
#define MEEPO_EMBEDDING_STORAGE_REGISTRATION_H_

#include <type_traits>
#include <utility>

#define ME_TYPE_NAME(type) #type

#define ME_ATTRIBUTE_UNUSED __attribute__((unused))

// Compiler attributes
#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG)
// Compiler supports GCC-style attributes
#define MEEPO_EMBEDDING_ATTRIBUTE_UNUSED __attribute__((unused))
#elif defined(_MSC_VER)
// Non-GCC equivalents
#define MEEPO_EMBEDDING_ATTRIBUTE_UNUSED
#else
// Non-GCC equivalents
#define MEEPO_EMBEDDING_ATTRIBUTE_UNUSED
#endif

// SELECTIVE_REGISTRATION (not supported now)
#define ME_SHOULD_REGISTER_STORAGE(cls) true

namespace meepo_embedding {
namespace storage {

// An InitOnStartupMarker is 'initialized' on program startup, purely for the
// side-effects of that initialization - the struct itself is empty. (The type
// is expected to be used to define globals.)
//
// The '<<' operator should be used in initializer expressions to specify what
// to run on startup. The following values are accepted:
//   - An InitOnStartupMarker. Example:
//      InitOnStartupMarker F();
//      InitOnStartupMarker const kInitF =
//        InitOnStartupMarker{} << F();
//   - Something to call, which returns an InitOnStartupMarker. Example:
//      InitOnStartupMarker const kInit =
//        InitOnStartupMarker{} << []() { G(); return
//
// See also: ME_INIT_ON_STARTUP_IF
struct InitOnStartupMarker {
  constexpr InitOnStartupMarker operator<<(InitOnStartupMarker) const {
    return *this;
  }

  template <typename T>
  constexpr InitOnStartupMarker operator<<(T&& v) const {
    return std::forward<T>(v)();
  }
};

// Conditional initializer expressions for InitOnStartupMarker:
//   ME_INIT_ON_STARTUP_IF(cond) << f
// If 'cond' is true, 'f' is evaluated (and called, if applicable) on startup.
// Otherwise, 'f' is *not evaluated*. Note that 'cond' is required to be a
// constant-expression, and so this approximates #ifdef.
//
// The implementation uses the ?: operator (!cond prevents evaluation of 'f').
// The relative precedence of ?: and << is significant; this effectively expands
// to (see extra parens):
//   !cond ? InitOnStartupMarker{} : (InitOnStartupMarker{} << f)
//
// Note that although forcing 'cond' to be a constant-expression should not
// affect binary size (i.e. the same optimizations should apply if it 'happens'
// to be one), it was found to be necessary (for a recent version of clang;
// perhaps an optimizer bug).
//
// The parens are necessary to hide the ',' from the preprocessor; it could
// otherwise act as a macro argument separator.
#define ME_INIT_ON_STARTUP_IF(cond)                       \
  (::std::integral_constant<bool, !(cond)>::value)        \
      ? ::meepo_embedding::storage::InitOnStartupMarker{} \
      : ::meepo_embedding::storage::InitOnStartupMarker {}

// Wrapper for generating unique IDs (for 'anonymous' InitOnStartup definitions)
// using __COUNTER__. The new ID (__COUNTER__ already expanded) is provided as a
// macro argument.
//
// Usage:
//   #define M_IMPL(id, a, b) ...
//   #define M(a, b) ME_NEW_ID_FOR_INIT(M_IMPL, a, b)
#define ME_NEW_ID_FOR_INIT_2(m, c, ...) m(c, __VA_ARGS__)
#define ME_NEW_ID_FOR_INIT_1(m, c, ...) ME_NEW_ID_FOR_INIT_2(m, c, __VA_ARGS__)
#define ME_NEW_ID_FOR_INIT(m, ...) \
  ME_NEW_ID_FOR_INIT_1(m, __COUNTER__, __VA_ARGS__)

}  // namespace storage
}  // namespace meepo_embedding

#endif  // MEEPO_EMBEDDING_STORAGE_REGISTRATION_H_

storage_registry.hpp:

#pragma once

#ifndef MEEPO_EMBEDDING_STORAGE_registry_H_
#define MEEPO_EMBEDDING_STORAGE_registry_H_

#include <algorithm>
#include <map>
#include <memory>
#include <string>

#include "storage_interface.hpp"
#include "storage_registration.hpp"

namespace meepo_embedding {
namespace storage {
namespace registry {

const std::string CreateFactoryKey(const std::string&& device,
                                   const std::string&& cls_name);

class StorageFactory {
 public:
  virtual pro::proxy<StorageInterface> Create() = 0;
  virtual ~StorageFactory() = default;
};  // class StorageFactory

class StorageRegistry {
 private:
  std::map<std::string, std::unique_ptr<StorageFactory>> DeferRegistrationData_;

 private:
  struct StorageFactoryImpl : public StorageFactory {
    explicit StorageFactoryImpl(pro::proxy<StorageInterface> (*create_func)())
        : create_func_(create_func) {}

    pro::proxy<StorageInterface> Create() override;

    pro::proxy<StorageInterface> (*create_func_)();

  };  // struct StorageFactoryImpl

  StorageRegistry() {}

 public:
  ~StorageRegistry() {}

  void DeferRegister(const std::string&& device, const std::string&& cls_name,
                     pro::proxy<StorageInterface>(create_fn)());

  pro::proxy<StorageInterface> LookUp(const std::string&& factory_key);

  static StorageRegistry* Global() {
    static registry::StorageRegistry me_global_registry;
    return &me_global_registry;
  }
};

// REGISTER_STORAGE_IMPL_2, with a unique 'ctr' as the first argument.
#define REGISTER_STORAGE_IMPL_3(ctr, device, cls_name, cls_type)               \
  static meepo_embedding::storage::InitOnStartupMarker const storage_##ctr     \
      ME_ATTRIBUTE_UNUSED =                                                    \
          ME_INIT_ON_STARTUP_IF(ME_SHOULD_REGISTER_STORAGE(cls_name))          \
          << ([]() {                                                           \
               ::meepo_embedding::storage::registry::StorageRegistry::Global() \
                   ->DeferRegister(                                            \
                       device, cls_name,                                       \
                       []() -> pro::proxy<                                     \
                                meepo_embedding::storage::StorageInterface> {  \
                         return pro::make_proxy<                               \
                             meepo_embedding::storage::StorageInterface,       \
                             cls_type>();                                      \
                       });                                                     \
               return meepo_embedding::storage::InitOnStartupMarker{};         \
             })();

#define REGISTER_STORAGE_IMPL_2(...) \
  ME_NEW_ID_FOR_INIT(REGISTER_STORAGE_IMPL_3, __VA_ARGS__)

#define REGISTER_STORAGE_IMPL(device, ...)                            \
  static_assert(std::is_default_constructible<__VA_ARGS__>::value,    \
                "Meepo Embedding storage backend must has a default " \
                "constructor with empty parameters!");                \
  REGISTER_STORAGE_IMPL_2(device, ME_TYPE_NAME(__VA_ARGS__), __VA_ARGS__)

#define REGISTER_STORAGE(...) REGISTER_STORAGE_IMPL(__VA_ARGS__)

}  // namespace registry
}  // namespace storage
}  // namespace meepo_embedding

#endif  // MEEPO_EMBEDDING_STORAGE_registry_H_

storage_registry.cpp:

#include "storage_registry.hpp"

namespace meepo_embedding {
namespace storage {
namespace registry {

const std::string CreateFactoryKey(const std::string&& device,
                                   const std::string&& cls_name) {
  auto device_str = std::string(device);
  std::transform(device_str.begin(), device_str.end(), device_str.begin(),
                 [](unsigned char c) { return std::tolower(c); });
  std::string key = device_str + "_" + cls_name;
  return std::move(key);
}

void registry::StorageRegistry::DeferRegister(
    const std::string&& device, const std::string&& cls_name,
    pro::proxy<StorageInterface>(create_fn)()) {
  auto constructor = std::make_unique<StorageFactoryImpl>(create_fn);
  auto key = CreateFactoryKey(std::move(device), std::move(cls_name));
  DeferRegistrationData_.insert(std::make_pair(key, std::move(constructor)));
}

pro::proxy<StorageInterface> registry::StorageRegistry::LookUp(
    const std::string&& factory_key) {
  auto pair_found = DeferRegistrationData_.find(factory_key);
  if (pair_found == DeferRegistrationData_.end()) {
    return pro::proxy<StorageInterface>();
  } else {
    return pair_found->second->Create();
  }
}

pro::proxy<StorageInterface> StorageRegistry::StorageFactoryImpl::Create() {
  return (*create_func_)();
}

}  // namespace registry
}  // namespace storage
}  // namespace meepo_embedding

storage_interface.hpp:

#pragma once

#ifndef MEEPO_EMBEDDING_STORAGE_INTERFACE_HPP_
#define MEEPO_EMBEDDING_STORAGE_INTERFACE_HPP_

#include <expected>
#include <string>
#include <system_error>

#include "proxy.h"  // from @proxy

namespace meepo_embedding {
namespace storage {
struct NotImplemented {
  explicit NotImplemented(auto &&...) {
    throw std::runtime_error{
        "Not implemented function in storage backend class instance!"};
  }

  template <class T>
  operator T() const noexcept {
    std::unreachable();
  }
};

//
// Specifications of abstraction
// For more details, please check:
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2024/p3086r2.pdf
//
PRO_DEF_MEM_DISPATCH(MemInit, init);
PRO_DEF_MEM_DISPATCH(MemDim, dim);
PRO_DEF_MEM_DISPATCH(MemFind, find);

PRO_DEF_WEAK_DISPATCH(WeakMemInit, MemInit, NotImplemented);
PRO_DEF_WEAK_DISPATCH(WeakMemDim, MemDim, NotImplemented);
PRO_DEF_WEAK_DISPATCH(WeakMemFind, MemFind, NotImplemented);

struct StorageInterface
    : pro::facade_builder ::add_convention<
          WeakMemInit, std::expected<size_t, std::error_code>(
                           const std::map<std::string, std::string> *config)>::
          add_convention<WeakMemDim, std::expected<size_t, std::error_code>()>::
              add_convention<WeakMemFind,
                             std::expected<size_t, std::error_code>(
                                 const int64_t *keys, int64_t *values),
                             std::expected<size_t, std::error_code>(
                                 const int64_t *keys, double *values)>::build {
};

}  // namespace storage
}  // namespace meepo_embedding

#endif  // MEEPO_EMBEDDING_STORAGE_INTERFACE_HPP_

main.cpp:

#include <bits/stdc++.h>

#include "storage_registry.hpp"

template<typename T>
class fake_table
{
private:
  /* data */
public:
  std::expected<size_t, std::error_code> find(const int64_t *keys, double *values) {
    std::cout << "keys="<< *keys << " values="<< *values<< " type="<< typeid(T).name() << std::endl;
    return 123;
  }

  std::expected<size_t, std::error_code> dim() {
    return 42;
  }

  fake_table(/* args */){};
  ~fake_table(){};
};

REGISTER_STORAGE("cpu", fake_table<int>)

int main() {
  auto factory_key = ::meepo_embedding::storage::registry::CreateFactoryKey("CPU", ME_TYPE_NAME(fake_table<int>));
  auto p = ::meepo_embedding::storage::registry::StorageRegistry::Global()->LookUp(std::move(factory_key));
  // auto p = pro::make_proxy<::meepo_embedding::storage::StorageInterface, fake_table<int>>();
  std::cout << typeid(p).name() << std::endl;
  int64_t keys = 1;
  double values = 0.1;
  if (p.has_value()) {
    std::cout << (p->find(&keys, &values)).value() << std::endl;
    // std::cout << p->dim().value() << std::endl;
  } else {
    std::cout << "no key name " << factory_key << std::endl;
  }
}
mingxwa commented 3 months ago

@MoFHeka Looks like StorageFactory is implemented with virtual functions rather than proxy. Is there any consideration for not using proxy? Could you simplify the code so that we can understand your question better?