Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datasets): Fix credentials handling in pandas GBQ datasets #1047

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
- Fixed `polars.CSVDataset` `save` method on Windows using `utf-8` as default encoding.
- Made `table_name` a keyword argument in the `ibis.FileDataset` implementation to be compatible with Ibis 10.0.
- Fixed how sessions are handled in the `snowflake.SnowflakeTableDataset` implementation.
- Fixed credentials handling in `pandas.GBQQueryDataset` and `pandas.GBQTableDataset`

## Breaking Changes

40 changes: 24 additions & 16 deletions kedro-datasets/kedro_datasets/pandas/gbq_dataset.py
Original file line number Diff line number Diff line change
@@ -11,9 +11,10 @@
import fsspec
import pandas as pd
import pandas_gbq as pd_gbq
from google.auth.credentials import Credentials
from google.cloud import bigquery
from google.cloud.exceptions import NotFound
from google.oauth2.credentials import Credentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from kedro.io.core import (
AbstractDataset,
DatasetError,
@@ -25,6 +26,15 @@
from kedro_datasets._utils import ConnectionMixin


def _get_credentials(credentials: dict[str, Any] | str) -> ServiceAccountCredentials:
# If dict: Assume it's a service account json
if isinstance(credentials, dict):
return ServiceAccountCredentials.from_service_account_info(credentials)

# If str: Assume it's a path to a service account key json file
return ServiceAccountCredentials.from_service_account_file(credentials)


class GBQTableDataset(ConnectionMixin, AbstractDataset[None, pd.DataFrame]):
"""``GBQTableDataset`` loads and saves data from/to Google BigQuery.
It uses pandas-gbq to read and write from/to BigQuery table.
@@ -78,7 +88,7 @@ def __init__( # noqa: PLR0913
dataset: str,
table_name: str,
project: str | None = None,
credentials: dict[str, Any] | Credentials | None = None,
credentials: dict[str, Any] | str | Credentials | None = None,
load_args: dict[str, Any] | None = None,
save_args: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
@@ -92,10 +102,9 @@ def __init__( # noqa: PLR0913
Optional when available from the environment.
https://cloud.google.com/resource-manager/docs/creating-managing-projects
credentials: Credentials for accessing Google APIs.
Either ``google.auth.credentials.Credentials`` object or dictionary with
parameters required to instantiate ``google.oauth2.credentials.Credentials``.
Here you can find all the arguments:
https://google-auth.readthedocs.io/en/latest/reference/google.oauth2.credentials.html
Either a credential that bases on ``google.auth.credentials.Credentials`` OR
a service account json as a dictionary OR
a path to a service account key json file.
load_args: Pandas options for loading BigQuery table into DataFrame.
Here you can find all available arguments:
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_gbq.html
@@ -121,6 +130,10 @@ def __init__( # noqa: PLR0913
self._dataset = dataset
self._table_name = table_name
self._project_id = project

if (not isinstance(credentials, Credentials)) and (credentials is not None):
credentials = _get_credentials(credentials)

self._connection_config = {
"project": self._project_id,
"credentials": credentials,
@@ -138,14 +151,9 @@ def _describe(self) -> dict[str, Any]:
}

def _connect(self) -> bigquery.Client:
credentials = self._connection_config["credentials"]
if isinstance(credentials, dict):
# Only create `Credentials` object once for consistent hash.
credentials = Credentials(**credentials)

return bigquery.Client(
project=self._connection_config["project"],
credentials=credentials,
credentials=self._connection_config["credentials"],
location=self._connection_config["location"],
)

@@ -276,10 +284,10 @@ def __init__( # noqa: PLR0913

self._project_id = project

if isinstance(credentials, dict):
credentials = Credentials(**credentials)

self._credentials = credentials
if (not isinstance(credentials, Credentials)) and (credentials is not None):
self._credentials = _get_credentials(credentials)
else:
self._credentials = credentials

# load sql query from arg or from file
if sql:
62 changes: 56 additions & 6 deletions kedro-datasets/tests/pandas/test_gbq_dataset.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from kedro.io.core import DatasetError
from pandas.testing import assert_frame_equal

from kedro_datasets._utils import ConnectionMixin
from kedro_datasets.pandas import GBQQueryDataset, GBQTableDataset

DATASET = "dataset"
@@ -191,11 +192,12 @@ def test_validation_of_dataset_and_table_name(self, dataset, table_name):
with pytest.raises(DatasetError, match=pattern):
GBQTableDataset(dataset=dataset, table_name=table_name)

def test_credentials_propagation(self, mocker):
# NOTE: tests for json and filepath are not DRY, keeping them separate for clarity
def test_credentials_propagation_json(self, mocker):
credentials = {"token": "my_token"}
credentials_obj = "credentials"
mocked_credentials = mocker.patch(
"kedro_datasets.pandas.gbq_dataset.Credentials",
"kedro_datasets.pandas.gbq_dataset.ServiceAccountCredentials.from_service_account_info",
return_value=credentials_obj,
)
mocked_bigquery = mocker.patch("kedro_datasets.pandas.gbq_dataset.bigquery")
@@ -208,11 +210,39 @@ def test_credentials_propagation(self, mocker):
)
dataset.exists() # Do something to trigger the client creation.

mocked_credentials.assert_called_once_with(**credentials)
mocked_credentials.assert_called_once_with(credentials)
mocked_bigquery.Client.assert_called_once_with(
project=PROJECT, credentials=credentials_obj, location=None
)

# Clear connections
ConnectionMixin._connections = {}

def test_credentials_propagation_filepath(self, mocker):
credentials = "path/to/credentials.json"
credentials_obj = "credentials"
mocked_credentials = mocker.patch(
"kedro_datasets.pandas.gbq_dataset.ServiceAccountCredentials.from_service_account_file",
return_value=credentials_obj,
)
mocked_bigquery = mocker.patch("kedro_datasets.pandas.gbq_dataset.bigquery")

dataset = GBQTableDataset(
dataset=DATASET,
table_name=TABLE_NAME,
credentials=credentials,
project=PROJECT,
)
dataset.exists() # Do something to trigger the client creation.

mocked_credentials.assert_called_once_with(credentials)
mocked_bigquery.Client.assert_called_once_with(
project=PROJECT, credentials=credentials_obj, location=None
)

# Clear connections
ConnectionMixin._connections = {}


class TestGBQQueryDataset:
def test_empty_query_error(self):
@@ -232,11 +262,12 @@ def test_load_extra_params(self, gbq_sql_dataset, load_args):
for key, value in load_args.items():
assert gbq_sql_dataset._load_args[key] == value

def test_credentials_propagation(self, mocker):
# NOTE: tests for json and filepath are not DRY, keeping them separate for clarity
def test_credentials_propagation_json(self, mocker):
credentials = {"token": "my_token"}
credentials_obj = "credentials"
mocked_credentials = mocker.patch(
"kedro_datasets.pandas.gbq_dataset.Credentials",
"kedro_datasets.pandas.gbq_dataset.ServiceAccountCredentials.from_service_account_info",
return_value=credentials_obj,
)

@@ -245,9 +276,28 @@ def test_credentials_propagation(self, mocker):
credentials=credentials,
project=PROJECT,
)
dataset.exists() # Do something to trigger the client creation.

assert dataset._credentials == credentials_obj
mocked_credentials.assert_called_once_with(credentials)

def test_credentials_propagation_filepath(self, mocker):
credentials = "path/to/credentials.json"
credentials_obj = "credentials"
mocked_credentials = mocker.patch(
"kedro_datasets.pandas.gbq_dataset.ServiceAccountCredentials.from_service_account_file",
return_value=credentials_obj,
)

dataset = GBQQueryDataset(
sql=SQL_QUERY,
credentials=credentials,
project=PROJECT,
)
dataset.exists() # Do something to trigger the client creation.

assert dataset._credentials == credentials_obj
mocked_credentials.assert_called_once_with(**credentials)
mocked_credentials.assert_called_once_with(credentials)

def test_load(self, mocker, gbq_sql_dataset, dummy_dataframe):
"""Test `load` method invocation"""