natesales / q

A tiny command line DNS client with support for UDP, TCP, DoT, DoH, DoQ and ODoH.
GNU General Public License v3.0
1.73k stars 63 forks source link

Enhancement: support querying multiple servers at once #73

Closed CosmicToast closed 12 months ago

CosmicToast commented 1 year ago

Support for querying multiple dns servers, like dog/doggo. This is particularly useful when testing a given dns server for conformance against a known good server, or trying to debug issues. Example in doggo (as doggo's output makes it more obvious):

image

Since q already supports querying multiple RRs at once, this seems like a logical extension, + allows for fuller parity.

CosmicToast commented 1 year ago

I made an attempt at implementing this, and it's relatively straightforward until we get to the output, since we don't store the origination server in the replies, and do not output the replying-server by default. Changing this is possible, but awkward due to #72. Here's the diff in case someone wants to pick up from where I left off; all but the final section should be relevant:

```diff diff --git a/cli/flags.go b/cli/flags.go index 5e396ae..536b5eb 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -4,7 +4,7 @@ import "time" type Flags struct { Name string `short:"q" long:"qname" description:"Query name"` - Server string `short:"s" long:"server" description:"DNS server"` + Server []string `short:"s" long:"server" description:"DNS server"` Types []string `short:"t" long:"type" description:"RR type (e.g. A, AAAA, MX, etc.) or type integer"` Reverse bool `short:"x" long:"reverse" description:"Reverse lookup"` DNSSEC bool `short:"d" long:"dnssec" description:"Set the DO (DNSSEC OK) bit in the OPT record"` diff --git a/main.go b/main.go index 3076d9e..0071d22 100644 --- a/main.go +++ b/main.go @@ -228,7 +228,7 @@ func parseServer(s string) (string, transport.Type, error) { // driver is the "main" function for this program that accepts a flag slice for testing func driver(args []string, out io.Writer) error { parser := flags.NewParser(&opts, flags.Default) - parser.Usage = `[OPTIONS] [@server] [type...] [name] + parser.Usage = `[OPTIONS] [@server...] [type...] [name] All long form (--) flags can be toggled with the dig-standard +[no]flag notation.` _, err := parser.ParseArgs(args) @@ -279,9 +279,9 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation // Add non-flag RR types for _, arg := range args { - // Find a server by @ symbol if it isn't set by flag - if opts.Server == "" && strings.HasPrefix(arg, "@") { - opts.Server = strings.TrimPrefix(arg, "@") + // @ servers added to server list + if strings.HasPrefix(arg, "@") { + opts.Server = append(opts.Server, strings.TrimPrefix(arg, "@")) } // Parse chaos class @@ -335,23 +335,24 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation log.Debugf("RR types: %+v", rrTypeStrings) } - // Set default DNS server - if opts.Server == "" { + // Set default DNS server if none were set explicitly + if len(opts.Server) == 0 { + opts.Server = make([]string, 1) if os.Getenv(defaultServerVar) != "" { - opts.Server = os.Getenv(defaultServerVar) + opts.Server[0] = os.Getenv(defaultServerVar) log.Debugf("Using %s from %s environment variable", opts.Server, defaultServerVar) } else { log.Debugf("No server specified or %s set, using /etc/resolv.conf", defaultServerVar) conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") if err != nil { - opts.Server = "https://cloudflare-dns.com/dns-query" + opts.Server[0] = "https://cloudflare-dns.com/dns-query" log.Debugf("no server set, using %s", opts.Server) } else { if len(conf.Servers) == 0 { - opts.Server = "https://cloudflare-dns.com/dns-query" + opts.Server[0] = "https://cloudflare-dns.com/dns-query" log.Debugf("no server set, using %s", opts.Server) } else { - opts.Server = conf.Servers[0] + opts.Server[0] = conf.Servers[0] log.Debugf("found server %s from /etc/resolv.conf", opts.Server) } } @@ -363,8 +364,10 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation if !strings.HasPrefix(opts.ODoHProxy, "https://") { return fmt.Errorf("ODoH proxy must use HTTPS") } - if !strings.HasPrefix(opts.Server, "https://") { - return fmt.Errorf("ODoH target must use HTTPS") + for _, v := range opts.Server { + if !strings.HasPrefix(v, "https://") { + return fmt.Errorf("ODoH target must use HTTPS") + } } } @@ -420,17 +423,20 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation ) // Parse server address and transport type - server, transportType, err := parseServer(opts.Server) - if err != nil { - return err + type tserver struct { + server string + ttype transport.Type + trans *transport.Transport } - log.Debugf("Using server %s with transport %s", server, transportType) - - // QUIC specific overrides - if transportType == transport.TypeQUIC { - tlsConfig.NextProtos = opts.QUICALPNTokens - // Skip ID check if QUIC (https://datatracker.ietf.org/doc/html/rfc9250#section-4.2.1) - opts.NoIDCheck = true + servers := make([]*tserver, 0, len(opts.Server)) + for _, v := range opts.Server { + server, transportType, err := parseServer(v) + if err != nil { + log.Debugf("Skipping server %s due to error %s", v, err) + continue + } + log.Debugf("Adding server %s with transport %s", server, transportType) + servers = append(servers, &tserver{server, transportType, nil}) } // Recursive zone transfer @@ -438,28 +444,63 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation if opts.Name == "" { return fmt.Errorf("no name specified for AXFR") } - _ = RecAXFR(opts.Name, server, out) + for _, v := range servers { + _ = RecAXFR(opts.Name, v.server, out) + } return nil } - // Create transport - txp, err := newTransport(server, transportType, tlsConfig) - if err != nil { - return err + // Create transports + for _, v := range servers { + tlsConfig := tlsConfig.Clone() + + // QUIC specific overrides + if v.ttype == transport.TypeQUIC { + tlsConfig.NextProtos = opts.QUICALPNTokens + // Skip ID check if QUIC (https://datatracker.ietf.org/doc/html/rfc9250#section-4.2.1) + opts.NoIDCheck = true // TODO: per-server overrides? + } + + // Create transport + txp, err := newTransport(v.server, v.ttype, tlsConfig) + if err != nil { + log.Debugf("Skipping server %s due to error %s", v.server, err) + continue + } + v.trans = txp } - startTime := time.Now() - var replies []*dns.Msg - for _, msg := range msgs { - reply, err := (*txp).Exchange(&msg) - if err != nil { - return err + // Filter failed servers, so we can preallocate responses + { + n := 0 + for _, v := range servers { + if v.trans == nil { + continue + } + servers[n] = v + n++ } + servers = servers[:n] + } - if !opts.NoIDCheck && reply.Id != msg.Id { - return fmt.Errorf("ID mismatch: expected %d, got %d", msg.Id, reply.Id) + // preallocate replies storage + replies := make([]*dns.Msg, 0, len(msgs) * len(servers)) + startTime := time.Now() + for _, v := range servers { + for _, msg := range msgs { + reply, err := (*v.trans).Exchange(&msg) + if err != nil { + log.Debugf("Skipping message %s with servers %s due to error %s", + &msg, v.server, err) + replies = append(replies, nil) // append nil so we can keep server sizes stable + continue + } + + if !opts.NoIDCheck && reply.Id != msg.Id { + return fmt.Errorf("ID mismatch: expected %d, got %d", msg.Id, reply.Id) + } + replies = append(replies, reply) } - replies = append(replies, reply) } queryTime := time.Since(startTime) @@ -474,27 +515,35 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation output.PrettyPrintNSID(replies, out) } - printer := output.Printer{ - Server: server, - Out: out, - Opts: &opts, - QueryTime: queryTime, - NumReplies: len(replies), - Transport: txp, - } - if opts.Format == "column" { - printer.PrintColumn(replies) - } else { - for i, reply := range replies { - switch opts.Format { - case "pretty": - printer.PrintPretty(i, reply) - case "raw": - printer.PrintRaw(i, reply) - case "json", "yml", "yaml": - printer.PrintStructured(i, reply) - default: - return fmt.Errorf("invalid output format") + for i, v := range servers { + fmt.Println(v.server) + printer := output.Printer{ + Server: v.server, + Out: out, + Opts: &opts, + QueryTime: queryTime, + NumReplies: len(msgs), // TODO: filter nils + Transport: v.trans, + } + if opts.Format == "column" { + printer.PrintColumn(replies[i * len(msgs) : (i+1) * len(msgs)]) + } else { + for i, reply := range replies[i * len(msgs) : (i+1) * len(msgs)] { + if reply == nil { + continue + } + switch opts.Format { + case "pretty": + printer.PrintPretty(i, reply) + case "raw": + printer.PrintRaw(i, reply) + // TODO: jq and co can handle multipe separate json objects on stdin + // however, it would be nice to potentially return a [] instead + case "json", "yml", "yaml": + printer.PrintStructured(i, reply) + default: + return fmt.Errorf("invalid output format") + } } } } ```