Skip to content

Commit

Permalink
do a big linting update and fix settings and add ci linter
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiascadee committed Aug 28, 2024
1 parent 360c379 commit 29c887a
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 175 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
### A CI workflow template that runs linting and python testing
### TODO: Modify as needed or as desired.

name: Test target-redshift

on:
pull_request:
types: ["opened", "synchronize", "reopened"]

jobs:
linter:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
# pytest:
# runs-on: ubuntu-latest
# env:
# GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
# strategy:
# matrix:
# python-version: ["3.8", "3.9", "3.10", "3.11"]
# steps:
# - uses: actions/checkout@v3
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v4
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install Poetry
# run: |
# pip install poetry
# - name: Install dependencies
# run: |
# poetry install
# - name: Test with pytest
# run: |
# poetry run pytest
30 changes: 0 additions & 30 deletions .github/workflows/test.yml

This file was deleted.

127 changes: 73 additions & 54 deletions target_redshift/connector.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
"""Common SQL connectors for Streams and Sinks."""

from __future__ import annotations

import typing as t
from contextlib import contextmanager
from typing import cast

from contextlib import contextmanager
import boto3

from singer_sdk.typing import _jsonschema_type_check
from singer_sdk import typing as th
import redshift_connector
from redshift_connector import Cursor
from singer_sdk.connectors import SQLConnector
from singer_sdk.helpers._typing import get_datelike_property_type
from redshift_connector import Cursor
import redshift_connector

from singer_sdk.typing import _jsonschema_type_check
from sqlalchemy import DDL, Column, MetaData, Table
from sqlalchemy.engine.url import URL
from sqlalchemy_redshift.dialect import SUPER, BIGINT, VARCHAR, DOUBLE_PRECISION
from sqlalchemy.schema import CreateSchema, CreateTable, DropTable
from sqlalchemy.types import (
BOOLEAN,
DATE,
DATETIME,
DECIMAL,
TIME,
VARCHAR,
TypeEngine,
)
from sqlalchemy.schema import CreateTable, DropTable, CreateSchema
from sqlalchemy.types import TypeEngine
from sqlalchemy import Table, MetaData, DDL, Column
from sqlalchemy_redshift.dialect import BIGINT, DOUBLE_PRECISION, SUPER, VARCHAR


class RedshiftConnector(SQLConnector):
Expand All @@ -44,6 +40,7 @@ def prepare_schema(self, schema_name: str, cursor: Cursor) -> None:
Args:
schema_name: The target schema name.
cursor: The database cursor.
"""
schema_exists = self.schema_exists(schema_name)
if not schema_exists:
Expand All @@ -54,11 +51,24 @@ def create_schema(self, schema_name: str, cursor: Cursor) -> None:
Args:
schema_name: The target schema to create.
cursor: The database cursor.
"""
cursor.execute(str(CreateSchema(schema_name)))

