Skip to content

Commit

Permalink
Create credentials upload plugins for GCP and AWS (#438)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbalabka committed Oct 4, 2024
1 parent 6eaf2db commit 418801a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
57 changes: 57 additions & 0 deletions dask_cloudprovider/aws/plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import os
from pathlib import Path

from distributed import WorkerPlugin

logger = logging.getLogger(__name__)

class UploadAWSCredentials(WorkerPlugin):
# """Automatically upload a GCP key to the worker."""

name = "upload_aws_credentials"

def __init__(self):
"""
Initialize the plugin by reading in the data from the given file.
"""
config_path = os.getenv("AWS_CONFIG_FILE", Path.home() / Path(".aws/config"))
credentials_path = os.getenv(
"AWS_SHARED_CREDENTIALS_FILE", Path.home() / Path(".aws/credentials")
)
config_path, credentials_path = Path(config_path), Path(credentials_path)

if not config_path.exists():
raise ValueError(
f"Config file {config_path} does not exist. If you store AWS config "
"in a different location, please set AWS_CONFIG_FILE environment variable."
)

if not credentials_path.exists():
raise ValueError(
f"Credentials file {credentials_path} does not exist. If you store AWS credentials "
"in a different location, please set AWS_SHARED_CREDENTIALS_FILE environment variable."
)

self.config_filename = config_path.name
self.credentials_filename = credentials_path.name

with open(config_path, "rb") as f:
self.config = f.read()
with open(credentials_path, "rb") as f:
self.credentials = f.read()

async def setup(self, worker):
await worker.upload_file(
filename=self.config_filename, data=self.config, load=False
)
worker_config_path = Path(worker.local_directory) / self.config_filename
os.environ["AWS_CONFIG_FILE"] = str(worker_config_path)

await worker.upload_file(
filename=self.credentials_filename, data=self.credentials, load=False
)
worker_credentials_path = (
Path(worker.local_directory) / self.credentials_filename
)
os.environ["AWS_SHARED_CREDENTIALS_FILE"] = str(worker_credentials_path)
36 changes: 36 additions & 0 deletions dask_cloudprovider/gcp/plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import logging
import os
from pathlib import Path

from distributed import WorkerPlugin
from google.auth._cloud_sdk import get_application_default_credentials_path

logger = logging.getLogger(__name__)

class UploadGCPKey(WorkerPlugin):
"""Automatically upload a GCP key to the worker."""

name = "upload_gcp_key"

def __init__(self):
"""
Initialize the plugin by reading in the data from the given file.
"""
key_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
if key_path is None:
key_path = Path(get_application_default_credentials_path())
if not key_path.exists():
raise ValueError("GOOGLE_APPLICATION_CREDENTIALS is not set or `gcloud auth application-default login` wasn't run.")

key_path = Path(key_path)
self.filename = key_path.name

logger.info("Uploading GCP key from %s.", str(key_path))

with open(key_path, "rb") as f:
self.data = f.read()

async def setup(self, worker):
await worker.upload_file(filename=self.filename, data=self.data, load=False)
worker_key_path = Path(worker.local_directory) / self.filename
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(worker_key_path)

0 comments on commit 418801a

Please sign in to comment.