33import os
44import subprocess
55import tempfile
6- import time
76
8- import sshtunnel
7+ # we support both pg8000 and psycopg2
8+ try :
9+ import psycopg2 as pglib
10+ except ImportError :
11+ try :
12+ import pg8000 as pglib
13+ except ImportError :
14+ raise ImportError ("You must have psycopg2 or pg8000 modules installed" )
915
1016from ..exceptions import ExecUtilException
1117
1218from .os_ops import OsOperations , ConnectionParams
13- from .os_ops import pglib
14-
15- sshtunnel .SSH_TIMEOUT = 5.0
16- sshtunnel .TUNNEL_TIMEOUT = 5.0
1719
1820ConsoleEncoding = locale .getdefaultlocale ()[1 ]
1921if not ConsoleEncoding :
@@ -50,21 +52,28 @@ def __init__(self, conn_params: ConnectionParams):
5052 self .remote = True
5153 self .username = conn_params .username or self .get_user ()
5254 self .add_known_host (self .host )
55+ self .tunnel_process = None
5356
5457 def __enter__ (self ):
5558 return self
5659
5760 def __exit__ (self , exc_type , exc_val , exc_tb ):
58- self .close_tunnel ()
61+ self .close_ssh_tunnel ()
5962
60- def close_tunnel (self ):
61- if getattr (self , 'tunnel' , None ):
62- self .tunnel .stop (force = True )
63- start_time = time .time ()
64- while self .tunnel .is_active :
65- if time .time () - start_time > sshtunnel .TUNNEL_TIMEOUT :
66- break
67- time .sleep (0.5 )
63+ def establish_ssh_tunnel (self , local_port , remote_port ):
64+ """
65+ Establish an SSH tunnel from a local port to a remote PostgreSQL port.
66+ """
67+ ssh_cmd = ['-N' , '-L' , f"{ local_port } :localhost:{ remote_port } " ]
68+ self .tunnel_process = self .exec_command (ssh_cmd , get_process = True , timeout = 300 )
69+
70+ def close_ssh_tunnel (self ):
71+ if hasattr (self , 'tunnel_process' ):
72+ self .tunnel_process .terminate ()
73+ self .tunnel_process .wait ()
74+ del self .tunnel_process
75+ else :
76+ print ("No active tunnel to close." )
6877
6978 def add_known_host (self , host ):
7079 cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host , os .getlogin ())
@@ -78,21 +87,29 @@ def add_known_host(self, host):
7887 raise ExecUtilException (message = "Failed to add %s to known_hosts. Error: %s" % (host , str (e )), command = cmd ,
7988 exit_code = e .returncode , out = e .stderr )
8089
81- def exec_command (self , cmd : str , wait_exit = False , verbose = False , expect_error = False ,
90+ def exec_command (self , cmd , wait_exit = False , verbose = False , expect_error = False ,
8291 encoding = None , shell = True , text = False , input = None , stdin = None , stdout = None ,
83- stderr = None , proc = None ):
92+ stderr = None , get_process = None , timeout = None ):
8493 """
8594 Execute a command in the SSH session.
8695 Args:
8796 - cmd (str): The command to be executed.
8897 """
98+ ssh_cmd = []
8999 if isinstance (cmd , str ):
90100 ssh_cmd = ['ssh' , f"{ self .username } @{ self .host } " , '-i' , self .ssh_key , cmd ]
91101 elif isinstance (cmd , list ):
92102 ssh_cmd = ['ssh' , f"{ self .username } @{ self .host } " , '-i' , self .ssh_key ] + cmd
93103 process = subprocess .Popen (ssh_cmd , stdin = subprocess .PIPE , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
104+ if get_process :
105+ return process
106+
107+ try :
108+ result , error = process .communicate (input , timeout = timeout )
109+ except subprocess .TimeoutExpired :
110+ process .kill ()
111+ raise ExecUtilException ("Command timed out after {} seconds." .format (timeout ))
94112
95- result , error = process .communicate (input )
96113 exit_status = process .returncode
97114
98115 if encoding :
@@ -372,41 +389,19 @@ def get_process_children(self, pid):
372389 raise ExecUtilException (f"Error in getting process children. Error: { result .stderr } " )
373390
374391 # Database control
375- def db_connect (self , dbname , user , password = None , host = "127.0.0.1 " , port = 5432 , ssh_key = None ):
392+ def db_connect (self , dbname , user , password = None , host = "localhost " , port = 5432 ):
376393 """
377- Connects to a PostgreSQL database on the remote system.
378- Args:
379- - dbname (str): The name of the database to connect to.
380- - user (str): The username for the database connection.
381- - password (str, optional): The password for the database connection. Defaults to None.
382- - host (str, optional): The IP address of the remote system. Defaults to "localhost".
383- - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432.
384-
385- This function establishes a connection to a PostgreSQL database on the remote system using the specified
386- parameters. It returns a connection object that can be used to interact with the database.
394+ Established SSH tunnel and Connects to a PostgreSQL
387395 """
388- self .close_tunnel ()
389- self .tunnel = sshtunnel .open_tunnel (
390- (self .host , 22 ), # Remote server IP and SSH port
391- ssh_username = self .username ,
392- ssh_pkey = self .ssh_key ,
393- remote_bind_address = (self .host , port ), # PostgreSQL server IP and PostgreSQL port
394- local_bind_address = ('localhost' , 0 )
395- # Local machine IP and available port (0 means it will pick any available port)
396- )
397- self .tunnel .start ()
398-
396+ self .establish_ssh_tunnel (local_port = port , remote_port = 5432 )
399397 try :
400- # Use localhost and self.tunnel.local_bind_port to connect
401398 conn = pglib .connect (
402- host = 'localhost' , # Connect to localhost
403- port = self . tunnel . local_bind_port , # use the local bind port set up by the tunnel
399+ host = host ,
400+ port = port ,
404401 database = dbname ,
405- user = user or self . username ,
406- password = password
402+ user = user ,
403+ password = password ,
407404 )
408-
409405 return conn
410406 except Exception as e :
411- self .tunnel .stop ()
412- raise ExecUtilException ("Could not create db tunnel. {}" .format (e ))
407+ raise Exception (f"Could not connect to the database. Error: { e } " )
0 commit comments