#!/usr/bin/env -S python3 -B
#
#   s5dnstls - TLS resolver using dnspython as the backend.
#
#   Queries will search for both IPv4 and IPv6 addresses.  If
#   multiple servers are provided, they will be searched in
#   round-robin fashion for each query, until addresses are
#   returned, or all servers have returned no results for a
#   query.
#
#   Many options are available; run s5dnstls -h to list them.
#
#   Copyright (C) 2026 by Matt Roberts, KK5JY.
#   All rights reserved.
#
#   License: GPL3 (www.gnu.org)
#
#

#
#  TODO:
#   1. Add option for specifying the name of custom server
#      that's not in the WK table
#   2. Add an option to bootstrap a WKN via reverse-lookup in DNS.
#   3. Add support for including port number in server address.
#   4. Show debugging error for TLS auth failures.
#   5. When looking up IP address in WKN table, do proper IPv6
#      address comparison, not just string match
#

#
#  required modules
#

# system modules
import sys
import time
import socket
import getopt
import threading

# the dnspython module
try:
	import dns.message
	import dns.query
	import dns.rdatatype
	import dns.exception
except ModuleNotFoundError as ex:
	sys.stderr.write("Error: this program requires the 'dnspython' module, but that module was not found.\n")
	sys.exit(5)

# local utilities
from s5utils import *


#
#  settings
#

# the server addresses (change with -s)
server_addr = [ '1.1.1.1', '1.0.0.1' ]

# the server TCP port (change with -P)
server_port = 853 # the standard DoT port

# DEBUG: set to True to show key resolution actions (enable with -z)
debug_lookups = False

# set to True to enable lookup cache
#    (socks5pp does this on its own across all resolvers, so
#     leave this disabled)
cache_lookups = False

# timeout for a single query (change with -t)
dns_query_timeout = 4 # seconds

# how long to keep an idle socket open (change with -i)
socket_idle = 5 # seconds

# prefer one IP address type over another; if non-None, this will
#   search the preferred version first, and stop searching if the
#   preferred IP version returns one or more addresses
#   (enable this with -6 or -4)
prefer_ipv = None

# flag indicating whether to validate if name known (enable with -w)
auto_validate = False

# address types to search (v4 first, then v6; can be changed with -6)
address_types = [ dns.rdatatype.RdataType.A, dns.rdatatype.RdataType.AAAA ]

# table of well-known host names
wk_names = {
	# cloudflare
	'1.1.1.1'             : 'one.one.one.one',
	'1.0.0.1'             : 'one.one.one.one',
	'2606:4700:4700::1111': 'one.one.one.one',
	'2606:4700:4700::1001': 'one.one.one.one',

	# quad9
	'9.9.9.9'        : 'dns.quad9.net', # default addresses
	'149.112.112.112': 'dns.quad9.net',
	'2620:fe::fe'    : 'dns.quad9.net',
	'2620:fe::9'     : 'dns.quad9.net',
	'9.9.9.11'       : 'dns.quad9.net', # filtered addresses
	'149.112.112.11' : 'dns.quad9.net',
	'2620:fe::11'    : 'dns.quad9.net',
	'2620:fe::fe:11' : 'dns.quad9.net',

	# google
	'8.8.8.8'             : 'dns.google',
	'8.8.4.4'             : 'dns.google',
	'2001:4860:4860::8888': 'dns.google',
	'2001:4860:4860::8844': 'dns.google',

	# mullvad
	'194.242.2.2'    : 'dns.mullvad.net',
	'2a07:e340::2'   : 'dns.mullvad.net',

	# Control-D
	'76.76.2.11'     : 'p0.freedns.controld.com',
	'76.76.10.11'    : 'p0.freedns.controld.com',
}


#
#  state
#

# the currently selected server index
server_idx = 0

# mutex to protect the socket cache
mutex = threading.Lock()

# how often to clean the cached sockets when idle
clean_interval = 2 # seconds

# the last time cleaning was done
last_clean = 0

# the connection cache (str -> node)
socket_cache = { }


#
#  connection cache node
#
class connection_node:
	def __init__(self, _sock=None):
		self.socket = _sock
		self.created = time.time()
		self.success = 0 # time of last completed query


#
#  return cache key for specific socket target
#
def make_socket_key(ip, port):
	return "%s-%d" % (ip, port)


#
#  make_connected_socket(ip, port) - return a connected SSL socket for the given IP
#
def make_connected_socket(ip, port):
	# determine the address family
	af = socket.AF_INET6 if ':' in ip else socket.AF_INET

	# if auto_validation is True, look up the IP in the well-known server table
	validate = False
	shn = None
	key = ip.lower()
	if auto_validate and key in wk_names:
		shn = wk_names[key]
		validate = True

		# DEBUG:
		if debug_lookups:
			sys.stderr.write("DEBUG: PID = %d query IP = %s TLS server name = '%s'\n" % (os.getpid(), ip, shn))

	# make the socket
	sock = dns.query.make_ssl_socket(
		af=af,
		server_hostname=shn,
		type=socket.SOCK_STREAM,
		ssl_context=dns.query.make_ssl_context(validate, validate))

	# connect and negotiate using dnspython inner utilities
	expiration = time.time() + dns_query_timeout
	try:
		dns.query._connect(sock, (ip, port), expiration)
		dns.query._tls_handshake(sock, expiration)
	except Exception as ex:
		try: sock.close()
		except: pass
		return None
	
	# return the socket
	return sock


