Skip to content

handler.py

Handler for websocket stuffs.

IndexHandler (MixinHandler, RequestHandler)

Source code in wizardwebssh/handler.py
class IndexHandler(MixinHandler, tornado.web.RequestHandler):
    executor = ThreadPoolExecutor(max_workers=cpu_count() * 5)

    def initialize(self, loop, policy, host_keys_settings):
        super(IndexHandler, self).initialize(loop)
        self.policy = policy
        self.host_keys_settings = host_keys_settings
        self.ssh_client = self.get_ssh_client()
        self.debug = self.settings.get("debug", True)
        self.font = self.settings.get("font", "")
        self.result = dict(id=None, status=None, encoding=None)

    def write_error(self, status_code, **kwargs):
        if swallow_http_errors and self.request.method == "POST":
            exc_info = kwargs.get("exc_info")
            if exc_info:
                reason = getattr(exc_info[1], "log_message", None)
                if reason:
                    self._reason = reason
            self.result.update(status=self._reason)
            self.set_status(200)
            self.finish(self.result)
        else:
            super(IndexHandler, self).write_error(status_code, **kwargs)

    def get_ssh_client(self):
        ssh = SSHClient()
        ssh._system_host_keys = self.host_keys_settings["system_host_keys"]
        ssh._host_keys = self.host_keys_settings["host_keys"]
        ssh._host_keys_filename = self.host_keys_settings["host_keys_filename"]
        ssh.set_missing_host_key_policy(self.policy)
        return ssh

    def get_privatekey(self):
        name = "privatekey"
        lst = self.request.files.get(name)
        if lst:
            # multipart form
            filename = lst[0]["filename"]
            data = lst[0]["body"]
            value = self.decode_argument(data, name=name).strip()
        else:
            # urlencoded form
            value = self.get_argument(name, "")
            filename = ""

        return value, filename

    def get_hostname(self):
        value = self.get_value("hostname")
        if not (is_valid_hostname(value) or is_valid_ip_address(value)):
            raise InvalidValueError("Invalid hostname: {}".format(value))
        return value

    def get_port(self):
        value = self.get_argument("port", "")
        if not value:
            return DEFAULT_PORT

        port = to_int(value)
        if port is None or not is_valid_port(port):
            raise InvalidValueError("Invalid port: {}".format(value))
        return port

    def lookup_hostname(self, hostname, port):
        key = hostname if port == 22 else "[{}]:{}".format(hostname, port)

        if self.ssh_client._system_host_keys.lookup(key) is None:
            if self.ssh_client._host_keys.lookup(key) is None:
                raise tornado.web.HTTPError(403, "Connection to {}:{} is not allowed.".format(hostname, port))

    def get_args(self):
        global priority, ssh_id, ssh_priority, ssh_connection_name, ssh_username, ssh_password, ssh_key_passphrase, ssh_public_key, ssh_private_key, ssh_host, ssh_hostname, ssh_port, ssh_proxy_command, ssh_public_key_file, ssh_private_key_file
        try:
            # default_ssh_connection()
            default_ssh_connection(sshdb, default_ssh_connection_name)
        except:
            pass

        hostname_form = self.get_hostname()
        port_form = self.get_port()
        username_form = self.get_value("username")
        password_form = self.get_argument("password", "")
        privatekey_form, filename = (
            self.get_privatekey()
            if bool(self.get_privatekey()) is not False
            else ssh_private_key_file
            if bool(ssh_private_key_file) is not False
            else print("Unable to find Private Key file")
        )
        passphrase_form = self.get_argument("passphrase", "")
        totp = self.get_argument("totp", "")

        # New version which defaults to form for stuff unless its empty(false)
        hostname = hostname_form if bool(hostname_form) is not False else ssh_hostname
        port = port_form if bool(port_form) is not False else ssh_port
        username = username_form if bool(username_form) is not False else ssh_username
        password = password_form if bool(password_form) is not False else ssh_password
        # privatekey = privatekey_form if bool(privatekey_form) is not False else bytes.decode(ssh_private_key) if bool(
        #    ssh_private_key) is not False else print('No Private key provided')
        privatekey = (
            privatekey_form
            if bool(privatekey_form) is not False
            else ssh_private_key
            if bool(ssh_private_key) is not False
            else print("No Private key provided")
        )
        passphrase = passphrase_form if bool(passphrase_form) is not False else ssh_key_passphrase

        if isinstance(self.policy, paramiko.RejectPolicy):
            self.lookup_hostname(hostname, port)

        if privatekey:
            pkey = PrivateKey(privatekey, passphrase, filename).get_pkey_obj()
        else:
            pkey = None

        self.ssh_client.totp = totp
        args = (hostname, port, username, password, pkey)
        logging.debug(args)

        return args

    def parse_encoding(self, data):
        try:
            encoding = to_str(data.strip(), "ascii")
        except UnicodeDecodeError:
            return

        if is_valid_encoding(encoding):
            return encoding

    def get_default_encoding(self, ssh):
        commands = ['$SHELL -ilc "locale charmap"', '$SHELL -ic "locale charmap"']

        for command in commands:
            try:
                _, stdout, _ = ssh.exec_command(command, get_pty=True)
            except paramiko.SSHException as exc:
                logging.info(str(exc))
            else:
                data = stdout.read()
                logging.debug("{!r} => {!r}".format(command, data))
                result = self.parse_encoding(data)
                if result:
                    return result

        logging.warning("Could not detect the default encoding.")
        return "utf-8"

    def ssh_connect(self, args):
        ssh = self.ssh_client
        dst_addr = args[:2]
        logging.info("Connecting to {}:{}".format(*dst_addr))

        try:
            ssh.connect(
                *args,
                allow_agent=options.allow_agent,
                look_for_keys=options.look_for_keys,
                timeout=options.timeout,
                auth_timeout=options.auth_timeout,
            )
        except socket.error:
            raise ValueError("Unable to connect to {}:{}".format(*dst_addr))
        except paramiko.BadAuthenticationType:
            raise ValueError("Bad authentication type.")
        except paramiko.AuthenticationException:
            raise ValueError("Authentication failed.")
        except paramiko.BadHostKeyException:
            raise ValueError("Bad host key.")

        term = self.get_argument("term", "") or "xterm"
        chan = ssh.invoke_shell(term=term)
        logging.info(f"Channel to channel: {chan} ")
        chan.setblocking(0)
        worker = Worker(self.loop, ssh, chan, dst_addr)
        worker.encoding = options.encoding if options.encoding else self.get_default_encoding(ssh)
        return worker

    def check_origin(self):
        event_origin = self.get_argument("_origin", "")
        header_origin = self.request.headers.get("Origin")
        origin = event_origin or header_origin

        if origin:
            if not super(IndexHandler, self).check_origin(origin):
                raise tornado.web.HTTPError(403, "Cross origin operation is not allowed.")

            if not event_origin and self.origin_policy != "same":
                self.set_header("Access-Control-Allow-Origin", origin)

    def head(self):
        pass

    def get(self):
        self.render("index.html", debug=self.debug, font=self.font)

    @tornado.gen.coroutine
    def post(self):
        if self.debug and self.get_argument("error", ""):
            # for testing purpose only
            raise ValueError("Uncaught exception")

        ip, port = self.get_client_addr()
        workers = clients.get(ip, {})
        if workers and len(workers) >= options.maxconn:
            raise tornado.web.HTTPError(403, "Too many live connections.")

        self.check_origin()

        try:
            args = self.get_args()
        except InvalidValueError as exc:
            raise tornado.web.HTTPError(400, str(exc))

        future = self.executor.submit(self.ssh_connect, args)

        try:
            worker = yield future
        except (ValueError, paramiko.SSHException) as exc:
            logging.error(traceback.format_exc())
            self.result.update(status=str(exc))
        else:
            if not workers:
                clients[ip] = workers
            worker.src_addr = (ip, port)
            workers[worker.id] = worker
            self.loop.call_later(options.delay or DELAY, recycle_worker, worker)
            self.result.update(id=worker.id, encoding=worker.encoding)

        self.write(self.result)

