#!/usr/bin/env -S python3 -B
#
#   s5dnstls - TLS resolver using dnspython as the backend.
#
#   Queries will search for both IPv4 and IPv6 addresses.  If
#   multiple servers are provided, they will be searched in
#   round-robin fashion for each query, until addresses are
#   returned, or all servers have returned no results for a
#   query.
#
#   Use --server-addr to provide one or more DNS-over-TLS
#   servers to query.
#
#   Use --hosts-path to specify the path to an optional hosts
#   file to read.
#
#   Copyright (C) 2026 by Matt Roberts, KK5JY.
#   All rights reserved.
#
#   License: GPL3 (www.gnu.org)
#
#

#
#  TODO:
#   1. Add options for changing some settings
#      a. server port (also support in server address string)
#      b. certs, sigs, validation, etc.
#   2. Persist the DNS-TLS connections for some period of time;
#      According to the docs: "dns/query.py now provides make_socket(),
#      make_ssl_socket(), and make_ssl_context() to make using persistent
#      connections with the query code easier..."
#

# system modules
import sys
import getopt

# the dnspython module
try:
	import dns.message
	import dns.query
	import dns.rdatatype
	import dns.exception
except ModuleNotFoundError as ex:
	sys.stderr.write("Error: this program requires the 'dnspython' module, but that module was not found.\n")
	sys.exit(5)

# local utilities
from s5utils import *


# the server addresses
default_addr = [ '1.1.1.1', '1.0.0.1' ]
server_addr  = [ ] # user-provided addresses, defaults to the above

# the currently selected server index
server_idx = 0

# DEBUG: set to True to show resolution actions
debug_lookups = False

# set to True to enable lookup cache
cache_lookups = False

# timeout for a single query
dns_query_timeout = 4 # seconds

# prefer one IP address type over another; if non-None, this will
#   search the preferred version first, and stop searching if the
#   preferred IP version returns one or more addresses
prefer_ipv = None

# address types to search (v4 first, then v6; can be changed with -6)
address_types = [ dns.rdatatype.RdataType.A, dns.rdatatype.RdataType.AAAA ]


#
#  resolve(name) - resolve 'name' into a list of IP addresses
#
def resolve(name):
	global server_addr, server_idx

	# if cache enabled, try that
	if cache_lookups:
		# search the cache
		result = cache_search(name)

		# DEBUG:
		sys.stderr.write("DEBUG: cache_search(%s) => %s\n" % (name, "HIT" if result else "MISS"))

		# and return the result if a valid record was found
		if result:
			return result

	# start with empty result set
	result = [ ]

	# pick the server to query
	count = 0

	# ttl value
	ttl = 0

	# repeat queries until results returned, or we run out of servers
	while not result and count < len(server_addr):
		# select the server to query
		server = server_addr[server_idx]

		# increment the servers-queried count
		count += 1

		# DEBUG:
		if debug_lookups:
			sys.stderr.write("DEBUG: PID = %d query IP = %s for name %s\n" % (os.getpid(), server, name))

		# query both IPv4 and IPv6
		for rdt in address_types:
			q = r = None
			try:
				# perform the query
				q = dns.message.make_query(name, rdt)
				r = dns.query.tls(q, server, timeout=dns_query_timeout)
			except ( EOFError, OSError ) as ex:
				sys.stderr.write("Error: DNS resolver PID = %d raised %s\n" % (os.getpid(), str(ex)))
				r = None
				break
			except (TimeoutError, dns.exception.Timeout) as ex:
				sys.stderr.write("Error: DNS timeout: PID = %d query IP = %s for name %s\n" % (os.getpid(), server, name))
				r = None
				break

			# don't process results if the query failed
			if r is None:
				break

			# collect the results
			for rrset in r.answer:
				# capture the minimum nonzero TTL value in this overall query sequence
				ttl = rrset.ttl if ttl == 0 else min(ttl, rrset.ttl)

				# filter results down to IP addresses
				for rr in filter(lambda x: x.rdtype in address_types and x.address not in result, rrset):
					result.append(rr.address)

			# if preferred IP version lookup succeeded, stop
			if prefer_ipv and result:
				break

		# increment the index if nothing was found
		if not result:
			server_idx = (server_idx + 1) % len(server_addr)

	# if cache enabled, update the cache with positive return
	if cache_lookups and result:
		cache_update(name, result, ttl)

	# add a failure line if no addresses provided
	if not result:
		result.append('0')

	# return failure
	return result, ttl


#
#  usage()
#
def usage():
	sys.stderr.write("%s [options]\n" % os.path.basename(sys.argv[0]))
	sys.stderr.write("   -h = show this information\n")
	sys.stderr.write("   -4 = prefer IPv4 addresses\n")
	sys.stderr.write("   -6 = prefer IPv6 addresses\n")
	sys.stderr.write("   -c = enable lookup cache\n")
	sys.stderr.write("   -l = show debugging information about each query\n")
	sys.stderr.write("   -s | --server-addr <addr>[,<addr> ...]\n")
	sys.stderr.write("        = set remote server addresses\n")
	sys.stderr.write("   -p | --hosts-path <path>\n")
	sys.stderr.write("        = read hosts from 'path'\n")
	sys.stderr.write("   -t | --timeout <seconds>\n")
	sys.stderr.write("        = set DNS timeout value\n")
	sys.exit(1)


#
#  main()
#
def main():
	global server_addr, dns_query_timeout, prefer_ipv

	# path to the optional hosts file
	hosts_path = None

	# if True, use default servers
	default_servers = True

	try:
		# read the command line
		optlist, args = getopt.getopt(sys.argv[1:], '46chls:p:t:', [ 'server-addr=', 'hosts-path=', 'timeout=' ])
		
		# process the command line
		for opt in optlist:
			if opt[0] in [ '-s', '--server-addr' ]:
				if default_servers:
					default_servers = False
					server_addr = [ ]
				server_addr += opt[1].split(',')
			elif opt[0] in [ '-p', '--hosts-path' ]:
				hosts_path = os.path.expanduser(opt[1])
			elif opt[0] in [ '-t', '--timeout' ]:
				dns_query_timeout = float(opt[1])
			elif opt[0] == '-c':
				cache_lookups = True
			elif opt[0] == '-l':
				debug_lookups = True
			elif opt[0] == '-4':
				prefer_ipv = 4
			elif opt[0] == '-6':
				prefer_ipv = 6
			else:
				usage()
	except Exception as ex:
		usage()

	# sort the address type list to search in the preferred order
	if prefer_ipv == 6:
		address_types.reverse()

	# use defaults if none provided
	if not server_addr:
		server_addr = default_addr

	# call the main processing loop
	main_common(resolver=resolve, hosts_path=hosts_path)


# entry point - run main() with exception safety
if __name__ == '__main__':
	try:
		main()
		sys.exit(0)
	except BrokenPipeError as ex:
		sys.exit(2)
	except KeyboardInterrupt as ex:
		sys.exit(0)
	except Exception as ex:
		sys.stderr.write("Received unexpected exception: %s: %s\n" % (str(type(ex)), str(ex)))
		sys.exit(1)

# EOF: s5dnstls
