boostorg / cobalt

Coroutines for C++20 & asio
https://www.boost.org/doc/libs/master/libs/cobalt/doc/html/index.html
205 stars 24 forks source link

how to await for a shared resource from multiple coroutines #199

Open mitsubishirgb opened 1 week ago

mitsubishirgb commented 1 week ago

multiple requests are being made on the same domain and i want every coroutine to wait for the pending resolve but in this snippet an error is thrown in the second pending co_await

class DnsCache {
public:
    DnsCache() : resolver_(cobalt::this_thread::get_executor()) {}

    cobalt::promise<std::optional<tcp::endpoint>> get_address(const std::string& domain) {
        // Check if the domain is already cached
        if (auto it = dns_cache.find(domain); it != dns_cache.end()) {
            DEBUG_PRINT("[INFO] Address found in cache for domain: " + domain);
            co_return it->second;  // Return cached address
        }

        // Check if there is a pending resolution
        if (auto it = pending_resolutions.find(domain); it != pending_resolutions.end()) {
            DEBUG_PRINT("[INFO] Waiting for pending resolution for domain: " + domain);
            co_return co_await it->second;  // Await the result from the existing coroutine
        }

        // Create a new promise coroutine
        auto resolution_coroutine = resolve_addr(domain, "443");

        // Use emplace to insert the promise into the map
        pending_resolutions.emplace(domain, std::move(resolution_coroutine));

        // Await the resolution
        auto resolved_addr = co_await resolution_coroutine;

        if (resolved_addr.port() != 0) {
            store_address(domain, resolved_addr);  // Cache the result if valid
        }

        // Remove the pending resolution
        pending_resolutions.erase(domain);

        co_return resolved_addr;  // Return the resolved address
    }

    void store_address(const std::string& domain, const tcp::endpoint& resolved_address) {
        DEBUG_PRINT("[INFO] Storing resolved address for domain: " + domain);
        dns_cache[domain] = resolved_address;  // Store the resolved address in the cache
    }

    cobalt::promise<tcp::endpoint> resolve_addr(const std::string& host, const std::string& port) {
        system::error_code ec;
        auto result = co_await resolver_.async_resolve(host, port, asio::redirect_error(use_op, ec));
        if (ec) {
            DEBUG_PRINT("Failed resolving " + host);
            co_return{};  // Return an empty endpoint on error
        }

        if (result != tcp::resolver::iterator()) {
            co_return result.begin()->endpoint();  // Return the first resolved endpoint
        }

        co_return{};  // Return an empty endpoint if no results were found
    }

private:
    tcp::resolver resolver_;  // Resolver for asynchronous DNS lookups
    std::unordered_map<std::string, tcp::endpoint> dns_cache;  // Cached DNS resolutions
    std::unordered_map<std::string, cobalt::promise<tcp::endpoint>> pending_resolutions;  // Pending DNS resolutions
};
klemens-morgenstern commented 4 days ago

What is the error? Are you sure you want to use const & in a coroutine argument list?

Also: is this an exercise or a prod problem? Because the OS will cache DNS lookups for you.