write_error(self, status_code, **kwargs)

Override to implement custom error pages.

write_error may call write, render, set_header, etc to produce output as usual.

If this error was caused by an uncaught exception (including HTTPError), an exc_info triple will be available as kwargs["exc_info"]. Note that this exception may not be the "current" exception for purposes of methods like sys.exc_info() or traceback.format_exc.

Source code in wizardwebssh/handler.py
def write_error(self, status_code, **kwargs):
    if swallow_http_errors and self.request.method == "POST":
        exc_info = kwargs.get("exc_info")
        if exc_info:
            reason = getattr(exc_info[1], "log_message", None)
            if reason:
                self._reason = reason
        self.result.update(status=self._reason)
        self.set_status(200)
        self.finish(self.result)
    else:
        super(IndexHandler, self).write_error(status_code, **kwargs)

NotFoundHandler (MixinHandler, ErrorHandler)

Source code in wizardwebssh/handler.py
class NotFoundHandler(MixinHandler, tornado.web.ErrorHandler):
    def initialize(self):
        super(NotFoundHandler, self).initialize()

    def prepare(self):
        raise tornado.web.HTTPError(404)

prepare(self)

Called at the beginning of a request before get/post/etc.

Override this method to perform common initialization regardless of the request method.

Asynchronous support: Use async def or decorate this method with .gen.coroutine to make it asynchronous. If this method returns an Awaitable execution will not proceed until the Awaitable is done.

