#!/usr/bin/env -S python3 -B
#
#   s5dnstls - TLS resolver using dnspython as the backend.
#
#   Queries will search for both IPv4 and IPv6 addresses.  If
#   multiple servers are provided, they will be searched in
#   round-robin fashion for each query, until addresses are
#   returned, or all servers have returned no results for a
#   query.
#
#   Many options are available; run s5dnstls -h to list them.
#
#   This program is part of the Socks5++ source code.
#
#   Copyright (C) 2026 by Matt Roberts, KK5JY.
#   All rights reserved.
#
#   License: GPL3 (www.gnu.org)
#
#

#
#  required modules
#

# system modules
import sys
import time
import socket
import getopt
import threading

# the dnspython module
try:
	import dns.message
	import dns.query
	import dns.rdatatype
	import dns.exception
except ModuleNotFoundError as ex:
	sys.stderr.write("Error: this program requires the 'dnspython' module, but that module was not found.\n")
	sys.exit(5)

# local utilities
from s5utils import *


#
#  settings
#

# the server addresses (change with -s)
server_addr = [ '1.1.1.1', '1.0.0.1' ]

# the server TCP port (change with -P)
server_port = 853 # the standard DoT port

# DEBUG: set to True to show key resolution actions (enable with -z)
debug_lookups = False

# set to True to enable lookup cache
#    (socks5pp does this on its own across all resolvers, so
#     leave this disabled)
cache_lookups = False

# timeout for a single query (change with -t)
dns_query_timeout = 4 # seconds

# how long to keep an idle socket open (change with -i)
socket_idle = 5 # seconds

# prefer one IP address type over another; if non-None, this will
#   search the preferred version first, and stop searching if the
#   preferred IP version returns one or more addresses
#   (enable this with -6 or -4)
prefer_ipv = None

# flag indicating whether to validate if name known (enable with -w)
auto_validate = False

# address types to search (v4 first, then v6; can be changed with -6)
address_types = [ dns.rdatatype.RdataType.A, dns.rdatatype.RdataType.AAAA ]

# table of well-known host names
#   map the integer equivalent of the IP address to the name,
#   so that variations in the IPv6 string don't cause the key
#   to be different than those used here, e.g., :: and hex case
wk_names = {
	# cloudflare
	ip_int('1.1.1.1')              : 'one.one.one.one',
	ip_int('1.0.0.1')              : 'one.one.one.one',
	ip_int('2606:4700:4700::1111') : 'one.one.one.one',
	ip_int('2606:4700:4700::1001') : 'one.one.one.one',

	# quad9
	ip_int('9.9.9.9')              : 'dns.quad9.net', # default addresses
	ip_int('149.112.112.112')      : 'dns.quad9.net',
	ip_int('2620:fe::fe')          : 'dns.quad9.net',
	ip_int('2620:fe::9')           : 'dns.quad9.net',
	ip_int('9.9.9.11')             : 'dns.quad9.net', # filtered addresses
	ip_int('149.112.112.11')       : 'dns.quad9.net',
	ip_int('2620:fe::11')          : 'dns.quad9.net',
	ip_int('2620:fe::fe:11')       : 'dns.quad9.net',

	# google
	ip_int('8.8.8.8')              : 'dns.google',
	ip_int('8.8.4.4')              : 'dns.google',
	ip_int('2001:4860:4860::8888') : 'dns.google',
	ip_int('2001:4860:4860::8844') : 'dns.google',

	# mullvad
	ip_int('194.242.2.2')    : 'dns.mullvad.net',
	ip_int('2a07:e340::2')   : 'dns.mullvad.net',

	# Control-D
	ip_int('76.76.2.11')     : 'p0.freedns.controld.com',
	ip_int('76.76.10.11')    : 'p0.freedns.controld.com',
}

# the version number
__version__ = '1.0.7'


#
#  state
#

# the currently selected server index
server_idx = 0

# mutex to protect the socket cache
mutex = threading.Lock()

# how often to clean the cached sockets when idle
clean_interval = 2 # seconds

# the last time cleaning was done
last_clean = 0

# the connection cache (str -> node)
socket_cache = { }


#
#  connection cache node
#
class connection_node:
	def __init__(self, _sock=None):
		self.socket = _sock
		self.created = time.time()
		self.success = 0 # time of last completed query


#
#  return cache key for specific socket target
#
def make_socket_key(ip, port):
	return "%s-%d" % (ip, port)