#
#  get_socket(ip, port) - return a connected socket cache node; create socket if needed
#
def get_socket(ip, port):
	# make the cache key
	key = make_socket_key(ip, port)

	# return cached socket if available
	sock = socket_cache.get(key, None)
	if sock: return sock

	# make a new one otherwise
	sock = make_connected_socket(ip, port)
	if not sock: return None

	# then cache it and return it
	node = connection_node(sock)
	socket_cache[key] = node
	return node


#
#  reset_socket(ip, port) - close socket and remove it from cache
#
def reset_socket(ip, port):
	# make the cache key
	key = make_socket_key(ip, port)

	# fetch and remove the cache node
	node = socket_cache.pop(key, None)
	if not node: return

	# then close it
	try: node.socket.close()
	except: pass


#
#  clean_cache() - remove old sockets
#
def clean_cache():
	global last_clean
	last_clean = time.time()

	# find long-idle sockets
	to_remove = [ ]
	for key in socket_cache:
		node = socket_cache.get(key, None)
		if (time.time() - node.success) > socket_idle:
			to_remove.append(key)

	# then close and remove them
	for key in to_remove:
		node = socket_cache.pop(key, None)
		try: node.socket.close()
		except: pass


#
#  resolve_core(name) - resolve 'name' into a list of IP addresses
#
def resolve_core(name):
	global server_addr, server_idx

	# if cache enabled, try that
	if cache_lookups:
		# search the cache
		result = cache_search(name)

		# DEBUG:
		sys.stderr.write("DEBUG: cache_search(%s) => %s\n" % (name, "HIT" if result else "MISS"))

		# and return the result if a valid record was found
		if result:
			return result

	# start with empty result set
	result = [ ]

	# pick the server to query
	count = 0

	# ttl value
	ttl = 0

	# repeat queries until results returned, or we run out of servers
	while not result and count < len(server_addr):
		# select the server to query
		server = server_addr[server_idx]

		# increment the servers-queried count
		count += 1

		# DEBUG:
		if debug_lookups:
			sys.stderr.write("DEBUG: PID = %d query IP = %s for name '%s'\n" % (os.getpid(), server, name))

		# query both IPv4 and IPv6 with the same socket connection
		for rdt in address_types:
			q = r = None
			tries = 2 # try at most twice; second try only if socket has expired before first query
			while tries > 0: # the retry loop
				# get or build the connection
				node = get_socket(server, server_port)

				# if could not connect to server...
				if not node:
					# DEBUG:
					if debug_lookups:
						sys.stderr.write("DEBUG: DNS resolver PID = %d could not connect to server %s+%d\n" % (os.getpid(), server, server_port))

					r = None
					tries = 0 # don't try again for this server
					break     # but do try the next server if there is one

				# try the query
				try:
					# DEBUG:
					if debug_lookups:
						sys.stderr.write("DEBUG: PID = %d query [%s] IP = %s for name '%s'\n" % (os.getpid(), rdt, server, name))

					# build and perform the query
					q = dns.message.make_query(name, rdt)
					r = dns.query.tls(q, server, sock=node.socket, ssl_context=node.socket.context, timeout=dns_query_timeout)
					node.success = time.time()  # mark the socket as OK as of now
					break                       # exit the retry loop

				# socket closed before query - retry same server
				except ( EOFError, OSError ) as ex:
					tries = (tries - 1) if node.success else 0 # retry if this socket worked previously, and has simply closed
					reset_socket(server, server_port)          # clean the old socket, since it's closed anyway
					r = None

					# DEBUG:
					if debug_lookups:
						sys.stderr.write("DEBUG: DNS EOF: PID = %d query IP = %s for name '%s'\n" % (os.getpid(), server, name))
						if tries:
							sys.stderr.write("DEBUG: DNS EOF: will retry\n")

				# timeout or other failure (these shouldn't happen, as they are now caught and handled earlier, but here for safety)
				except ( TimeoutError, dns.exception.Timeout, Exception ) as ex:
					tries = 0                           # timeouts and hard-fails don't get retried
					reset_socket(server, server_port)   # clean the old socket, regardless of its state
					r = None

					# DEBUG:
					if debug_lookups:
						why = "timeout" if type(ex) in [ TimeoutError, dns.exception.Timeout ] else "failed"
						sys.stderr.write("DEBUG: DNS %s: PID = %d query IP = %s for name '%s'\n" % (why, os.getpid(), server, name))

			# if retries exceeded, fail
			if not tries:
				break

			# don't process results if the query failed
			if r is None:
				break

			# collect the results
			for rrset in r.answer:
				# capture the minimum nonzero TTL value in this overall query sequence
				ttl = rrset.ttl if ttl == 0 else min(ttl, rrset.ttl)

				# filter results down to IP addresses
				for rr in filter(lambda x: x.rdtype in address_types and x.address not in result, rrset):
					result.append(rr.address)

			# if preferred IP version lookup succeeded, stop
			if prefer_ipv and result:
				break

		# increment the index if nothing was found
		if not result:
			server_idx = (server_idx + 1) % len(server_addr)

	# if cache enabled, update the cache with positive return
	if cache_lookups and result:
		cache_update(name, result, ttl)

	# add a failure line if no addresses provided
	if not result:
		result.append('0')
	
	# clean out sockets that are too old
	clean_cache()

	# return failure
	return result, ttl


