Skip to content

Commit faa37ff

Browse files
authored
feat: lakeformation tags for columns support (dbt-labs#185)
1 parent 514830d commit faa37ff

File tree

8 files changed

+171
-65
lines changed

8 files changed

+171
-65
lines changed

README.md

+16-13
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,19 @@ stored login info. You can configure the AWS profile name to use via `aws_profil
6060

6161
A dbt profile can be configured to run against AWS Athena using the following configuration:
6262

63-
| Option | Description | Required? | Example |
64-
|------------------|--------------------------------------------------------------------------------|-------------|-----------------------|
65-
| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` |
66-
| s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` |
67-
| s3_data_naming | How to generate table paths in `s3_data_dir` | Optional | `schema_table_unique` |
68-
| region_name | AWS region of your Athena instance | Required | `eu-west-1` |
69-
| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` |
70-
| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` |
71-
| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` |
72-
| aws_profile_name | Profile to use from your AWS shared credentials file. | Optional | `my-profile` |
73-
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
74-
| num_retries | Number of times to retry a failing query | Optional | `3` |
75-
| lf_tags | Default lf tags to apply to any database created by dbt | Optional | `{"origin": "dbt", "team": "analytics"}`|
63+
| Option | Description | Required? | Example |
64+
|------------------|--------------------------------------------------------------------------------|-----------|------------------------------------------|
65+
| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` |
66+
| s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` |
67+
| s3_data_naming | How to generate table paths in `s3_data_dir` | Optional | `schema_table_unique` |
68+
| region_name | AWS region of your Athena instance | Required | `eu-west-1` |
69+
| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` |
70+
| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` |
71+
| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` |
72+
| aws_profile_name | Profile to use from your AWS shared credentials file. | Optional | `my-profile` |
73+
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
74+
| num_retries | Number of times to retry a failing query | Optional | `3` |
75+
| lf_tags | Default lf tags to apply to any database created by dbt | Optional | `{"origin": "dbt", "team": "analytics"}` |
7676

7777
**Example profiles.yml entry:**
7878
```yaml
@@ -125,6 +125,9 @@ _Additional information_
125125
* `lf_tags` (`default=none`)
126126
* lf tags to associate with the table
127127
* format: `{"tag1": "value1", "tag2": "value2"}`
128+
* `lf_tags_columns` (`default=none`)
129+
* lf tags to associate with the table columns
130+
* format: `{"tag1": {"value1": ["column1": "column2"]}}`
128131

129132
#### Table location
130133

dbt/adapters/athena/impl.py

+68-43
Original file line numberDiff line numberDiff line change
@@ -56,59 +56,84 @@ def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
5656
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
5757
return "timestamp"
5858

59+
@classmethod
60+
def parse_lf_response(
61+
cls,
62+
response: Dict[str, Any],
63+
database: str,
64+
table: Optional[str],
65+
columns: Optional[List[str]],
66+
lf_tags: Dict[str, str],
67+
) -> str:
68+
failures = response.get("Failures", [])
69+
tbl_appendix = f".{table}" if table else ""
70+
columns_appendix = f" for columns {columns}" if columns else ""
71+
msg_appendix = tbl_appendix + columns_appendix
72+
if failures:
73+
base_msg = f"Failed to add LF tags: {lf_tags} to {database}" + msg_appendix
74+
for failure in failures:
75+
tag = failure.get("LFTag", {}).get("TagKey")
76+
error = failure.get("Error", {}).get("ErrorMessage")
77+
logger.error(f"Failed to set {tag} for {database}" + msg_appendix + f" - {error}")
78+
raise DbtRuntimeError(base_msg)
79+
return f"Added LF tags: {lf_tags} to {database}" + msg_appendix
80+
81+
@classmethod
82+
def lf_tags_columns_is_valid(cls, lf_tags_columns: Dict[str, Dict[str, List[str]]]) -> Optional[bool]:
83+
if not lf_tags_columns:
84+
return False
85+
for tag_key, tag_config in lf_tags_columns.items():
86+
if isinstance(tag_config, Dict):
87+
for tag_value, columns in tag_config.items():
88+
if not isinstance(columns, List):
89+
raise DbtRuntimeError(f"Not a list: {columns}. " + "Expected format: ['c1', 'c2']")
90+
else:
91+
raise DbtRuntimeError(f"Not a dict: {tag_config}. " + "Expected format: {'tag_value': ['c1', 'c2']}")
92+
return True
93+
5994
# TODO: Add more lf-tag unit tests when moto supports lakeformation
6095
# moto issue: https://github.com/getmoto/moto/issues/5964
6196
@available
62-
def add_lf_tags(self, database: str, table: str = None, lf_tags: Dict[str, str] = None):
97+
def add_lf_tags(
98+
self,
99+
database: str,
100+
table: str = None,
101+
lf_tags: Optional[Dict[str, str]] = None,
102+
lf_tags_columns: Optional[Dict[str, Dict[str, List[str]]]] = None,
103+
):
63104
conn = self.connections.get_thread_connection()
64105
client = conn.handle
65106

66107
lf_tags = lf_tags or conn.credentials.lf_tags
67-
if not lf_tags:
68-
logger.debug("No LF tags configured")
69-
return
70-
71-
resource = {
72-
"Database": {"Name": database},
73-
}
74108

75-
if table:
76-
resource = {
77-
"Table": {
78-
"DatabaseName": database,
79-
"Name": table,
80-
}
81-
}
82-
83-
with boto3_client_lock:
84-
lf_client = client.session.client(
85-
"lakeformation", region_name=client.region_name, config=get_boto3_config()
86-
)
109+
if not lf_tags and not lf_tags_columns:
110+
logger.debug("No LF tags configured")
111+
else:
112+
with boto3_client_lock:
113+
lf_client = client.session.client(
114+
"lakeformation", region_name=client.region_name, config=get_boto3_config()
115+
)
87116

88-
response = lf_client.add_lf_tags_to_resource(
89-
Resource=resource,
90-
LFTags=[
91-
{
92-
"TagKey": key,
93-
"TagValues": [
94-
value,
95-
],
96-
}
97-
for key, value in lf_tags.items()
98-
],
99-
)
117+
if lf_tags:
118+
resource = {"Database": {"Name": database}}
119+
if table:
120+
resource = {"Table": {"DatabaseName": database, "Name": table}}
100121

101-
failures = response.get("Failures", [])
102-
tbl_appendix = f".{table}" if table else ""
103-
if failures:
104-
base_msg = f"Failed to add LF tags: {lf_tags} to {database}" + tbl_appendix
105-
for failure in failures:
106-
tag = failure.get("LFTag", {}).get("TagKey")
107-
error = failure.get("Error", {}).get("ErrorMessage")
108-
logger.error(f"Failed to set {tag} for {database}" + tbl_appendix + f" - {error}")
109-
raise DbtRuntimeError(base_msg)
110-
else:
111-
logger.debug(f"Added LF tags: {lf_tags} to {database}" + tbl_appendix)
122+
response = lf_client.add_lf_tags_to_resource(
123+
Resource=resource, LFTags=[{"TagKey": key, "TagValues": [value]} for key, value in lf_tags.items()]
124+
)
125+
logger.debug(self.parse_lf_response(response, database, table, None, lf_tags))
126+
127+
if self.lf_tags_columns_is_valid(lf_tags_columns):
128+
for tag_key, tag_config in lf_tags_columns.items():
129+
for tag_value, columns in tag_config.items():
130+
response = lf_client.add_lf_tags_to_resource(
131+
Resource={
132+
"TableWithColumns": {"DatabaseName": database, "Name": table, "ColumnNames": columns}
133+
},
134+
LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}],
135+
)
136+
logger.debug(self.parse_lf_response(response, database, table, columns, {tag_key: tag_value}))
112137

