-
-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create credentials upload plugins for GCP and AWS (#438)
- Loading branch information
Showing
2 changed files
with
93 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
from distributed import WorkerPlugin | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class UploadAWSCredentials(WorkerPlugin): | ||
# """Automatically upload a GCP key to the worker.""" | ||
|
||
name = "upload_aws_credentials" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize the plugin by reading in the data from the given file. | ||
""" | ||
config_path = os.getenv("AWS_CONFIG_FILE", Path.home() / Path(".aws/config")) | ||
credentials_path = os.getenv( | ||
"AWS_SHARED_CREDENTIALS_FILE", Path.home() / Path(".aws/credentials") | ||
) | ||
config_path, credentials_path = Path(config_path), Path(credentials_path) | ||
|
||
if not config_path.exists(): | ||
raise ValueError( | ||
f"Config file {config_path} does not exist. If you store AWS config " | ||
"in a different location, please set AWS_CONFIG_FILE environment variable." | ||
) | ||
|
||
if not credentials_path.exists(): | ||
raise ValueError( | ||
f"Credentials file {credentials_path} does not exist. If you store AWS credentials " | ||
"in a different location, please set AWS_SHARED_CREDENTIALS_FILE environment variable." | ||
) | ||
|
||
self.config_filename = config_path.name | ||
self.credentials_filename = credentials_path.name | ||
|
||
with open(config_path, "rb") as f: | ||
self.config = f.read() | ||
with open(credentials_path, "rb") as f: | ||
self.credentials = f.read() | ||
|
||
async def setup(self, worker): | ||
await worker.upload_file( | ||
filename=self.config_filename, data=self.config, load=False | ||
) | ||
worker_config_path = Path(worker.local_directory) / self.config_filename | ||
os.environ["AWS_CONFIG_FILE"] = str(worker_config_path) | ||
|
||
await worker.upload_file( | ||
filename=self.credentials_filename, data=self.credentials, load=False | ||
) | ||
worker_credentials_path = ( | ||
Path(worker.local_directory) / self.credentials_filename | ||
) | ||
os.environ["AWS_SHARED_CREDENTIALS_FILE"] = str(worker_credentials_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
from distributed import WorkerPlugin | ||
from google.auth._cloud_sdk import get_application_default_credentials_path | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class UploadGCPKey(WorkerPlugin): | ||
"""Automatically upload a GCP key to the worker.""" | ||
|
||
name = "upload_gcp_key" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize the plugin by reading in the data from the given file. | ||
""" | ||
key_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") | ||
if key_path is None: | ||
key_path = Path(get_application_default_credentials_path()) | ||
if not key_path.exists(): | ||
raise ValueError("GOOGLE_APPLICATION_CREDENTIALS is not set or `gcloud auth application-default login` wasn't run.") | ||
|
||
key_path = Path(key_path) | ||
self.filename = key_path.name | ||
|
||
logger.info("Uploading GCP key from %s.", str(key_path)) | ||
|
||
with open(key_path, "rb") as f: | ||
self.data = f.read() | ||
|
||
async def setup(self, worker): | ||
await worker.upload_file(filename=self.filename, data=self.data, load=False) | ||
worker_key_path = Path(worker.local_directory) / self.filename | ||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(worker_key_path) |