File indexing completed on 2026-04-20 07:58:58
0001 import random
0002 import socket
0003 import threading
0004 import uuid
0005
0006 import pexpect
0007
0008 from pandaharvester.harvestercore import core_utils
0009
0010 pexpect_spawn = pexpect.spawnu
0011
0012
0013 baseLogger = core_utils.setup_logger("ssh_tunnel_pool")
0014
0015
0016
0017 class SshTunnelPool(object):
0018
0019 def __init__(self):
0020 self.lock = threading.Lock()
0021 self.pool = dict()
0022 self.params = dict()
0023
0024
0025 def make_dict_key(self, host, port):
0026 return f"{host}:{port}"
0027
0028
0029 def make_tunnel_server(
0030 self,
0031 remote_host,
0032 remote_port,
0033 remote_bind_port=None,
0034 num_tunnels=1,
0035 ssh_username=None,
0036 ssh_password=None,
0037 private_key=None,
0038 pass_phrase=None,
0039 jump_host=None,
0040 jump_port=None,
0041 login_timeout=60,
0042 reconnect=False,
0043 with_lock=True,
0044 ):
0045 dict_key = self.make_dict_key(remote_host, remote_port)
0046 if with_lock:
0047 self.lock.acquire()
0048
0049 if dict_key not in self.pool:
0050 self.pool[dict_key] = []
0051
0052 if not reconnect:
0053 self.params[dict_key] = {
0054 "remote_bind_port": remote_bind_port,
0055 "num_tunnels": num_tunnels,
0056 "ssh_username": ssh_username,
0057 "ssh_password": ssh_password,
0058 "private_key": private_key,
0059 "pass_phrase": pass_phrase,
0060 "jump_host": jump_host,
0061 "jump_port": jump_port,
0062 "login_timeout": login_timeout,
0063 }
0064 else:
0065 remote_bind_port = self.params[dict_key]["remote_bind_port"]
0066 num_tunnels = self.params[dict_key]["num_tunnels"]
0067 ssh_username = self.params[dict_key]["ssh_username"]
0068 ssh_password = self.params[dict_key]["ssh_password"]
0069 private_key = self.params[dict_key]["private_key"]
0070 pass_phrase = self.params[dict_key]["pass_phrase"]
0071 jump_host = self.params[dict_key]["jump_host"]
0072 jump_port = self.params[dict_key]["jump_port"]
0073 login_timeout = self.params[dict_key]["login_timeout"]
0074
0075 for i in range(num_tunnels - len(self.pool[dict_key])):
0076
0077 s = socket.socket()
0078 s.bind(("", 0))
0079 com = "ssh -L {local_bind_port}:127.0.0.1:{remote_bind_port} "
0080 com += "-p {remote_port} {ssh_username}@{remote_host} "
0081 com += "-o ServerAliveInterval=120 -o ServerAliveCountMax=2 "
0082 if private_key is not None:
0083 com += "-i {private_key} "
0084 if jump_port is not None:
0085 com += '-o ProxyCommand="ssh -p {jump_port} {ssh_username}@{jump_host} -W %h:%p" '
0086 local_bind_port = s.getsockname()[1]
0087 com = com.format(
0088 remote_host=remote_host,
0089 remote_port=remote_port,
0090 remote_bind_port=remote_bind_port,
0091 ssh_username=ssh_username,
0092 private_key=private_key,
0093 jump_host=jump_host,
0094 jump_port=jump_port,
0095 local_bind_port=local_bind_port,
0096 )
0097 s.close()
0098
0099 loginString = "login_to_be_confirmed_with " + uuid.uuid4().hex
0100 expected_list = [
0101 pexpect.EOF,
0102 pexpect.TIMEOUT,
0103 "(?i)are you sure you want to continue connecting",
0104 "(?i)password:",
0105 "(?i)enter passphrase for key.*",
0106 loginString,
0107 ]
0108 c = pexpect_spawn(com, echo=False)
0109 c.logfile_read = baseLogger.handlers[0].stream
0110 isOK = False
0111 for iTry in range(3):
0112 idx = c.expect(expected_list, timeout=login_timeout)
0113 if idx == expected_list.index(loginString):
0114
0115 isOK = True
0116 break
0117 if idx == 1:
0118
0119 baseLogger.error(f"timeout when making a tunnel with com={com} out={c.buffer}")
0120 c.close()
0121 break
0122 if idx == 2:
0123
0124 c.sendline("yes")
0125 idx = c.expect(expected_list, timeout=login_timeout)
0126 if idx == 1:
0127
0128 baseLogger.error(f"timeout after accepting new cert with com={com} out={c.buffer}")
0129 c.close()
0130 break
0131 if idx == 3:
0132
0133 c.sendline(ssh_password)
0134 elif idx == 4:
0135
0136 c.sendline(pass_phrase)
0137 elif idx == 0:
0138 baseLogger.error(f"something weired with com={com} out={c.buffer}")
0139 c.close()
0140 break
0141
0142 c.sendline(f"echo {loginString}")
0143 if isOK:
0144 self.pool[dict_key].append((local_bind_port, c))
0145 if with_lock:
0146 self.lock.release()
0147
0148
0149 def get_tunnel(self, remote_host, remote_port):
0150 dict_key = self.make_dict_key(remote_host, remote_port)
0151 self.lock.acquire()
0152 active_tunnels = []
0153 someClosed = False
0154 for port, child in self.pool[dict_key]:
0155 if child.isalive():
0156 active_tunnels.append([port, child])
0157 else:
0158 child.close()
0159 someClosed = True
0160 if someClosed:
0161 self.make_tunnel_server(remote_host, remote_port, reconnect=True, with_lock=False)
0162 active_tunnels = [item for item in self.pool[dict_key] if item[1].isalive()]
0163 if len(active_tunnels) > 0:
0164 port, child = random.choice(active_tunnels)
0165 else:
0166 port, child = None, None
0167 self.lock.release()
0168 return ("127.0.0.1", port, child)
0169
0170
0171
0172 sshTunnelPool = SshTunnelPool()
0173 del SshTunnelPool