113138
@available
114139
def get_work_group_output_location(self) -> Optional[str]:

dbt/include/athena/macros/materializations/models/incremental/incremental.sql

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
{% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %}
77

88
{% set lf_tags = config.get('lf_tags', default=none) %}
9+
{% set lf_tags_columns = config.get('lf_tags_columns', default=none) %}
910
{% set partitioned_by = config.get('partitioned_by', default=none) %}
1011
{% set target_relation = this.incorporate(type='table') %}
1112
{% set existing_relation = load_relation(this) %}
@@ -84,8 +85,8 @@
8485

8586
{{ run_hooks(post_hooks, inside_transaction=False) }}
8687

87-
{% if lf_tags is not none %}
88-
{{ adapter.add_lf_tags(target_relation.schema, target_relation.identifier, lf_tags) }}
88+
{% if lf_tags is not none or lf_tags_columns is not none %}
89+
{{ adapter.add_lf_tags(target_relation.schema, target_relation.identifier, lf_tags, lf_tags_columns) }}
8990
{% endif %}
9091

9192
{{ return({'relations': [target_relation]}) }}

dbt/include/athena/macros/materializations/models/table/table.sql

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{%- set identifier = model['alias'] -%}
33

44
{%- set lf_tags = config.get('lf_tags', default=none) -%}
5+
{%- set lf_tags_columns = config.get('lf_tags_columns', default=none) -%}
56
{%- set table_type = config.get('table_type', default='hive') | lower -%}
67
{%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%}
78
{%- set target_relation = api.Relation.create(identifier=identifier,
@@ -27,8 +28,8 @@
2728

2829
{{ run_hooks(post_hooks) }}
2930

30-
{% if lf_tags is not none %}
31-
{{ adapter.add_lf_tags(target_relation.schema, identifier, lf_tags) }}
31+
{% if lf_tags is not none or lf_tags_columns is not none %}
32+
{{ adapter.add_lf_tags(target_relation.schema, identifier, lf_tags, lf_tags_columns) }}
3233
{% endif %}
3334

3435
{% do persist_docs(target_relation, model) %}

dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{%- set identifier = model['alias'] -%}
33

44
{%- set lf_tags = config.get('lf_tags', default=none) -%}
5+
{%- set lf_tags_columns = config.get('lf_tags_columns', default=none) -%}
56
{%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%}
67
{%- set exists_as_view = (old_relation is not none and old_relation.is_view) -%}
78
{%- set target_relation = api.Relation.create(
@@ -29,8 +30,8 @@
2930
{{ create_view_as(target_relation, sql) }}
3031
{%- endcall %}
3132

32-
{% if lf_tags is not none %}
33-
{{ adapter.add_lf_tags(target_relation.schema, identifier, lf_tags) }}
33+
{% if lf_tags is not none or lf_tags_columns is not none %}
34+
{{ adapter.add_lf_tags(target_relation.schema, identifier, lf_tags, lf_tags_columns) }}
3435
{% endif %}
3536

3637
{{ run_hooks(post_hooks, inside_transaction=True) }}

dbt/include/athena/macros/materializations/seeds/helpers.sql

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
{%- set identifier = model['alias'] -%}
1212

1313
{%- set lf_tags = config.get('lf_tags', default=none) -%}
14+
{%- set lf_tags_columns = config.get('lf_tags_columns', default=none) -%}
1415
{%- set column_override = config.get('column_types', {}) -%}
1516
{%- set quote_seed_column = config.get('quote_columns', None) -%}
1617
{%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%}
@@ -35,8 +36,8 @@
3536
{{ sql }}
3637
{%- endcall %}
3738

38-
{% if lf_tags is not none %}
39-
{{ adapter.add_lf_tags(model.schema, identifier, lf_tags) }}
39+
{% if lf_tags is not none or lf_tags_columns is not none %}
40+
{{ adapter.add_lf_tags(model.schema, identifier, lf_tags, lf_tags_columns) }}
4041
{% endif %}
4142

4243
{{ return(sql) }}

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ include = '\.pyi?$'
1212
[tool.flake8]
1313
files = '.*\.py'
1414
max-line-length = 120
15-
exclude = ['.git', '.eggs', '__pycache__', 'venv']
15+
exclude = ['.git', '.eggs', '__pycache__', 'venv', '.venv']
1616
ignore = [
1717
# space before : (needed for how black formats slicing)
1818
'E203',

tests/unit/test_adapter.py

+74
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,80 @@ def test_get_columns_in_relation(self):
793793
Column("dt", "date"),
794794
]
795795

796+
@pytest.mark.parametrize(
797+
"response,database,table,columns,lf_tags,expected",
798+
[
799+
pytest.param(
800+
{
801+
"Failures": [
802+
{
803+
"LFTag": {"CatalogId": "test_catalog", "TagKey": "test_key", "TagValues": ["test_values"]},
804+
"Error": {"ErrorCode": "test_code", "ErrorMessage": "test_err_msg"},
805+
}
806+
]
807+
},
808+
"test_database",
809+
"test_table",
810+
["column1", "column2"],
811+
{"tag_key": "tag_value"},
812+
None,
813+
id="lf_tag error",
814+
marks=pytest.mark.xfail,
815+
),
816+
pytest.param(
817+
{"Failures": []},
818+
"test_database",
819+
None,
820+
None,
821+
{"tag_key": "tag_value"},
822+
"Added LF tags: {'tag_key': 'tag_value'} to test_database",
823+
id="lf_tag database",
824+
),
825+
pytest.param(
826+
{"Failures": []},
827+
"test_db",
828+
"test_table",
829+
None,
830+
{"tag_key": "tag_value"},
831+
"Added LF tags: {'tag_key': 'tag_value'} to test_db.test_table",
832+
id="lf_tag database and table",
833+
),
834+
pytest.param(
835+
{"Failures": []},
836+
"test_db",
837+
"test_table",
838+
["column1", "column2"],
839+
{"tag_key": "tag_value"},
840+
"Added LF tags: {'tag_key': 'tag_value'} to test_db.test_table for columns ['column1', 'column2']",
841+
id="lf_tag database table and columns",
842+
),
843+
],
844+
)
845+
def test_parse_lf_response(self, response, database, table, columns, lf_tags, expected):
846+
assert self.adapter.parse_lf_response(response, database, table, columns, lf_tags) == expected
847+
848+
@pytest.mark.parametrize(
849+
"lf_tags_columns,expected",
850+
[
851+
pytest.param({"tag_key": {"tag_value": ["col1, col2"]}}, True, id="valid lf_tags_columns"),
852+
pytest.param(None, False, id="empty lf_tags_columns"),
853+
pytest.param(
854+
{"tag_key": "tag_value"},
855+
None,
856+
id="lf_tags_columns tag config is not a dict",
857+
marks=pytest.mark.xfail(raises=DbtRuntimeError),
858+
),
859+
pytest.param(
860+
{"tag_key": {"tag_value": "col1"}},
861+
None,
862+
id="lf_tags_columns columns config is not a list",
863+
marks=pytest.mark.xfail(raises=DbtRuntimeError),
864+
),
865+
],
866+
)
867+
def test_lf_tags_columns_is_valid(self, lf_tags_columns, expected):
868+
assert self.adapter.lf_tags_columns_is_valid(lf_tags_columns) == expected
869+
796870

797871
class TestAthenaFilterCatalog:
798872
def test__catalog_filter_table(self):

0 commit comments

Comments
 (0)