.. versionadded:: 3.1 Asynchronous support.

Source code in wizardwebssh/handler.py
def prepare(self):
    raise tornado.web.HTTPError(404)

WsockHandler (MixinHandler, WebSocketHandler)

Source code in wizardwebssh/handler.py
class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
    def initialize(self, loop):
        super(WsockHandler, self).initialize(loop)
        self.worker_ref = None

    def open(self):
        self.src_addr = self.get_client_addr()
        logging.info("Connected from {}:{}".format(*self.src_addr))

        workers = clients.get(self.src_addr[0])
        if not workers:
            self.close(reason="Websocket authentication failed.")
            return

        try:
            worker_id = self.get_value("id")
        except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
            self.close(reason=str(exc))
        else:
            worker = workers.get(worker_id)
            if worker:
                workers[worker_id] = None
                self.set_nodelay(True)
                worker.set_handler(self)
                self.worker_ref = weakref.ref(worker)
                self.loop.add_handler(worker.fd, worker, IOLoop.READ)
            else:
                self.close(reason="Websocket authentication failed.")

    def on_message(self, message):
        logging.debug("{!r} from {}:{}".format(message, *self.src_addr))
        worker = self.worker_ref()
        try:
            msg = json.loads(message)
        except JSONDecodeError:
            return

        if not isinstance(msg, dict):
            return

        resize = msg.get("resize")
        if resize and len(resize) == 2:
            try:
                worker.chan.resize_pty(*resize)
            except (TypeError, struct.error, paramiko.SSHException):
                pass

        data = msg.get("data")
        if data and isinstance(data, UnicodeType):
            worker.data_to_dst.append(data)
            worker.on_write()

    def on_close(self):
        logging.info("Disconnected from {}:{}".format(*self.src_addr))
        if not self.close_reason:
            self.close_reason = "client disconnected"

        worker = self.worker_ref() if self.worker_ref else None
        if worker:
            worker.close(reason=self.close_reason)

on_close(self)

Invoked when the WebSocket is closed.

If the connection was closed cleanly and a status code or reason phrase was supplied, these values will be available as the attributes self.close_code and self.close_reason.

.. versionchanged:: 4.0

Added close_code and close_reason attributes.

Source code in wizardwebssh/handler.py
def on_close(self):
    logging.info("Disconnected from {}:{}".format(*self.src_addr))
    if not self.close_reason:
        self.close_reason = "client disconnected"

    worker = self.worker_ref() if self.worker_ref else None
    if worker:
        worker.close(reason=self.close_reason)

on_message(self, message)

Handle incoming messages on the WebSocket

This method must be overridden.

.. versionchanged:: 4.5

on_message can be a coroutine.

Source code in wizardwebssh/handler.py
def on_message(self, message):
    logging.debug("{!r} from {}:{}".format(message, *self.src_addr))
    worker = self.worker_ref()
    try:
        msg = json.loads(message)
    except JSONDecodeError:
        return

    if not isinstance(msg, dict):
        return

    resize = msg.get("resize")
    if resize and len(resize) == 2:
        try:
            worker.chan.resize_pty(*resize)
        except (TypeError, struct.error, paramiko.SSHException):
            pass

    data = msg.get("data")
    if data and isinstance(data, UnicodeType):
        worker.data_to_dst.append(data)
        worker.on_write()

open(self)

Invoked when a new WebSocket is opened.

The arguments to open are extracted from the tornado.web.URLSpec regular expression, just like the arguments to tornado.web.RequestHandler.get.

open may be a coroutine. on_message will not be called until open has returned.

.. versionchanged:: 5.1

open may be a coroutine.

Source code in wizardwebssh/handler.py
def open(self):
    self.src_addr = self.get_client_addr()
    logging.info("Connected from {}:{}".format(*self.src_addr))

    workers = clients.get(self.src_addr[0])
    if not workers:
        self.close(reason="Websocket authentication failed.")
        return

    try:
        worker_id = self.get_value("id")
    except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
        self.close(reason=str(exc))
    else:
        worker = workers.get(worker_id)
        if worker:
            workers[worker_id] = None
            self.set_nodelay(True)
            worker.set_handler(self)
            self.worker_ref = weakref.ref(worker)
            self.loop.add_handler(worker.fd, worker, IOLoop.READ)
        else:
            self.close(reason="Websocket authentication failed.")

