dragonflydb / dragonfly

A modern replacement for Redis and Memcached
https://www.dragonflydb.io/
Other
24.76k stars 892 forks source link

redis-cell port #270

Closed kakserpom closed 1 month ago

kakserpom commented 1 year ago

Hey!

There's a small and yet very nice module for Redis: https://github.com/brandur/redis-cell It would be really nice to have it ported into dragonflydb :)

romange commented 1 year ago

Как серпом по яйцам?! Ok, it's indeed a nice module. Also, it has a reference to the algorithm it implements, so should not be very hard to implement.

romange commented 1 year ago

For reference, here is the blog post about the algorithm by the author of the module, here is the wiki article

kakserpom commented 1 year ago

Да-да. kak.serpom.po.yaitsam собака gmail ком

aryenkhandal commented 1 year ago

For reference, here is the blog post about the algorithm by the author of the module, here is the wiki article

It's use for sometimes when the code will execute

romange commented 1 year ago

it's a rate limiting algorithm

zetanumbers commented 1 year ago

Here's reference implementation with tests ported from redis-cell (start from cl_throttle function):

#include <array>
#include <chrono>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>

using std::int64_t;
using std::string;
using std::chrono::nanoseconds;
using namespace std::chrono_literals;
using time_point =
    std::chrono::time_point<std::chrono::steady_clock, nanoseconds>;

// Maximum number of times to retry set_if_not_exists/compare_and_swap
// operations before returning an error.
constexpr int64_t MAX_CAS_ATTEMPTS = 5;

struct Rate {
  nanoseconds period;

public:
  /// Produces a rate for some number of actions per second. For example, if
  /// we wanted to have 10 actions every 2 seconds, the period produced would
  /// be 200 ms.
  static Rate per_period(int64_t n, nanoseconds period) {
    int64_t ns = period.count();

    // Don't rely on floating point math to get here.
    if (n == 0 || ns == 0) {
      return Rate{.period = 0ns};
    }

    return Rate{.period = nanoseconds{static_cast<int64_t>(
                    static_cast<double>(ns) / static_cast<double>(n))}};
  }
};

struct RateQuota {
  int64_t max_burst;
  Rate max_rate;
};

struct RateLimitResult {
  int64_t limit;
  int64_t remaining;
  nanoseconds reset_after;
  nanoseconds retry_after;

public:
  friend std::ostream &operator<<(std::ostream &os, RateLimitResult c) {
    os << "{.limit" << c.limit << ", .remaining = " << c.remaining
       << ", "
          ".remaining = "
       << c.remaining
       << ", "
          ".reset_after = "
       << c.reset_after.count()
       << "ns, "
          ".retry_after = "
       << c.retry_after.count() << "ns}";
    return os;
  }
};

class Store {
public:
  virtual ~Store() = default;

  virtual bool compare_and_swap_with_ttl(const string &key, int64_t old_val,
                                         int64_t new_val, nanoseconds ttl) = 0;

  virtual std::pair<int64_t, time_point> get_with_time(const string &key) = 0;

  virtual bool set_if_not_exists_with_ttl(const string &key, int64_t val,
                                          nanoseconds ttl) = 0;
};

class MemoryStore : public Store {
  std::unordered_map<string, int64_t> map;

public:
  ~MemoryStore() final = default;

  bool compare_and_swap_with_ttl(const string &key, int64_t old_val,
                                 int64_t new_val, nanoseconds ttl) final {
    try {
      if (this->map.at(key) != old_val) {
        return false;
      }
    } catch (std::out_of_range e) {
    }

    this->map.insert_or_assign(key, new_val);
    return true;
  }

  std::pair<int64_t, time_point> get_with_time(const string &key) final {
    int64_t val;
    try {
      val = this->map.at(key);
    } catch (std::out_of_range e) {
      val = -1;
    }
    return std::make_pair(val, std::chrono::steady_clock::now());
  }

  bool set_if_not_exists_with_ttl(const string &key, int64_t val,
                                  nanoseconds ttl) final {
    return this->map.insert(std::make_pair(key, val)).second;
  }
};

class RateLimiter {
  /// Think of the DVT as our flexibility: how far can you deviate from the
  /// nominal equally spaced schedule? If you like leaky buckets, think about
  /// it as the size of your bucket.
  nanoseconds delay_variation_tolerance;

  /// Think of the emission interval as the time between events in the
  /// nominal equally spaced schedule. If you like leaky buckets, think of it
  /// as how frequently the bucket leaks one unit.
  nanoseconds emission_interval;

  int64_t limit;

public:
  Store *store;

public:
  RateLimiter(Store *store, const RateQuota &quota)
      : delay_variation_tolerance(quota.max_rate.period.count() *
                                  (quota.max_burst + 1)),
        emission_interval(quota.max_rate.period), limit(quota.max_burst + 1),
        store(store) {}

