Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-20 07:58:58

0001 import random
0002 import socket
0003 import threading
0004 import uuid
0005 
0006 import pexpect
0007 
0008 from pandaharvester.harvestercore import core_utils
0009 
0010 pexpect_spawn = pexpect.spawnu
0011 
0012 # logger
0013 baseLogger = core_utils.setup_logger("ssh_tunnel_pool")
0014 
0015 
0016 # Pool of SSH tunnels
0017 class SshTunnelPool(object):
0018     # constructor
0019     def __init__(self):
0020         self.lock = threading.Lock()
0021         self.pool = dict()
0022         self.params = dict()
0023 
0024     # make a dict key
0025     def make_dict_key(self, host, port):
0026         return f"{host}:{port}"
0027 
0028     # make a tunnel server
0029     def make_tunnel_server(
0030         self,
0031         remote_host,
0032         remote_port,
0033         remote_bind_port=None,
0034         num_tunnels=1,
0035         ssh_username=None,
0036         ssh_password=None,
0037         private_key=None,
0038         pass_phrase=None,
0039         jump_host=None,
0040         jump_port=None,
0041         login_timeout=60,
0042         reconnect=False,
0043         with_lock=True,
0044     ):
0045         dict_key = self.make_dict_key(remote_host, remote_port)
0046         if with_lock:
0047             self.lock.acquire()
0048         # make dicts
0049         if dict_key not in self.pool:
0050             self.pool[dict_key] = []
0051         # preserve parameters
0052         if not reconnect:
0053             self.params[dict_key] = {
0054                 "remote_bind_port": remote_bind_port,
0055                 "num_tunnels": num_tunnels,
0056                 "ssh_username": ssh_username,
0057                 "ssh_password": ssh_password,
0058                 "private_key": private_key,
0059                 "pass_phrase": pass_phrase,
0060                 "jump_host": jump_host,
0061                 "jump_port": jump_port,
0062                 "login_timeout": login_timeout,
0063             }
0064         else:
0065             remote_bind_port = self.params[dict_key]["remote_bind_port"]
0066             num_tunnels = self.params[dict_key]["num_tunnels"]
0067             ssh_username = self.params[dict_key]["ssh_username"]
0068             ssh_password = self.params[dict_key]["ssh_password"]
0069             private_key = self.params[dict_key]["private_key"]
0070             pass_phrase = self.params[dict_key]["pass_phrase"]
0071             jump_host = self.params[dict_key]["jump_host"]
0072             jump_port = self.params[dict_key]["jump_port"]
0073             login_timeout = self.params[dict_key]["login_timeout"]
0074         # make a tunnel server
0075         for i in range(num_tunnels - len(self.pool[dict_key])):
0076             # get a free port
0077             s = socket.socket()
0078             s.bind(("", 0))
0079             com = "ssh -L {local_bind_port}:127.0.0.1:{remote_bind_port} "
0080             com += "-p {remote_port} {ssh_username}@{remote_host} "
0081             com += "-o ServerAliveInterval=120 -o ServerAliveCountMax=2 "
0082             if private_key is not None:
0083                 com += "-i {private_key} "
0084             if jump_port is not None:
0085                 com += '-o ProxyCommand="ssh -p {jump_port} {ssh_username}@{jump_host} -W %h:%p" '
0086             local_bind_port = s.getsockname()[1]
0087             com = com.format(
0088                 remote_host=remote_host,
0089                 remote_port=remote_port,
0090                 remote_bind_port=remote_bind_port,
0091                 ssh_username=ssh_username,
0092                 private_key=private_key,
0093                 jump_host=jump_host,
0094                 jump_port=jump_port,
0095                 local_bind_port=local_bind_port,
0096             )
0097             s.close()
0098             # list of expected strings
0099             loginString = "login_to_be_confirmed_with " + uuid.uuid4().hex
0100             expected_list = [
0101                 pexpect.EOF,
0102                 pexpect.TIMEOUT,
0103                 "(?i)are you sure you want to continue connecting",
0104                 "(?i)password:",
0105                 "(?i)enter passphrase for key.*",
0106                 loginString,
0107             ]
0108             c = pexpect_spawn(com, echo=False)
0109             c.logfile_read = baseLogger.handlers[0].stream
0110             isOK = False
0111             for iTry in range(3):
0112                 idx = c.expect(expected_list, timeout=login_timeout)
0113                 if idx == expected_list.index(loginString):
0114                     # succeeded
0115                     isOK = True
0116                     break
0117                 if idx == 1:
0118                     # timeout
0119                     baseLogger.error(f"timeout when making a tunnel with com={com} out={c.buffer}")
0120                     c.close()
0121                     break
0122                 if idx == 2:
0123                     # new certificate
0124                     c.sendline("yes")
0125                     idx = c.expect(expected_list, timeout=login_timeout)
0126                 if idx == 1:
0127                     # timeout
0128                     baseLogger.error(f"timeout after accepting new cert with com={com} out={c.buffer}")
0129                     c.close()
0130                     break
0131                 if idx == 3:
0132                     # password prompt
0133                     c.sendline(ssh_password)
0134                 elif idx == 4:
0135                     # passphrase prompt
0136                     c.sendline(pass_phrase)
0137                 elif idx == 0:
0138                     baseLogger.error(f"something weired with com={com} out={c.buffer}")
0139                     c.close()
0140                     break
0141                 # exec to confirm login
0142                 c.sendline(f"echo {loginString}")
0143             if isOK:
0144                 self.pool[dict_key].append((local_bind_port, c))
0145         if with_lock:
0146             self.lock.release()
0147 
0148     # get a tunnel
0149     def get_tunnel(self, remote_host, remote_port):
0150         dict_key = self.make_dict_key(remote_host, remote_port)
0151         self.lock.acquire()
0152         active_tunnels = []
0153         someClosed = False
0154         for port, child in self.pool[dict_key]:
0155             if child.isalive():
0156                 active_tunnels.append([port, child])
0157             else:
0158                 child.close()
0159                 someClosed = True
0160         if someClosed:
0161             self.make_tunnel_server(remote_host, remote_port, reconnect=True, with_lock=False)
0162             active_tunnels = [item for item in self.pool[dict_key] if item[1].isalive()]
0163         if len(active_tunnels) > 0:
0164             port, child = random.choice(active_tunnels)
0165         else:
0166             port, child = None, None
0167         self.lock.release()
0168         return ("127.0.0.1", port, child)
0169 
0170 
0171 # singleton
0172 sshTunnelPool = SshTunnelPool()
0173 del SshTunnelPool