@contextmanager
def _connect_cursor(self) -> t.Iterator[Cursor]:
def connect_cursor(self) -> t.Iterator[Cursor]:
"""Connect to a redshift connector cursor.
Returns:
-------
t.Iterator[Cursor]
A redshift connector cursor.
Yields:
------
Iterator[t.Iterator[Cursor]]
A redshift connector cursor.
"""
user, password = self.get_credentials()
with redshift_connector.connect(
user=user,
Expand All @@ -71,14 +81,13 @@ def _connect_cursor(self) -> t.Iterator[Cursor]:
yield cursor
connection.commit()

def prepare_table( # type: ignore[override]
def prepare_table( # type: ignore[override] # noqa: D417
self,
full_table_name: str,
schema: dict,
primary_keys: t.Sequence[str],
cursor: Cursor,
partition_keys: list[str] | None = None,
as_temp_table: bool = False,
as_temp_table: bool = False, # noqa: FBT001, FBT002
) -> Table:
"""Adapt target table to provided schema if possible.
Expand All @@ -87,7 +96,6 @@ def prepare_table( # type: ignore[override]
schema: the JSON Schema for the table.
primary_keys: list of key properties.
connection: the database connection.
partition_keys: list of partition keys.
as_temp_table: True to create a temp table.
Returns:
Expand Down Expand Up @@ -116,7 +124,6 @@ def prepare_table( # type: ignore[override]
meta=meta,
schema=schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
as_temp_table=as_temp_table,
cursor=cursor,
)
Expand Down Expand Up @@ -149,14 +156,15 @@ def copy_table_structure(
full_table_name: str,
from_table: Table,
cursor: Cursor,
as_temp_table: bool = False,
as_temp_table: bool = False, # noqa: FBT001, FBT002
) -> Table:
"""Copy table structure.
Args:
full_table_name: the target table name potentially including schema
from_table: the source table
connection: the database connection.
cursor: A redshift connector cursor.
as_temp_table: True to create a temp table.
Returns:
Expand All @@ -165,26 +173,27 @@ def copy_table_structure(
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = MetaData(schema=schema_name)
new_table: Table
columns = []
if self.table_exists(full_table_name=full_table_name):
raise RuntimeError("Table already exists")
for column in from_table.columns:
columns.append(column._copy())
msg = "Table already exists"
raise RuntimeError(msg)
columns = [column._copy() for column in from_table.columns] # noqa: SLF001
if as_temp_table:
new_table = Table(table_name, meta, *columns, prefixes=["TEMPORARY"])
else:
new_table = Table(table_name, meta, *columns)

create_table_ddl = str(CreateTable(new_table).compile(dialect=self._engine.dialect))
create_table_ddl = str(
CreateTable(new_table).compile(dialect=self._engine.dialect)
)
cursor.execute(create_table_ddl)
return new_table

def drop_table(self, table: Table, cursor: Cursor):
def drop_table(self, table: Table, cursor: Cursor) -> None:
"""Drop table data."""
drop_table_ddl = str(DropTable(table).compile(dialect=self._engine.dialect))
cursor.execute(drop_table_ddl)

def to_sql_type(self, jsonschema_type: dict) -> TypeEngine:
def to_sql_type(self, jsonschema_type: dict) -> TypeEngine: # noqa: PLR0911
"""Convert JSON Schema type to a SQL type.
Args:
Expand Down Expand Up @@ -216,15 +225,14 @@ def to_sql_type(self, jsonschema_type: dict) -> TypeEngine:

return VARCHAR(self.default_varchar_length)

def create_empty_table( # type: ignore[override]
def create_empty_table( # type: ignore[override] # noqa: PLR0913
self,
table_name: str,
meta: MetaData,
schema: dict,
cursor: Cursor,
primary_keys: t.Sequence[str] | None = None,
partition_keys: list[str] | None = None,
as_temp_table: bool = False,
as_temp_table: bool = False, # noqa: FBT001, FBT002
) -> Table:
"""Create an empty target table.
Expand All @@ -249,7 +257,11 @@ def create_empty_table( # type: ignore[override]
try:
properties: dict = schema["properties"]
except KeyError:
raise RuntimeError(f"Schema for table_name: '{table_name}'" f"does not define properties: {schema}")
msg = (
f"Schema for table_name: '{table_name}'"
f"does not define properties: {schema}"
)
raise RuntimeError(msg) # noqa: B904

for property_name, property_jsonschema in properties.items():
is_primary_key = property_name in primary_keys
Expand All @@ -266,7 +278,9 @@ def create_empty_table( # type: ignore[override]
else:
new_table = Table(table_name, meta, *columns)

create_table_ddl = str(CreateTable(new_table).compile(dialect=self._engine.dialect))
create_table_ddl = str(
CreateTable(new_table).compile(dialect=self._engine.dialect)
)
cursor.execute(create_table_ddl)
return new_table

Expand All @@ -288,7 +302,9 @@ def prepare_column(
column_object: a SQLAlchemy column. optional.
"""
column_name = column_name.lower().replace(" ", "_")
column_exists = column_object is not None or self.column_exists(full_table_name, column_name)
column_exists = column_object is not None or self.column_exists(
full_table_name, column_name
)

if not column_exists:
self._create_empty_column(
Expand Down Expand Up @@ -321,6 +337,7 @@ def _create_empty_column(
full_table_name: The target table name.
column_name: The name of the new column.
sql_type: SQLAlchemy type engine to be used in creating the new column.
cursor: a database cursor.
Raises:
NotImplementedError: if adding columns is not supported.
Expand Down Expand Up @@ -361,7 +378,10 @@ def get_column_add_ddl( # type: ignore[override]
column = Column(column_name, column_type)

return DDL(
('ALTER TABLE "%(schema_name)s"."%(table_name)s"' "ADD COLUMN %(column_name)s %(column_type)s"),
(
'ALTER TABLE "%(schema_name)s"."%(table_name)s"'
"ADD COLUMN %(column_name)s %(column_type)s"
),
{
"schema_name": schema_name,
"table_name": table_name,
Expand All @@ -383,6 +403,7 @@ def _adapt_column_type(
full_table_name: The target table name.
column_name: The target column name.
sql_type: The new SQLAlchemy type.
cursor: a database cursor.
Raises:
NotImplementedError: if altering columns is not supported.
Expand Down Expand Up @@ -452,7 +473,10 @@ def get_column_alter_ddl( # type: ignore[override]
"""
column = Column(column_name, column_type)
return DDL(
('ALTER TABLE "%(schema_name)s"."%(table_name)s"' "ALTER COLUMN %(column_name)s %(column_type)s"),
(
'ALTER TABLE "%(schema_name)s"."%(table_name)s"'
"ALTER COLUMN %(column_name)s %(column_type)s"
),
{
"schema_name": schema_name,
"table_name": table_name,
Expand All @@ -467,20 +491,17 @@ def get_sqlalchemy_url(self, config: dict) -> str:
Args:
config: The configuration for the connector.
"""
if config.get("sqlalchemy_url"):
return cast(str, config["sqlalchemy_url"])
else:
user, password = self.get_credentials()
sqlalchemy_url = URL.create(
drivername=config["dialect+driver"],
username=user,
password=password,
host=config["host"],
port=config["port"],
database=config["dbname"],
query=self.get_sqlalchemy_query(config),
)
return cast(str, sqlalchemy_url)
user, password = self.get_credentials()
sqlalchemy_url = URL.create(
drivername="redshift+redshift_connector",
username=user,
password=password,
host=config["host"],
port=config["port"],
database=config["dbname"],
query=self.get_sqlalchemy_query(config),
)
return cast(str, sqlalchemy_url)

def get_sqlalchemy_query(self, config: dict) -> dict:
"""Get query values to be used for sqlalchemy URL creation.
Expand All @@ -500,9 +521,9 @@ def get_sqlalchemy_query(self, config: dict) -> dict:
return query

def get_credentials(self) -> tuple[str, str]:
"""Use boto3 to get temporary cluster credentials
"""Use boto3 to get temporary cluster credentials.
Returns
Returns:
-------
tuple[str, str]
username and password
Expand All @@ -517,6 +538,4 @@ def get_credentials(self) -> tuple[str, str]:
AutoCreate=False,
)
return response["DbUser"], response["DbPassword"]
else:
return self.config["user"], self.config["password"]

return self.config["user"], self.config["password"]
Loading

0 comments on commit 29c887a

Please sign in to comment.