  /// RateLimit checks whether a particular key has exceeded a rate limit. It
  /// also returns a RateLimitResult to provide additional information about
  /// the state of the RateLimiter.
  ///
  /// If the rate limit has not been exceeded, the underlying storage is
  /// updated by the supplied quantity. For example, a quantity of 1 might be
  /// used to rate limit a single request while a greater quantity could rate
  /// limit based on the size of a file upload in megabytes. If quantity is
  /// 0, no update is performed allowing you to "peek" at the state of the
  /// RateLimiter for a given key.
  std::pair<bool, RateLimitResult> rate_limit(const string &key,
                                              int64_t quantity) {
    auto rlc = RateLimitResult{.limit = this->limit,
                               .remaining = 0,
                               .reset_after = -1s,
                               .retry_after = -1s};

    if (this->emission_interval == 0ns) {
      throw std::runtime_error("Zero rates are not supported");
    }

    auto increment = nanoseconds{
        this->emission_interval.count() * quantity,
    };

    bool limited;
    nanoseconds ttl;

    // Looping here is not about retrying communication failures, it's
    // about retrying contention. While we're performing our calculations
    // it's possible for another limiter to be doing its own simultaneously
    // and beat us to the punch. In that case only one limiter should win.
    //
    // Note that when running with our internal Redis store (i.e. the
    // normal case for the redis-cell project) this is actually *not* true
    // because our entire operation will execute atomically.
    int64_t i = 0;

    while (true) {
      auto [tat_val, now] = this->store->get_with_time(key);

      time_point tat = tat_val == -1 ? now : time_point{nanoseconds{tat_val}};
      time_point new_tat = std::max(now, tat) + increment;

      time_point allow_at = new_tat - this->delay_variation_tolerance;
      nanoseconds diff = now - allow_at;

      if (diff < 0ns) {
        if (increment <= this->delay_variation_tolerance) {
          rlc.retry_after = -diff;
        }

        limited = true;
        ttl = tat - now;
        break;
      }

      int64_t new_tat_ns = new_tat.time_since_epoch().count();
      ttl = new_tat - now;

      // If the key was originally missing, set it if if doesn't exist.
      // If it was there, try to compare and swap.
      //
      // Both of these cases are designed to work around the fact that
      // another limiter could be running in parallel.
      bool updated =
          tat_val == -1
              ? this->store->set_if_not_exists_with_ttl(key, new_tat_ns, ttl)
              : this->store->compare_and_swap_with_ttl(key, tat_val, new_tat_ns,
                                                       ttl);

      if (updated) {
        limited = false;
        break;
      }

      i += 1;
      if (i < MAX_CAS_ATTEMPTS) {
        std::stringstream ss;
        ss << "Failed to update rate limit after " << MAX_CAS_ATTEMPTS
           << " attempts";
        throw std::runtime_error(ss.str());
      }
    }

    nanoseconds next = this->delay_variation_tolerance - ttl;
    if (next > -this->emission_interval) {
      rlc.remaining = static_cast<int64_t>(
          static_cast<double>(next.count()) /
          static_cast<double>(this->emission_interval.count()));
    }
    rlc.reset_after = ttl;

    return std::make_pair(limited, rlc);
  }
};

// --- START HERE ---
std::array<int64_t, 5> cl_throttle(Store *store, const string &key,
                                   int64_t max_burst, int64_t count,
                                   int64_t period, int64_t quantity) {
  auto rate = Rate::per_period(count, std::chrono::seconds{period});
  auto limiter =
      RateLimiter(store, RateQuota{.max_burst = max_burst, .max_rate = rate});

  auto [throttled, rate_limit_result] = limiter.rate_limit(key, quantity);

  // If either time had a partial component, but it up to the next full
  // second because otherwise a fast-paced caller could try again too
  // early.
  int64_t retry_after = rate_limit_result.retry_after.count() / 1'000'000'000;
  if (rate_limit_result.retry_after.count() / 1'000'000 > 0) {
    retry_after += 1;
  }
  int64_t reset_after = rate_limit_result.reset_after.count() / 1'000'000'000;
  if (rate_limit_result.reset_after.count() / 1'000'000 > 0) {
    reset_after += 1;
  }

  return {throttled ? 1 : 0, rate_limit_result.limit,
          rate_limit_result.remaining, retry_after, reset_after};
}

// --- START HERE ---
std::array<int64_t, 5> cl_throttle(Store *store, const string &key,
                                   int64_t max_burst, int64_t count,
                                   int64_t period) {
  return cl_throttle(store, key, max_burst, count, period, 1);
}

// --- TESTS ---

