Source code for dagster_ssh.resources

import getpass
import os
from io import StringIO

import paramiko
from paramiko.config import SSH_PORT
from sshtunnel import SSHTunnelForwarder

from dagster import BoolSource, Field, IntSource, StringSource
from dagster import _check as check
from dagster import resource
from dagster._utils import merge_dicts, mkdir_p


def key_from_str(key_str):
    """Creates a paramiko SSH key from a string."""
    check.str_param(key_str, "key_str")

    # py2 StringIO doesn't support with
    key_file = StringIO(key_str)
    result = paramiko.RSAKey.from_private_key(key_file)
    key_file.close()
    return result


[docs]class SSHResource: """ Resource for ssh remote execution using Paramiko. ref: https://github.com/paramiko/paramiko """ def __init__( self, remote_host, remote_port, username=None, password=None, key_file=None, key_string=None, timeout=10, keepalive_interval=30, compress=True, no_host_key_check=True, allow_host_key_change=False, logger=None, ): self.remote_host = check.str_param(remote_host, "remote_host") self.remote_port = check.opt_int_param(remote_port, "remote_port") self.username = check.opt_str_param(username, "username") self.password = check.opt_str_param(password, "password") self.key_file = check.opt_str_param(key_file, "key_file") self.timeout = check.opt_int_param(timeout, "timeout") self.keepalive_interval = check.opt_int_param(keepalive_interval, "keepalive_interval") self.compress = check.opt_bool_param(compress, "compress") self.no_host_key_check = check.opt_bool_param(no_host_key_check, "no_host_key_check") self.allow_host_key_change = check.opt_bool_param( allow_host_key_change, "allow_host_key_change" ) self.log = logger self.host_proxy = None # Create RSAKey object from private key string self.key_obj = key_from_str(key_string) if key_string is not None else None # Auto detecting username values from system if not self.username: logger.debug( "username to ssh to host: %s is not specified. Using system's default provided by" " getpass.getuser()" % self.remote_host ) self.username = getpass.getuser() user_ssh_config_filename = os.path.expanduser("~/.ssh/config") if os.path.isfile(user_ssh_config_filename): ssh_conf = paramiko.SSHConfig() ssh_conf.parse(open(user_ssh_config_filename, encoding="utf8")) host_info = ssh_conf.lookup(self.remote_host) if host_info and host_info.get("proxycommand"): self.host_proxy = paramiko.ProxyCommand(host_info.get("proxycommand")) if not (self.password or self.key_file): if host_info and host_info.get("identityfile"): self.key_file = host_info.get("identityfile")[0] def get_connection(self): """ Opens a SSH connection to the remote host. :rtype: paramiko.client.SSHClient """ client = paramiko.SSHClient() if not self.allow_host_key_change: self.log.warning( "Remote Identification Change is not verified. This won't protect against " "Man-In-The-Middle attacks" ) client.load_system_host_keys() if self.no_host_key_check: self.log.warning( "No Host Key Verification. This won't protect against Man-In-The-Middle attacks" ) # Default is RejectPolicy client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) if self.password and self.password.strip(): client.connect( hostname=self.remote_host, username=self.username, password=self.password, key_filename=self.key_file, pkey=self.key_obj, timeout=self.timeout, compress=self.compress, port=self.remote_port, sock=self.host_proxy, look_for_keys=False, ) else: client.connect( hostname=self.remote_host, username=self.username, key_filename=self.key_file, pkey=self.key_obj, timeout=self.timeout, compress=self.compress, port=self.remote_port, sock=self.host_proxy, ) if self.keepalive_interval: client.get_transport().set_keepalive(self.keepalive_interval) return client def get_tunnel(self, remote_port, remote_host="localhost", local_port=None): check.int_param(remote_port, "remote_port") check.str_param(remote_host, "remote_host") check.opt_int_param(local_port, "local_port") if local_port is not None: local_bind_address = ("localhost", local_port) else: local_bind_address = ("localhost",) # Will prefer key string if specified, otherwise use the key file pkey = self.key_obj if self.key_obj else self.key_file if self.password and self.password.strip(): client = SSHTunnelForwarder( self.remote_host, ssh_port=self.remote_port, ssh_username=self.username, ssh_password=self.password, ssh_pkey=pkey, ssh_proxy=self.host_proxy, local_bind_address=local_bind_address, remote_bind_address=(remote_host, remote_port), logger=self.log, ) else: client = SSHTunnelForwarder( self.remote_host, ssh_port=self.remote_port, ssh_username=self.username, ssh_pkey=pkey, ssh_proxy=self.host_proxy, local_bind_address=local_bind_address, remote_bind_address=(remote_host, remote_port), host_pkey_directories=[], logger=self.log, ) return client def sftp_get(self, remote_filepath, local_filepath): check.str_param(remote_filepath, "remote_filepath") check.str_param(local_filepath, "local_filepath") conn = self.get_connection() with conn.open_sftp() as sftp_client: local_folder = os.path.dirname(local_filepath) # Create intermediate directories if they don't exist mkdir_p(local_folder) self.log.info( "Starting to transfer from {0} to {1}".format(remote_filepath, local_filepath) ) sftp_client.get(remote_filepath, local_filepath) conn.close() return local_filepath def sftp_put(self, remote_filepath, local_filepath, confirm=True): check.str_param(remote_filepath, "remote_filepath") check.str_param(local_filepath, "local_filepath") conn = self.get_connection() with conn.open_sftp() as sftp_client: self.log.info( "Starting to transfer file from {0} to {1}".format(local_filepath, remote_filepath) ) sftp_client.put(local_filepath, remote_filepath, confirm=confirm) conn.close() return local_filepath
[docs]@resource( config_schema={ "remote_host": Field( StringSource, description="remote host to connect to", is_required=True ), "remote_port": Field( IntSource, description="port of remote host to connect (Default is paramiko SSH_PORT)", is_required=False, default_value=SSH_PORT, ), "username": Field( StringSource, description="username to connect to the remote_host", is_required=False ), "password": Field( StringSource, description="password of the username to connect to the remote_host", is_required=False, ), "key_file": Field( StringSource, description="key file to use to connect to the remote_host.", is_required=False, ), "key_string": Field( StringSource, description="key string to use to connect to remote_host", is_required=False, ), "timeout": Field( IntSource, description="timeout for the attempt to connect to the remote_host.", is_required=False, default_value=10, ), "keepalive_interval": Field( IntSource, description="send a keepalive packet to remote host every keepalive_interval seconds", is_required=False, default_value=30, ), "compress": Field(BoolSource, is_required=False, default_value=True), "no_host_key_check": Field(BoolSource, is_required=False, default_value=True), "allow_host_key_change": Field(BoolSource, is_required=False, default_value=False), } ) def ssh_resource(init_context): args = init_context.resource_config args = merge_dicts(init_context.resource_config, {"logger": init_context.log}) return SSHResource(**args)