saghul / pycares

Python interface for c-ares
https://pypi.org/project/pycares/
MIT License
163 stars 74 forks source link

Finding which nameserver responded with a result #130

Closed arjunv closed 3 years ago

arjunv commented 4 years ago

When multiple resolve nameservers are passed to the resolver, is there any way to identify which server responded with a particular result? Also, how to do the same in aiodns?

I can get all the nameservers with .servers, but that's a list of all that will be used right?

And how exactly is resolving done when multiple servers are passed since it's asynchronous? Primary/Fallbacks or Roundrobin?

boytm commented 4 years ago

If you want to know which server responded with a particular result, you should use multiple channels -- one server one channel.
The resolving order is Roundrobin, see https://github.com/c-ares/c-ares/blob/7ebedab25dab50b2f008fbef8601c223096bb780/ares_send.c#L101

arjunv commented 4 years ago

Alright. Is there any way I can tweak the result and add in a field which tells me the namserver that was used (in a single channel-multiple server situation)?

It's not a common enough situation I guess, but if I were to use multiple channels, then I'd have to implement round robin all over again to spread the queries across multiple nameservers.

boytm commented 4 years ago

Impletement a round robin query is simple in asyncio. See below.

Core logic is in many_resolve() and resolve().

The DNSResolver class is copied from https://github.com/saghul/pycares/blob/master/examples/cares-asyncio.py, I just added one extra parameters.

This is just a demo. You should implement your exception handling code.

Run with args: python many_resolve.py qq.com sina.com twitter.com facebook.com ciis-cn.net

many_resolve.py


import asyncio
import socket
import sys

import pycares

class DNSResolver(object):
    EVENT_READ = 0
    EVENT_WRITE = 1

    def __init__(self, servers, loop=None):
        self.servers = servers
        self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb, servers=servers)
        self._timer = None
        self._fds = set()
        self.loop = loop or asyncio.get_event_loop()

    def _sock_state_cb(self, fd, readable, writable):
        if fd in self._fds:
            # clear old events
            self.loop.remove_reader(fd)
            self.loop.remove_writer(fd)
        if readable or writable:
            if readable:
                self.loop.add_reader(fd, self._process_events, fd, self.EVENT_READ)
            if writable:
                self.loop.add_writer(fd, self._process_events, fd, self.EVENT_WRITE)
            self._fds.add(fd)
            if self._timer is None:
                self._timer = self.loop.call_later(1.0, self._timer_cb)
        else:
            # socket is now closed
            self._fds.discard(fd)
            if not self._fds:
                self._timer.cancel()
                self._timer = None

    def _timer_cb(self):
        self._channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD)
        self._timer = self.loop.call_later(1.0, self._timer_cb)

    def _process_events(self, fd, event):
        if event == self.EVENT_READ:
            read_fd = fd
            write_fd = pycares.ARES_SOCKET_BAD
        elif event == self.EVENT_WRITE:
            read_fd = pycares.ARES_SOCKET_BAD
            write_fd = fd
        else:
            read_fd = write_fd = pycares.ARES_SOCKET_BAD
        self._channel.process_fd(read_fd, write_fd)

    def query(self, query_type, name, cb):
        self._channel.query(query_type, name, cb)

    def gethostbyname(self, name, cb):
        self._channel.gethostbyname(name, socket.AF_INET, cb)

def resolve(resolver, name):
    fut = asyncio.Future()
    def cb(result, error):
        print(result, error)
        fut.set_result((name, result, error))
    resolver.query('google.com', pycares.QUERY_TYPE_A, cb)
    return fut

async def many_resolve(resolvers, name):
    tasks = []
    # wait 1.0 second to send next query
    for resolver in resolvers:
        tasks.append(resolve(resolver, name))
        done, pending = await asyncio.wait(tasks, timeout=1.0, return_when=asyncio.FIRST_COMPLETED)
        if done:
            return (resolver.servers, done.pop().result())

    # extra wait
    done, pending = await asyncio.wait(tasks, timeout=5.0, return_when=asyncio.FIRST_COMPLETED)
    if done:
        return (resolver.servers, done.pop().result())

    # no result
    return None, None

def main():
    results = {}
    def cb(fut):
        print("Result: {}".format(fut.result()))

    loop = asyncio.get_event_loop()
    resolvers = [DNSResolver([s], loop) for s in ['1.1.1.1', '8.8.8.8',]]
    for name in sys.argv[1:]:
        fut = asyncio.ensure_future(many_resolve(resolvers, name))
        fut.add_done_callback(cb)

    loop.run_forever()

if __name__ == '__main__':
    main()