#
#  make_connected_socket(ip, port) - return a connected SSL socket for the given IP
#
def make_connected_socket(ip, port):
	# determine the address family
	af = socket.AF_INET6 if ':' in ip else socket.AF_INET

	# if auto_validation is True, look up the IP in the well-known server table
	validate = False
	shn = None
	key = ip_int(ip)
	if auto_validate and key in wk_names:
		shn = wk_names[key]
		validate = True

		# DEBUG:
		if debug_lookups:
			sys.stderr.write("DEBUG: PID %d query %s TLS server name is '%s'\n" % (os.getpid(), ip, shn))

	# make the socket
	sock = dns.query.make_ssl_socket(
		af=af,
		server_hostname=shn,
		type=socket.SOCK_STREAM,
		ssl_context=dns.query.make_ssl_context(validate, validate))

	# connect and negotiate using dnspython inner utilities
	expiration = time.time() + dns_query_timeout
	try:
		dns.query._connect(sock, (ip, port), expiration)
		dns.query._tls_handshake(sock, expiration)
	except Exception as ex:
		try: sock.close()
		except: pass

		# DEBUG:
		if debug_lookups:
			sys.stderr.write("DEBUG: PID %d make_connected_socket(%s, %d) failed: %s\n" % (os.getpid(), ip, port, str(ex)))

		return None
	
	# return the socket
	return sock


#
#  get_socket(ip, port) - return a connected socket cache node; create socket if needed
#
def get_socket(ip, port):
	# make the cache key
	key = make_socket_key(ip, port)

	# return cached socket if available
	sock = socket_cache.get(key, None)
	if sock: return sock

	# make a new one otherwise
	sock = make_connected_socket(ip, port)
	if not sock: return None

	# then cache it and return it
	node = connection_node(sock)
	socket_cache[key] = node
	return node


#
#  reset_socket(ip, port) - close socket and remove it from cache
#
def reset_socket(ip, port):
	# make the cache key
	key = make_socket_key(ip, port)

	# fetch and remove the cache node
	node = socket_cache.pop(key, None)
	if not node: return

	# then close it
	try: node.socket.close()
	except: pass


#
#  clean_cache() - remove old sockets
#
def clean_cache():
	global last_clean
	last_clean = time.time()

	# find long-idle sockets
	to_remove = [ ]
	for key in socket_cache:
		node = socket_cache.get(key, None)
		if (time.time() - node.success) > socket_idle:
			to_remove.append(key)

	# then close and remove them
	for key in to_remove:
		node = socket_cache.pop(key, None)
		try: node.socket.close()
		except: pass


#
#  query_retry_loop(server, port, name, addr_type)
#
#  Query remote server for a specific name and address type.
#
#  The retry loop below will obtain a socket for 'server',
#  either by making a new one, or reusing a cached one.
#
#  If the socket is cached, it may have timed out since
#  it was last used, and disconnected without warning.
#  Some DNS/TLS servers are aggressive about disconnecting
#  idle TCP connections, even just a few seconds after the
#  most recent query.
#
#  There is no reliable way to test a cached socket for
#  "good" state, since even a well-formed test has a race
#  condition between this process, the remote server, and
#  the networks between them.  There is always a window for
#  failure that we cannot completely avoid by simply asking
#  the socket if it is still connected.
#
#  The selective retry loop below compensates for this.
#
#  The retry loop will try the query on the socket obtained
#  from get_socket(server).  If the socket was cached, and
#  if it successfully completed at least one prior query to
#  its server, and if this query fails to get a response
#  (even an empty one), the retry loop will call get_socket
#  one more time, to get a fresh socket and try again. This
#  ensures that the query doesn't fail due to a stale cached
#  socket.
#
#  If a fresh socket fails to connect, or fails to get a
#  response from the server (timeout, etc.), then the query
#  is treated as irrecoverable.
#
def query_retry_loop(server, server_port, name, rdt):
	# retry count: try at most twice
	#    second try only if working socket expired before this query
	tries = 2

	# clear the response
	r = None

	# the retry loop; see NOTE above
	while tries > 0:
		# get or build the connection
		node = get_socket(server, server_port)

		# if a connected socket could not be created...
		if not node:
			# DEBUG:
			if debug_lookups:
				sys.stderr.write("DEBUG: PID %d could not connect to server %s+%d\n" % (os.getpid(), server, server_port))

			return None # return failure

		# try the query
		try:
			# DEBUG:
			if debug_lookups:
				sys.stderr.write("DEBUG: PID %d query %s for %s name '%s'\n" % (os.getpid(), server, "IPv4" if rdt == 1 else "IPv6", name))

			# build and perform the query
			q = dns.message.make_query(name, rdt)
			r = dns.query.tls(q, server, sock=node.socket, ssl_context=node.socket.context, timeout=dns_query_timeout)
			node.success = time.time()  # mark the socket as OK as of now
			return r                    # return results

		# socket closed before query - retry same server
		except ( EOFError, OSError ) as ex:
			# retry if this socket worked previously, and has simply closed
			#    otherwise skip any retries and just fail
			tries = (tries - 1) if node.success else 0

			# DEBUG:
			if debug_lookups:
				sys.stderr.write("DEBUG: DNS EOF: PID %d query %s for name '%s'\n" % (os.getpid(), server, name))
				if tries:
					sys.stderr.write("DEBUG: DNS EOF: will retry\n")

		# timeout or other failure (these shouldn't happen, as they are now caught and handled above, but here for safety)
		except ( TimeoutError, dns.exception.Timeout, Exception ) as ex:
			# timeouts and hard-fails don't get retried
			tries = 0

			# DEBUG:
			if debug_lookups:
				why = "timeout" if type(ex) in [ TimeoutError, dns.exception.Timeout ] else "failed"
				sys.stderr.write("DEBUG: PID %d %s %s for name '%s'\n" % (os.getpid(), why, server, name))

		# clean the expired/failed socket before retrying or returning failure
		reset_socket(server, server_port)

	# return failure
	return None