struct RateLimitCase {
  int64_t num;
  time_point now;
  int64_t volume;
  int64_t remaining;
  nanoseconds reset_after;
  nanoseconds retry_after;
  bool limited;

public:
  friend std::ostream &operator<<(std::ostream &os, RateLimitCase c) {
    os << "{.num" << c.num << ", .volume = " << c.volume
       << ", "
          //".now = " << c.now << ", "
          ".remaining = "
       << c.remaining
       << ", "
          ".reset_after = "
       << c.reset_after.count()
       << "ns, "
          ".retry_after = "
       << c.retry_after.count()
       << "ns, "
          ".limited = "
       << c.limited << "}";
    return os;
  }
};

/// TestStore is a Store implementation that wraps a MemoryStore and allows
/// us to tweak certain behavior, like for example setting the effective
/// system clock.
struct TestStore : public Store {
  time_point clock;
  bool fail_updates;
  MemoryStore *store;

public:
  TestStore(MemoryStore *store) : store(store), clock(), fail_updates(false) {}

  ~TestStore() final = default;

  bool compare_and_swap_with_ttl(const string &key, int64_t old_val,
                                 int64_t new_val, nanoseconds ttl) final {
    if (this->fail_updates) {
      return false;
    } else {
      return this->store->compare_and_swap_with_ttl(key, old_val, new_val, ttl);
    }
  }

  std::pair<int64_t, time_point> get_with_time(const string &key) final {
    auto tup = this->store->get_with_time(key);
    return std::make_pair(tup.first, this->clock);
  }

  bool set_if_not_exists_with_ttl(const string &key, int64_t value,
                                  nanoseconds ttl) final {
    if (this->fail_updates) {
      return false;
    } else {
      return this->store->set_if_not_exists_with_ttl(key, value, ttl);
    }
  }
};

template <class T> void assert_eq(const T &a, const T &b) {
  if (a != b) {
    std::stringstream ss;
    ss << "Values are not equal. Left: " << a << "; Right: " << b;
    throw std::runtime_error(ss.str());
  }
}

template <class T, class Fmt> void assert_eq(const T &a, const T &b, Fmt fmt) {
  if (a != b) {
    std::stringstream ss;
    ss << "Values are not equal. Left: ";
    fmt(ss, a);
    ss << "; Right: ";
    fmt(ss, b);
    throw std::runtime_error(ss.str());
  }
}

void it_rate_limits() {
  int64_t limit = 5;
  auto quota = RateQuota{
      .max_burst = limit - 1,
      .max_rate = Rate::per_period(1, 1s),
  };
  time_point start = std::chrono::steady_clock::now();
  auto memory_store = MemoryStore();
  auto test_store = TestStore(&memory_store);
  auto limiter = RateLimiter(&test_store, quota);

  auto cases = std::array{
      //
      // (test case #, now, volume, remaining, reset_after, retry_after,
      // limited)
      //

      // You can never make a request larger than the maximum.
      RateLimitCase{0, start, 6, 5, 0ns, -1s, true},

      // Rate limit normal requests appropriately.
      RateLimitCase{1, start, 1, 4, 1s, -1s, false},
      RateLimitCase{2, start, 1, 3, 2s, -1s, false},
      RateLimitCase{3, start, 1, 2, 3s, -1s, false},
      RateLimitCase{4, start, 1, 1, 4s, -1s, false},
      RateLimitCase{5, start, 1, 0, 5s, -1s, false},
      RateLimitCase{6, start, 1, 0, 5s, 1s, true},

      RateLimitCase{7, start + 3000ms, 1, 2, 3000ms, -1s, false},
      RateLimitCase{8, start + 3100ms, 1, 1, 3900ms, -1s, false},
      RateLimitCase{9, start + 4000ms, 1, 1, 4000ms, -1s, false},
      RateLimitCase{10, start + 8000ms, 1, 4, 1000ms, -1s, false},
      RateLimitCase{11, start + 9500ms, 1, 4, 1000ms, -1s, false},

      // Zero-volume request just peeks at the state.
      RateLimitCase{12, start + 9500ms, 0, 4, 1s, -1s, false},

      // High-volume request uses up more of the limit.
      RateLimitCase{13, start + 9500ms, 2, 2, 3s, -1s, false},

      // Large requests cannot exceed limits
      RateLimitCase{14, start + 9500ms, 5, 2, 3s, 3s, true},
  };

  for (const auto &c : cases) {
    std::cerr << "starting test case = " << c.num << '\n';
    std::cerr << c << '\n';

    test_store.clock = c.now;
    auto [limited, results] = limiter.rate_limit("foo", c.volume);

    std::cerr << "limited = " << limited << '\n';
    std::cerr << results << "\n\n";

    auto fmt_nanoseconds = [](std::ostream &os, nanoseconds t) {
      os << t.count() << "ns";
    };
    assert_eq(c.limited, limited);
    assert_eq(limit, results.limit);
    assert_eq(c.remaining, results.remaining);
    assert_eq(c.reset_after, results.reset_after, fmt_nanoseconds);
    assert_eq(c.retry_after, results.retry_after, fmt_nanoseconds);
  }
}

int main() { it_rate_limits(); }