handler.py
Handler for websocket stuffs.
IndexHandler (MixinHandler, RequestHandler)
¶
Source code in wizardwebssh/handler.py
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
executor = ThreadPoolExecutor(max_workers=cpu_count() * 5)
def initialize(self, loop, policy, host_keys_settings):
super(IndexHandler, self).initialize(loop)
self.policy = policy
self.host_keys_settings = host_keys_settings
self.ssh_client = self.get_ssh_client()
self.debug = self.settings.get("debug", True)
self.font = self.settings.get("font", "")
self.result = dict(id=None, status=None, encoding=None)
def write_error(self, status_code, **kwargs):
if swallow_http_errors and self.request.method == "POST":
exc_info = kwargs.get("exc_info")
if exc_info:
reason = getattr(exc_info[1], "log_message", None)
if reason:
self._reason = reason
self.result.update(status=self._reason)
self.set_status(200)
self.finish(self.result)
else:
super(IndexHandler, self).write_error(status_code, **kwargs)
def get_ssh_client(self):
ssh = SSHClient()
ssh._system_host_keys = self.host_keys_settings["system_host_keys"]
ssh._host_keys = self.host_keys_settings["host_keys"]
ssh._host_keys_filename = self.host_keys_settings["host_keys_filename"]
ssh.set_missing_host_key_policy(self.policy)
return ssh
def get_privatekey(self):
name = "privatekey"
lst = self.request.files.get(name)
if lst:
# multipart form
filename = lst[0]["filename"]
data = lst[0]["body"]
value = self.decode_argument(data, name=name).strip()
else:
# urlencoded form
value = self.get_argument(name, "")
filename = ""
return value, filename
def get_hostname(self):
value = self.get_value("hostname")
if not (is_valid_hostname(value) or is_valid_ip_address(value)):
raise InvalidValueError("Invalid hostname: {}".format(value))
return value
def get_port(self):
value = self.get_argument("port", "")
if not value:
return DEFAULT_PORT
port = to_int(value)
if port is None or not is_valid_port(port):
raise InvalidValueError("Invalid port: {}".format(value))
return port
def lookup_hostname(self, hostname, port):
key = hostname if port == 22 else "[{}]:{}".format(hostname, port)
if self.ssh_client._system_host_keys.lookup(key) is None:
if self.ssh_client._host_keys.lookup(key) is None:
raise tornado.web.HTTPError(403, "Connection to {}:{} is not allowed.".format(hostname, port))
def get_args(self):
global priority, ssh_id, ssh_priority, ssh_connection_name, ssh_username, ssh_password, ssh_key_passphrase, ssh_public_key, ssh_private_key, ssh_host, ssh_hostname, ssh_port, ssh_proxy_command, ssh_public_key_file, ssh_private_key_file
try:
# default_ssh_connection()
default_ssh_connection(sshdb, default_ssh_connection_name)
except:
pass
hostname_form = self.get_hostname()
port_form = self.get_port()
username_form = self.get_value("username")
password_form = self.get_argument("password", "")
privatekey_form, filename = (
self.get_privatekey()
if bool(self.get_privatekey()) is not False
else ssh_private_key_file
if bool(ssh_private_key_file) is not False
else print("Unable to find Private Key file")
)
passphrase_form = self.get_argument("passphrase", "")
totp = self.get_argument("totp", "")
# New version which defaults to form for stuff unless its empty(false)
hostname = hostname_form if bool(hostname_form) is not False else ssh_hostname
port = port_form if bool(port_form) is not False else ssh_port
username = username_form if bool(username_form) is not False else ssh_username
password = password_form if bool(password_form) is not False else ssh_password
# privatekey = privatekey_form if bool(privatekey_form) is not False else bytes.decode(ssh_private_key) if bool(
# ssh_private_key) is not False else print('No Private key provided')
privatekey = (
privatekey_form
if bool(privatekey_form) is not False
else ssh_private_key
if bool(ssh_private_key) is not False
else print("No Private key provided")
)
passphrase = passphrase_form if bool(passphrase_form) is not False else ssh_key_passphrase
if isinstance(self.policy, paramiko.RejectPolicy):
self.lookup_hostname(hostname, port)
if privatekey:
pkey = PrivateKey(privatekey, passphrase, filename).get_pkey_obj()
else:
pkey = None
self.ssh_client.totp = totp
args = (hostname, port, username, password, pkey)
logging.debug(args)
return args
def parse_encoding(self, data):
try:
encoding = to_str(data.strip(), "ascii")
except UnicodeDecodeError:
return
if is_valid_encoding(encoding):
return encoding
def get_default_encoding(self, ssh):
commands = ['$SHELL -ilc "locale charmap"', '$SHELL -ic "locale charmap"']
for command in commands:
try:
_, stdout, _ = ssh.exec_command(command, get_pty=True)
except paramiko.SSHException as exc:
logging.info(str(exc))
else:
data = stdout.read()
logging.debug("{!r} => {!r}".format(command, data))
result = self.parse_encoding(data)
if result:
return result
logging.warning("Could not detect the default encoding.")
return "utf-8"
def ssh_connect(self, args):
ssh = self.ssh_client
dst_addr = args[:2]
logging.info("Connecting to {}:{}".format(*dst_addr))
try:
ssh.connect(
*args,
allow_agent=options.allow_agent,
look_for_keys=options.look_for_keys,
timeout=options.timeout,
auth_timeout=options.auth_timeout,
)
except socket.error:
raise ValueError("Unable to connect to {}:{}".format(*dst_addr))
except paramiko.BadAuthenticationType:
raise ValueError("Bad authentication type.")
except paramiko.AuthenticationException:
raise ValueError("Authentication failed.")
except paramiko.BadHostKeyException:
raise ValueError("Bad host key.")
term = self.get_argument("term", "") or "xterm"
chan = ssh.invoke_shell(term=term)
logging.info(f"Channel to channel: {chan} ")
chan.setblocking(0)
worker = Worker(self.loop, ssh, chan, dst_addr)
worker.encoding = options.encoding if options.encoding else self.get_default_encoding(ssh)
return worker
def check_origin(self):
event_origin = self.get_argument("_origin", "")
header_origin = self.request.headers.get("Origin")
origin = event_origin or header_origin
if origin:
if not super(IndexHandler, self).check_origin(origin):
raise tornado.web.HTTPError(403, "Cross origin operation is not allowed.")
if not event_origin and self.origin_policy != "same":
self.set_header("Access-Control-Allow-Origin", origin)
def head(self):
pass
def get(self):
self.render("index.html", debug=self.debug, font=self.font)
@tornado.gen.coroutine
def post(self):
if self.debug and self.get_argument("error", ""):
# for testing purpose only
raise ValueError("Uncaught exception")
ip, port = self.get_client_addr()
workers = clients.get(ip, {})
if workers and len(workers) >= options.maxconn:
raise tornado.web.HTTPError(403, "Too many live connections.")
self.check_origin()
try:
args = self.get_args()
except InvalidValueError as exc:
raise tornado.web.HTTPError(400, str(exc))
future = self.executor.submit(self.ssh_connect, args)
try:
worker = yield future
except (ValueError, paramiko.SSHException) as exc:
logging.error(traceback.format_exc())
self.result.update(status=str(exc))
else:
if not workers:
clients[ip] = workers
worker.src_addr = (ip, port)
workers[worker.id] = worker
self.loop.call_later(options.delay or DELAY, recycle_worker, worker)
self.result.update(id=worker.id, encoding=worker.encoding)
self.write(self.result)
write_error(self, status_code, **kwargs)
¶
Override to implement custom error pages.
write_error
may call write
, render
, set_header
, etc
to produce output as usual.
If this error was caused by an uncaught exception (including
HTTPError), an exc_info
triple will be available as
kwargs["exc_info"]
. Note that this exception may not be
the "current" exception for purposes of methods like
sys.exc_info()
or traceback.format_exc
.
Source code in wizardwebssh/handler.py
def write_error(self, status_code, **kwargs):
if swallow_http_errors and self.request.method == "POST":
exc_info = kwargs.get("exc_info")
if exc_info:
reason = getattr(exc_info[1], "log_message", None)
if reason:
self._reason = reason
self.result.update(status=self._reason)
self.set_status(200)
self.finish(self.result)
else:
super(IndexHandler, self).write_error(status_code, **kwargs)
NotFoundHandler (MixinHandler, ErrorHandler)
¶
Source code in wizardwebssh/handler.py
class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler):
def initialize(self):
super(NotFoundHandler, self).initialize()
def prepare(self):
raise tornado.web.HTTPError(404)
prepare(self)
¶
Called at the beginning of a request before get
/post
/etc.
Override this method to perform common initialization regardless of the request method.
Asynchronous support: Use async def
or decorate this method with
.gen.coroutine
to make it asynchronous.
If this method returns an Awaitable
execution will not proceed
until the Awaitable
is done.
.. versionadded:: 3.1 Asynchronous support.
Source code in wizardwebssh/handler.py
def prepare(self):
raise tornado.web.HTTPError(404)
WsockHandler (MixinHandler, WebSocketHandler)
¶
Source code in wizardwebssh/handler.py
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def initialize(self, loop):
super(WsockHandler, self).initialize(loop)
self.worker_ref = None
def open(self):
self.src_addr = self.get_client_addr()
logging.info("Connected from {}:{}".format(*self.src_addr))
workers = clients.get(self.src_addr[0])
if not workers:
self.close(reason="Websocket authentication failed.")
return
try:
worker_id = self.get_value("id")
except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
self.close(reason=str(exc))
else:
worker = workers.get(worker_id)
if worker:
workers[worker_id] = None
self.set_nodelay(True)
worker.set_handler(self)
self.worker_ref = weakref.ref(worker)
self.loop.add_handler(worker.fd, worker, IOLoop.READ)
else:
self.close(reason="Websocket authentication failed.")
def on_message(self, message):
logging.debug("{!r} from {}:{}".format(message, *self.src_addr))
worker = self.worker_ref()
try:
msg = json.loads(message)
except JSONDecodeError:
return
if not isinstance(msg, dict):
return
resize = msg.get("resize")
if resize and len(resize) == 2:
try:
worker.chan.resize_pty(*resize)
except (TypeError, struct.error, paramiko.SSHException):
pass
data = msg.get("data")
if data and isinstance(data, UnicodeType):
worker.data_to_dst.append(data)
worker.on_write()
def on_close(self):
logging.info("Disconnected from {}:{}".format(*self.src_addr))
if not self.close_reason:
self.close_reason = "client disconnected"
worker = self.worker_ref() if self.worker_ref else None
if worker:
worker.close(reason=self.close_reason)
on_close(self)
¶
Invoked when the WebSocket is closed.
If the connection was closed cleanly and a status code or reason
phrase was supplied, these values will be available as the attributes
self.close_code
and self.close_reason
.
.. versionchanged:: 4.0
Added close_code
and close_reason
attributes.
Source code in wizardwebssh/handler.py
def on_close(self):
logging.info("Disconnected from {}:{}".format(*self.src_addr))
if not self.close_reason:
self.close_reason = "client disconnected"
worker = self.worker_ref() if self.worker_ref else None
if worker:
worker.close(reason=self.close_reason)
on_message(self, message)
¶
Handle incoming messages on the WebSocket
This method must be overridden.
.. versionchanged:: 4.5
on_message
can be a coroutine.
Source code in wizardwebssh/handler.py
def on_message(self, message):
logging.debug("{!r} from {}:{}".format(message, *self.src_addr))
worker = self.worker_ref()
try:
msg = json.loads(message)
except JSONDecodeError:
return
if not isinstance(msg, dict):
return
resize = msg.get("resize")
if resize and len(resize) == 2:
try:
worker.chan.resize_pty(*resize)
except (TypeError, struct.error, paramiko.SSHException):
pass
data = msg.get("data")
if data and isinstance(data, UnicodeType):
worker.data_to_dst.append(data)
worker.on_write()
open(self)
¶
Invoked when a new WebSocket is opened.
The arguments to open
are extracted from the tornado.web.URLSpec
regular expression, just like the arguments to
tornado.web.RequestHandler.get
.
open
may be a coroutine. on_message
will not be called until
open
has returned.
.. versionchanged:: 5.1
open
may be a coroutine.
Source code in wizardwebssh/handler.py
def open(self):
self.src_addr = self.get_client_addr()
logging.info("Connected from {}:{}".format(*self.src_addr))
workers = clients.get(self.src_addr[0])
if not workers:
self.close(reason="Websocket authentication failed.")
return
try:
worker_id = self.get_value("id")
except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
self.close(reason=str(exc))
else:
worker = workers.get(worker_id)
if worker:
workers[worker_id] = None
self.set_nodelay(True)
worker.set_handler(self)
self.worker_ref = weakref.ref(worker)
self.loop.add_handler(worker.fd, worker, IOLoop.READ)
else:
self.close(reason="Websocket authentication failed.")
get_default_ssh_connection_data(database_name, connection)
¶
Gets ssh connection data for default connection.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
database_name |
QSqlDatabase connection or database to use. |
required | |
connection |
ssh connection to get data for. |
required |
Returns: dictionary of ssh connection data
Source code in wizardwebssh/handler.py
def get_default_ssh_connection_data(database_name, connection):
"""
Gets ssh connection data for default connection.
Args:
database_name (): QSqlDatabase connection or database to use.
connection (): ssh connection to get data for.
Returns: dictionary of ssh connection data
"""
# query = f"SELECT * from sshconfig where sshconnectionname ='{connection}'"
query = query = f"""
SELECT ssh_group_name, ssh_connection_name,ssh_username,ssh_password,Host,HostName,Port,ProxyCommand,
sshkey_name,sshkey_passphrase,sshkey_public,sshkey_private,sshkey_public_file,sshkey_private_file,
ssh_config_name,ssh_config_content
FROM sshconnections
JOIN sshkeys ON sshconnections.ssh_key_id = sshkeys.id
JOIN sshgroup ON sshconnections.ssh_group_id = sshgroup.id
JOIN sshconfig ON sshconnections.ssh_config_id = sshconfig.id
WHERE ssh_connection_name = '{connection}'
"""
return get_query_as_dict(query, database_name)
get_query_as_dict(query, database_name)
¶
Get QSqlQuery for single record as a dictionary with the columns as the keys.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query |
QSqlQuery to use. |
required | |
database_name |
QSqlDatabase connection or database to use. |
required |
Returns: dictionary with key values of the query.
Source code in wizardwebssh/handler.py
def get_query_as_dict(query, database_name):
"""
Get QSqlQuery for single record as a dictionary with the columns as the keys.
Args:
query (): QSqlQuery to use.
database_name (): QSqlDatabase connection or database to use.
Returns: dictionary with key values of the query.
"""
row_values = {}
columns_names_mapping = {}
try:
# columns_names_mapping = {}
# row_values = {}
q = QSqlQuery(f"{query}", db=database_name)
rec = q.record()
if q.exec():
if q.first():
for column in range(rec.count()):
# print(column)
field = rec.fieldName(column)
value = q.value(column)
columns_names_mapping[field] = column
row_values[field] = str(value)
# print(str(row_values))
return row_values
except Exception as e:
print(f"Exception: {e}")
return row_values
paramiko_host_info(host)
¶
Get SSH host information via Paramiko Parser for terminal.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
host |
SSH host from ssh config file |
required |
Returns: dictionary with connection info parsed from config
Source code in wizardwebssh/handler.py
def paramiko_host_info(host):
"""
Get SSH host information via Paramiko Parser for terminal.
Args:
host (): SSH host from ssh config file
Returns: dictionary with connection info parsed from config
"""
ssh_config = paramiko.SSHConfig()
user_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(user_config_file):
with open(user_config_file) as f:
ssh_config.parse(f)
o = ssh_config.lookup(host)
# print(o)
# setup template
con = {
"ssh_group_name": "default",
"ssh_connection_name": host,
"ssh_username": "",
"ssh_password": "",
"Host": host,
"HostName": "",
"Port": "22",
"ProxyCommand": "",
"sshkey_name": "None",
"sshkey_passphrase": "",
"sshkey_public": "",
"sshkey_private": "",
"sshkey_public_file": "",
"sshkey_private_file": "",
"ssh_config_name": "paramiko",
"ssh_config_content": "",
}
if o:
if "hostname" in o.keys():
con.update(HostName=o["hostname"])
if "user" in o.keys():
con.update(ssh_username=o["user"])
# if 'identityfile' in o.keys():
# ident = o['identityfile']
# if type(ident) is list:
# ident = ident[0]
# con.update(sshkey_private_file=ident)
# if 'key_filename' in o.keys():
# ident = o['identityfile']
# con.update(sshkey_private_file=ident)
if "port" in o.keys():
con.update(Port=o["port"])
if "proxyjump" in o.keys():
con.update(ProxyCommand=o["proxyjump"])
print(
f"Printing paramiko host lookup dict: {con} from {os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])))}"
)
return con