#
#  resolve_core(name) - resolve 'name' into a list of IP addresses
#
def resolve_core(name):
	global server_addr, server_idx

	# if cache enabled, try that
	if cache_lookups:
		# search the cache
		result = cache_search(name)

		# DEBUG:
		sys.stderr.write("DEBUG: cache_search(%s) => %s\n" % (name, "HIT" if result else "MISS"))

		# and return the result if a valid record was found
		if result:
			return result

	# start with empty result set
	result = [ ]

	# the count of servers tried for this request; this counter
	#    will make sure that each server is tried at most once
	count = 0

	# ttl value; this will track the minimum TTL of the records returned,
	#   since IPv4 and IPv6 queries may return different TTL values
	ttl = 0

	# repeat queries until results returned, or we run out of servers
	while not result and count < len(server_addr):
		# select the server to query
		server = server_addr[server_idx]

		# increment the servers-queried count
		count += 1

		# DEBUG:
		if debug_lookups:
			sys.stderr.write("DEBUG: PID %d query %s name '%s'\n" % (os.getpid(), server, name))

		# query both IPv4 and IPv6 (order determined by -4 or -6)
		for rdt in address_types:
			# perform the query in a safe retry loop
			r = query_retry_loop(server, server_port, name, rdt)

			# don't try to process results if the query failed (not just empty, but *failed*)
			if r is None:
				break # move to the next server (exit address-type loop)

			# collect the results
			for rrset in r.answer:
				# capture the minimum nonzero TTL value in this overall query sequence
				ttl = rrset.ttl if ttl == 0 else min(ttl, rrset.ttl)

				# filter results down to IP addresses
				for rr in filter(lambda x: x.rdtype in address_types and x.address not in result, rrset):
					result.append(rr.address)

			# if preferred IP version lookup succeeded, stop now, and use what we have so far
			if prefer_ipv and result:
				break

		# if query succeeded, but nothing found, move to the next server
		if not result:
			server_idx = (server_idx + 1) % len(server_addr)

	# if cache enabled, update the cache with positive return
	if cache_lookups and result:
		cache_update(name, result, ttl)

	# add a failure line if no addresses provided
	if not result:
		result.append('0')
	
	# clean out sockets that are too old
	clean_cache()

	# return the result and the minimum TTL that was received
	return result, ttl


#
#  resolve_locked(name)
#
def resolve_locked(name):
	with mutex:
		try:
			return resolve_core(name)
		except Exception as ex:
			sys.stderr.write("DEBUG: resolve_locked() raised %s\n" % str(ex))
			return None


#
#  background_clean()
#
def background_clean():
	while True:
		# be very nice to the system
		time.sleep(clean_interval)

		try:
			# if cleanup hasn't run in a while, run it
			with mutex:
				if time.time() - last_clean >= clean_interval:
					clean_cache()
		except Exception as ex:
			sys.stderr.write("DEBUG: background_clean() raised %s\n" % str(ex))