#
#  resolve_locked(name)
#
def resolve_locked(name):
	with mutex:
		try:
			return resolve_core(name)
		except Exception as ex:
			sys.stderr.write("DEBUG: resolve_locked() raised %s\n" % str(ex))
			return None


#
#  background_clean()
#
def background_clean():
	while True:
		# be very nice to the system
		time.sleep(clean_interval)

		try:
			# if cleanup hasn't run in a while, run it
			with mutex:
				if time.time() - last_clean >= clean_interval:
					clean_cache()
		except Exception as ex:
			sys.stderr.write("DEBUG: background_clean() raised %s\n" % str(ex))


#
#  usage()
#
def usage():
	sys.stderr.write("%s [options]\n" % os.path.basename(sys.argv[0]))
	sys.stderr.write("   -h = show this information\n")
	sys.stderr.write("   -4 = prefer IPv4 addresses\n")
	sys.stderr.write("   -6 = prefer IPv6 addresses\n")
	sys.stderr.write("   -c = enable lookup cache\n")
	sys.stderr.write("   -w = enable TLS validation for well-known servers\n")
	sys.stderr.write("   -P <port> = set server TCP port\n")
	sys.stderr.write("   -z = show debugging information about each query\n")
	sys.stderr.write("   -s | --server-addr <addr>[,<addr> ...]\n")
	sys.stderr.write("        = set remote server addresses\n")
	sys.stderr.write("   -p | --hosts-path <path>\n")
	sys.stderr.write("        = read hosts from 'path'\n")
	sys.stderr.write("   -t | --timeout <seconds>\n")
	sys.stderr.write("        = set DNS timeout value\n")
	sys.stderr.write("   -i <seconds>\n")
	sys.stderr.write("        = set max time before idle connections are closed\n")
	sys.exit(1)


#
#  main()
#
def main():
	global server_addr, server_port, cache_lookups, dns_query_timeout, prefer_ipv, auto_validate, socket_idle, debug_lookups

	# path to the optional hosts file
	hosts_path = None

	# if True, use default servers
	default_servers = True

	try:
		# read the command line
		optlist, args = getopt.getopt(sys.argv[1:], '46chi:s:p:P:t:wz', [ 'server-addr=', 'hosts-path=', 'timeout=' ])
		
		# process the command line
		for opt in optlist:
			if opt[0] in [ '-s', '--server-addr' ]:
				if default_servers:
					default_servers = False
					server_addr = [ ]
				server_addr += opt[1].split(',')
			elif opt[0] in [ '-p', '--hosts-path' ]:
				hosts_path = os.path.expanduser(opt[1])
			elif opt[0] in [ '-t', '--timeout' ]:
				dns_query_timeout = float(opt[1])
			elif opt[0] == '-P':
				server_port = int(opt[1])
			elif opt[0] == '-i':
				new_value = int(opt[1])
				if new_value >= 1:
					socket_idle = new_value
			elif opt[0] == '-c':
				cache_lookups = True
			elif opt[0] == '-z':
				debug_lookups = True
			elif opt[0] == '-4':
				prefer_ipv = 4
			elif opt[0] == '-6':
				prefer_ipv = 6
			elif opt[0] == '-w':
				auto_validate = True
			else:
				usage()
	except Exception as ex:
		usage()

	# sort the address type list to search in the preferred order
	if prefer_ipv == 6:
		address_types.reverse()

	# use defaults if none provided
	if not server_addr:
		raise Exception("Must specify at least one server address.")

	# start the background cleaner thread
	bg_thread = threading.Thread(target=background_clean, daemon=True)
	bg_thread.start()

	# call the main processing loop
	main_common(resolver=resolve_locked, hosts_path=hosts_path)


# entry point - run main() with exception safety
if __name__ == '__main__':
	try:
		main()
		sys.exit(0)
	except BrokenPipeError as ex:
		sys.exit(2)
	except KeyboardInterrupt as ex:
		sys.exit(0)
	except Exception as ex:
		sys.stderr.write("Received unexpected exception: %s: %s\n" % (str(type(ex)), str(ex)))
		sys.exit(1)

# EOF: s5dnstls