get_default_ssh_connection_data(database_name, connection)

Gets ssh connection data for default connection.

Parameters:

Name Type Description Default
database_name

QSqlDatabase connection or database to use.

required
connection

ssh connection to get data for.

required

Returns: dictionary of ssh connection data

Source code in wizardwebssh/handler.py
def get_default_ssh_connection_data(database_name, connection):
    """
    Gets ssh connection data for default connection.

    Args:
        database_name (): QSqlDatabase connection or database to use.
        connection (): ssh connection to get data for.

    Returns: dictionary of ssh connection data

    """
    # query = f"SELECT * from sshconfig where sshconnectionname ='{connection}'"
    query = query = f"""
                SELECT ssh_group_name, ssh_connection_name,ssh_username,ssh_password,Host,HostName,Port,ProxyCommand,
                sshkey_name,sshkey_passphrase,sshkey_public,sshkey_private,sshkey_public_file,sshkey_private_file,
                ssh_config_name,ssh_config_content
                FROM sshconnections
                JOIN sshkeys ON sshconnections.ssh_key_id = sshkeys.id
                JOIN sshgroup ON sshconnections.ssh_group_id = sshgroup.id
                JOIN sshconfig ON sshconnections.ssh_config_id = sshconfig.id
                WHERE ssh_connection_name =  '{connection}'
                """
    return get_query_as_dict(query, database_name)

get_query_as_dict(query, database_name)

Get QSqlQuery for single record as a dictionary with the columns as the keys.

Parameters:

Name Type Description Default
query

QSqlQuery to use.

required
database_name

QSqlDatabase connection or database to use.

required

Returns: dictionary with key values of the query.

Source code in wizardwebssh/handler.py
def get_query_as_dict(query, database_name):
    """
    Get QSqlQuery for single record as a dictionary with the columns as the keys.

    Args:
        query (): QSqlQuery to use.
        database_name (): QSqlDatabase connection or database to use.

    Returns: dictionary with key values of the query.

    """
    row_values = {}
    columns_names_mapping = {}
    try:
        # columns_names_mapping = {}
        # row_values = {}
        q = QSqlQuery(f"{query}", db=database_name)
        rec = q.record()
        if q.exec():
            if q.first():
                for column in range(rec.count()):
                    # print(column)
                    field = rec.fieldName(column)
                    value = q.value(column)
                    columns_names_mapping[field] = column
                    row_values[field] = str(value)
                # print(str(row_values))
                return row_values
    except Exception as e:
        print(f"Exception: {e}")
    return row_values

paramiko_host_info(host)

Get SSH host information via Paramiko Parser for terminal.

Parameters:

Name Type Description Default
host

SSH host from ssh config file

required

Returns: dictionary with connection info parsed from config

Source code in wizardwebssh/handler.py
def paramiko_host_info(host):
    """
    Get SSH host information via Paramiko Parser for terminal.
    Args:
        host (): SSH host from ssh config file

    Returns: dictionary with connection info parsed from config

    """
    ssh_config = paramiko.SSHConfig()
    user_config_file = os.path.expanduser("~/.ssh/config")
    if os.path.exists(user_config_file):
        with open(user_config_file) as f:
            ssh_config.parse(f)
            o = ssh_config.lookup(host)
            # print(o)
    # setup template
    con = {
        "ssh_group_name": "default",
        "ssh_connection_name": host,
        "ssh_username": "",
        "ssh_password": "",
        "Host": host,
        "HostName": "",
        "Port": "22",
        "ProxyCommand": "",
        "sshkey_name": "None",
        "sshkey_passphrase": "",
        "sshkey_public": "",
        "sshkey_private": "",
        "sshkey_public_file": "",
        "sshkey_private_file": "",
        "ssh_config_name": "paramiko",
        "ssh_config_content": "",
    }

    if o:
        if "hostname" in o.keys():
            con.update(HostName=o["hostname"])
        if "user" in o.keys():
            con.update(ssh_username=o["user"])
        # if 'identityfile' in o.keys():
        #     ident = o['identityfile']
        #     if type(ident) is list:
        #         ident = ident[0]
        #         con.update(sshkey_private_file=ident)
        # if 'key_filename' in o.keys():
        #     ident = o['identityfile']
        #     con.update(sshkey_private_file=ident)
        if "port" in o.keys():
            con.update(Port=o["port"])
        if "proxyjump" in o.keys():
            con.update(ProxyCommand=o["proxyjump"])
    print(
        f"Printing paramiko host lookup dict: {con} from {os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])))}"
    )
    return con