#
#  add_wkn(s) - add "ip=name" association in well-known DNS name table
#
def add_wkn(s):
	key = value = None
	try:
		key, value = s.split('=')
		key = ip_int(key)
	except:
		raise Exception("Argument to -n must be of the form ip=name")

	if not key:
		raise Exception("Address provided to -n is invalid")
	if not is_valid(value):
		raise Exception("Hostname provided to -n is invalid")

	# update the wkn table
	wk_names[key] = value


#
#  usage()
#
def usage():
	sys.stderr.write("%s [options]\n" % os.path.basename(sys.argv[0]))
	sys.stderr.write("   -h = show this information\n")
	sys.stderr.write("   -4 = prefer IPv4 addresses\n")
	sys.stderr.write("   -6 = prefer IPv6 addresses\n")
	sys.stderr.write("   -c = enable lookup cache\n")
	sys.stderr.write("   -w = enable TLS validation for well-known servers\n")
	sys.stderr.write("   -P <port> = set server TCP port\n")
	sys.stderr.write("   -z = show debugging information about each query\n")
	sys.stderr.write("   -s | --server-addr <addr>[,<addr> ...]\n")
	sys.stderr.write("        = set remote server addresses\n")
	sys.stderr.write("   -p | --hosts-path <path>\n")
	sys.stderr.write("        = read hosts from 'path'\n")
	sys.stderr.write("   -t | --timeout <seconds>\n")
	sys.stderr.write("        = set DNS timeout value\n")
	sys.stderr.write("   -i <seconds>\n")
	sys.stderr.write("        = set max time before idle connections are closed\n")
	sys.stderr.write("   -n <ip=name>\n")
	sys.stderr.write("        = add mapping from 'ip' to 'name' in the well-known hosts table\n")
	sys.exit(1)


#
#  main()
#
def main():
	global server_addr, server_port, cache_lookups, dns_query_timeout, prefer_ipv, auto_validate, socket_idle, debug_lookups

	# path to the optional hosts file
	hosts_path = None

	# if True, use default servers
	default_servers = True

	try:
		# read the command line
		optlist, args = getopt.getopt(sys.argv[1:], '46chi:n:s:p:P:t:wz', [ 'server-addr=', 'hosts-path=', 'timeout=' ])
		
		# process the command line
		for opt in optlist:
			if opt[0] in [ '-s', '--server-addr' ]:
				if default_servers:
					default_servers = False
					server_addr = [ ]
				server_addr += opt[1].split(',')
			elif opt[0] in [ '-p', '--hosts-path' ]:
				hosts_path = os.path.expanduser(opt[1])
			elif opt[0] in [ '-t', '--timeout' ]:
				dns_query_timeout = float(opt[1])
			elif opt[0] == '-P':
				server_port = int(opt[1])
			elif opt[0] == '-i':
				new_value = int(opt[1])
				if new_value >= 1:
					socket_idle = new_value
			elif opt[0] == '-c':
				cache_lookups = True
			elif opt[0] == '-z':
				debug_lookups = True
			elif opt[0] == '-4':
				prefer_ipv = 4
			elif opt[0] == '-6':
				prefer_ipv = 6
			elif opt[0] == '-w':
				auto_validate = True
			elif opt[0] == '-n':
				add_wkn(opt[1])
			else:
				usage()
	except Exception as ex:
		sys.stderr.write("\nError: %s\n\n" % str(ex))
		usage()

	# sort the address type list to search in the preferred order
	if prefer_ipv == 6:
		address_types.reverse()

	# use defaults if none provided
	if not server_addr:
		raise Exception("Must specify at least one server address.")

	# start the background cleaner thread
	bg_thread = threading.Thread(target=background_clean, daemon=True)
	bg_thread.start()

	# call the main processing loop
	main_common(resolver=resolve_locked, hosts_path=hosts_path)


# entry point - run main() with exception safety
if __name__ == '__main__':
	try:
		main()
		sys.exit(0)
	except BrokenPipeError as ex:
		sys.exit(2)
	except KeyboardInterrupt as ex:
		sys.exit(0)
	except Exception as ex:
		sys.stderr.write("Received unexpected exception: %s: %s\n" % (str(type(ex)), str(ex)))
		sys.exit(1)

# EOF: s5dnstls
