import hashlib
import json
import logging
import sys
import time
from contextlib import contextmanager
from typing import Any, Dict, List, Mapping, Optional, cast
import requests
from dagster_airbyte.types import AirbyteOutput
from requests.exceptions import RequestException
from dagster import Failure, Field, StringSource
from dagster import _check as check
from dagster import get_dagster_logger, resource
from dagster._config.field_utils import Permissive
from dagster._utils.merger import deep_merge_dicts
DEFAULT_POLL_INTERVAL_SECONDS = 10
class AirbyteState:
RUNNING = "running"
SUCCEEDED = "succeeded"
CANCELLED = "cancelled"
PENDING = "pending"
FAILED = "failed"
ERROR = "error"
INCOMPLETE = "incomplete"
[docs]class AirbyteResource:
"""
This class exposes methods on top of the Airbyte REST API.
"""
def __init__(
self,
host: str,
port: str,
use_https: bool,
request_max_retries: int = 3,
request_retry_delay: float = 0.25,
request_timeout: int = 15,
request_additional_params: Optional[Mapping[str, Any]] = None,
log: logging.Logger = get_dagster_logger(),
forward_logs: bool = True,
username: Optional[str] = None,
password: Optional[str] = None,
):
self._host = host
self._port = port
self._use_https = use_https
self._request_max_retries = request_max_retries
self._request_retry_delay = request_retry_delay
self._request_timeout = request_timeout
self._additional_request_params = request_additional_params or dict()
self._log = log
self._forward_logs = forward_logs
self._request_cache: Dict[str, Optional[Mapping[str, object]]] = {}
# Int in case we nest contexts
self._cache_enabled = 0
self._username = username
self._password = password
@property
def api_base_url(self) -> str:
return (
("https://" if self._use_https else "http://")
+ (f"{self._host}:{self._port}" if self._port else self._host)
+ "/api/v1"
)
@contextmanager
def cache_requests(self):
"""
Context manager that enables caching certain requests to the Airbyte API,
cleared when the context is exited.
"""
self.clear_request_cache()
self._cache_enabled += 1
try:
yield
finally:
self.clear_request_cache()
self._cache_enabled -= 1
def clear_request_cache(self):
self._request_cache = {}
def make_request_cached(self, endpoint: str, data: Optional[Mapping[str, object]]):
if not self._cache_enabled > 0:
return self.make_request(endpoint, data)
data_json = json.dumps(data, sort_keys=True)
sha = hashlib.sha1()
sha.update(endpoint.encode("utf-8"))
sha.update(data_json.encode("utf-8"))
digest = sha.hexdigest()
if digest not in self._request_cache:
self._request_cache[digest] = self.make_request(endpoint, data)
return self._request_cache[digest]
def make_request(
self, endpoint: str, data: Optional[Mapping[str, object]]
) -> Optional[Mapping[str, object]]:
"""
Creates and sends a request to the desired Airbyte REST API endpoint.
Args:
endpoint (str): The Airbyte API endpoint to send this request to.
data (Optional[str]): JSON-formatted data string to be included in the request.
Returns:
Optional[Dict[str, Any]]: Parsed json data from the response to this request
"""
headers = {"accept": "application/json"}
num_retries = 0
while True:
try:
response = requests.request(
**deep_merge_dicts(
dict(
method="POST",
url=self.api_base_url + endpoint,
headers=headers,
json=data,
timeout=self._request_timeout,
auth=(self._username, self._password)
if self._username and self._password
else None,
),
self._additional_request_params,
),
)
response.raise_for_status()
if response.status_code == 204:
return None
return response.json()
except RequestException as e:
self._log.error("Request to Airbyte API failed: %s", e)
if num_retries == self._request_max_retries:
break
num_retries += 1
time.sleep(self._request_retry_delay)
raise Failure("Exceeded max number of retries.")
def cancel_job(self, job_id: int):
self.make_request(endpoint="/jobs/cancel", data={"id": job_id})
def get_default_workspace(self):
workspaces = cast(
List[Dict[str, Any]],
check.not_none(self.make_request_cached(endpoint="/workspaces/list", data={})).get(
"workspaces", []
),
)
return workspaces[0].get("workspaceId")
def get_source_definition_by_name(self, name: str, workspace_id: str) -> Optional[str]:
name_lower = name.lower()
definitions = self.make_request_cached(
endpoint="/source_definitions/list_for_workspace", data={"workspaceId": workspace_id}
)
return next(
(
definition["sourceDefinitionId"]
for definition in definitions["sourceDefinitions"]
if definition["name"].lower() == name_lower
),
None,
)
def get_destination_definition_by_name(self, name: str, workspace_id: str):
name_lower = name.lower()
definitions = cast(
Dict[str, List[Dict[str, str]]],
check.not_none(
self.make_request_cached(
endpoint="/destination_definitions/list_for_workspace",
data={"workspaceId": workspace_id},
)
),
)
return next(
(
definition["destinationDefinitionId"]
for definition in definitions["destinationDefinitions"]
if definition["name"].lower() == name_lower
),
None,
)
def get_source_catalog_id(self, source_id: str):
result = cast(
Dict[str, Any],
check.not_none(
self.make_request(endpoint="/sources/discover_schema", data={"sourceId": source_id})
),
)
return result["catalogId"]
def get_source_schema(self, source_id: str) -> Mapping[str, Any]:
return cast(
Dict[str, Any],
check.not_none(
self.make_request(endpoint="/sources/discover_schema", data={"sourceId": source_id})
),
)
def does_dest_support_normalization(
self, destination_definition_id: str, workspace_id: str
) -> Dict[str, Any]:
return cast(
Dict[str, Any],
check.not_none(
self.make_request_cached(
endpoint="/destination_definition_specifications/get",
data={
"destinationDefinitionId": destination_definition_id,
"workspaceId": workspace_id,
},
)
),
).get("supportsNormalization", False)
def get_job_status(self, connection_id: str, job_id: int) -> Mapping[str, object]:
if self._forward_logs:
return check.not_none(self.make_request(endpoint="/jobs/get", data={"id": job_id}))
else:
# the "list all jobs" endpoint doesn't return logs, which actually makes it much more
# lightweight for long-running syncs with many logs
out = check.not_none(
self.make_request(
endpoint="/jobs/list",
data={
"configTypes": ["sync"],
"configId": connection_id,
# sync should be the most recent, so pageSize 5 is sufficient
"pagination": {"pageSize": 5},
},
)
)
job = next((job for job in cast(List, out["jobs"]) if job["job"]["id"] == job_id), None)
return check.not_none(job)
def start_sync(self, connection_id: str) -> Mapping[str, object]:
return check.not_none(
self.make_request(endpoint="/connections/sync", data={"connectionId": connection_id})
)
def get_connection_details(self, connection_id: str) -> Mapping[str, object]:
return check.not_none(
self.make_request(endpoint="/connections/get", data={"connectionId": connection_id})
)
def sync_and_poll(
self,
connection_id: str,
poll_interval: float = DEFAULT_POLL_INTERVAL_SECONDS,
poll_timeout: Optional[float] = None,
) -> AirbyteOutput:
"""
Initializes a sync operation for the given connector, and polls until it completes.
Args:
connection_id (str): The Airbyte Connector ID. You can retrieve this value from the
"Connection" tab of a given connection in the Arbyte UI.
poll_interval (float): The time (in seconds) that will be waited between successive polls.
poll_timeout (float): The maximum time that will waited before this operation is timed
out. By default, this will never time out.
Returns:
:py:class:`~AirbyteOutput`:
Details of the sync job.
"""
connection_details = self.get_connection_details(connection_id)
job_details = self.start_sync(connection_id)
job_info = cast(Dict[str, object], job_details.get("job", {}))
job_id = cast(int, job_info.get("id"))
self._log.info(f"Job {job_id} initialized for connection_id={connection_id}.")
start = time.monotonic()
logged_attempts = 0
logged_lines = 0
state = None
try:
while True:
if poll_timeout and start + poll_timeout < time.monotonic():
raise Failure(
f"Timeout: Airbyte job {job_id} is not ready after the timeout {poll_timeout} seconds"
)
time.sleep(poll_interval)
job_details = self.get_job_status(connection_id, job_id)
attempts = cast(List, job_details.get("attempts", []))
cur_attempt = len(attempts)
# spit out the available Airbyte log info
if cur_attempt:
if self._forward_logs:
log_lines = attempts[logged_attempts].get("logs", {}).get("logLines", [])
for line in log_lines[logged_lines:]:
sys.stdout.write(line + "\n")
sys.stdout.flush()
logged_lines = len(log_lines)
# if there's a next attempt, this one will have no more log messages
if logged_attempts < cur_attempt - 1:
logged_lines = 0
logged_attempts += 1
job_info = cast(Dict[str, object], job_details.get("job", {}))
state = job_info.get("status")
if state in (AirbyteState.RUNNING, AirbyteState.PENDING, AirbyteState.INCOMPLETE):
continue
elif state == AirbyteState.SUCCEEDED:
break
elif state == AirbyteState.ERROR:
raise Failure(f"Job failed: {job_id}")
elif state == AirbyteState.CANCELLED:
raise Failure(f"Job was cancelled: {job_id}")
else:
raise Failure(f"Encountered unexpected state `{state}` for job_id {job_id}")
finally:
# if Airbyte sync has not completed, make sure to cancel it so that it doesn't outlive
# the python process
if state not in (AirbyteState.SUCCEEDED, AirbyteState.ERROR, AirbyteState.CANCELLED):
self.cancel_job(job_id)
return AirbyteOutput(job_details=job_details, connection_details=connection_details)
[docs]@resource(
config_schema={
"host": Field(
StringSource,
is_required=True,
description="The Airbyte Server Address.",
),
"port": Field(
StringSource,
is_required=True,
description="Port for the Airbyte Server.",
),
"username": Field(
StringSource,
description="Username if using basic auth.",
is_required=False,
),
"password": Field(
StringSource,
description="Password if using basic auth.",
is_required=False,
),
"use_https": Field(
bool,
default_value=False,
description="Use https to connect in Airbyte Server.",
),
"request_max_retries": Field(
int,
default_value=3,
description="The maximum number of times requests to the Airbyte API should be retried "
"before failing.",
),
"request_retry_delay": Field(
float,
default_value=0.25,
description="Time (in seconds) to wait between each request retry.",
),
"request_timeout": Field(
int,
default_value=15,
description="Time (in seconds) after which the requests to Airbyte are declared timed out.",
),
"request_additional_params": Field(
Permissive(),
description="Any additional kwargs to pass to the requests library when making requests to Airbyte.",
),
"forward_logs": Field(
bool,
default_value=True,
description="Whether to forward Airbyte logs to the compute log, can be expensive for long-running syncs.",
),
},
description="This resource helps manage Airbyte connectors",
)
def airbyte_resource(context) -> AirbyteResource:
"""
This resource allows users to programatically interface with the Airbyte REST API to launch
syncs and monitor their progress. This currently implements only a subset of the functionality
exposed by the API.
For a complete set of documentation on the Airbyte REST API, including expected response JSON
schema, see the `Airbyte API Docs <https://airbyte-public-api-docs.s3.us-east-2.amazonaws.com/rapidoc-api-docs.html#overview>`_.
To configure this resource, we recommend using the `configured
<https://docs.dagster.io/concepts/configuration/configured>`_ method.
**Examples:**
.. code-block:: python
from dagster import job
from dagster_airbyte import airbyte_resource
my_airbyte_resource = airbyte_resource.configured(
{
"host": {"env": "AIRBYTE_HOST"},
"port": {"env": "AIRBYTE_PORT"},
# If using basic auth
"username": {"env": "AIRBYTE_USERNAME"},
"password": {"env": "AIRBYTE_PASSWORD"},
}
)
@job(resource_defs={"airbyte":my_airbyte_resource})
def my_airbyte_job():
...
"""
return AirbyteResource(
host=context.resource_config["host"],
port=context.resource_config["port"],
use_https=context.resource_config["use_https"],
request_max_retries=context.resource_config["request_max_retries"],
request_retry_delay=context.resource_config["request_retry_delay"],
request_timeout=context.resource_config["request_timeout"],
request_additional_params=context.resource_config["request_additional_params"],
log=context.log,
forward_logs=context.resource_config["forward_logs"],
username=context.resource_config.get("username"),
password=context.resource_config.get("password"),
)