Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added debug messages and use the logging framework for stdout #178

Merged
merged 5 commits into from
Apr 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flintrock/config.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ launch:
num-slaves: 1
# install-hdfs: True
# install-spark: False

debug: false
16 changes: 10 additions & 6 deletions flintrock/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shlex
import sys
import time
import logging
from concurrent.futures import FIRST_EXCEPTION

# External modules
Expand All @@ -25,6 +26,9 @@
SCRIPTS_DIR = os.path.join(THIS_DIR, 'scripts')


logger = logging.getLogger('flintrock.core')


class StorageDirs:
def __init__(self, *, root, ephemeral, persistent):
self.root = root
Expand Down Expand Up @@ -530,7 +534,7 @@ def ensure_java8(client: paramiko.client.SSHClient):
java_major_version = get_java_major_version(client)

if not java_major_version or java_major_version < (1, 8):
print("[{h}] Installing Java 1.8...".format(h=host))
logger.info("[{h}] Installing Java 1.8...".format(h=host))

ssh_check_output(
client=client,
Expand Down Expand Up @@ -583,7 +587,7 @@ def setup_node(
localpath=os.path.join(SCRIPTS_DIR, 'setup-ephemeral-storage.py'),
remotepath='/tmp/setup-ephemeral-storage.py')

print("[{h}] Configuring ephemeral storage...".format(h=host))
logger.info("[{h}] Configuring ephemeral storage...".format(h=host))
# TODO: Print some kind of warning if storage is large, since formatting
# will take several minutes (~4 minutes for 2TB).
storage_dirs_raw = ssh_check_output(
Expand Down Expand Up @@ -806,7 +810,7 @@ def run_command_node(*, user: str, host: str, identity_file: str, command: tuple
host=host,
identity_file=identity_file)

print("[{h}] Running command...".format(h=host))
logger.info("[{h}] Running command...".format(h=host))

command_str = ' '.join(command)

Expand All @@ -815,7 +819,7 @@ def run_command_node(*, user: str, host: str, identity_file: str, command: tuple
client=ssh_client,
command=command_str)

print("[{h}] Command complete.".format(h=host))
logger.info("[{h}] Command complete.".format(h=host))


def copy_file_node(
Expand Down Expand Up @@ -850,11 +854,11 @@ def copy_file_node(
raise Exception("Remote directory does not exist: {d}".format(d=remote_dir))

with ssh_client.open_sftp() as sftp:
print("[{h}] Copying file...".format(h=host))
logger.info("[{h}] Copying file...".format(h=host))

sftp.put(localpath=local_path, remotepath=remote_path)

print("[{h}] Copy complete.".format(h=host))
logger.info("[{h}] Copy complete.".format(h=host))


# This is necessary down here since we have a circular import dependency between
Expand Down
20 changes: 14 additions & 6 deletions flintrock/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import urllib.request
import base64
import logging
from collections import namedtuple
from datetime import datetime

Expand All @@ -26,6 +27,9 @@
from .ssh import generate_ssh_key_pair


logger = logging.getLogger('flintrock.ec2')


class NoDefaultVPC(Error):
def __init__(self, *, region: str):
super().__init__(
Expand All @@ -48,7 +52,7 @@ def wrapper(*args, **kwargs):
start = datetime.now().replace(microsecond=0)
res = func(*args, **kwargs)
end = datetime.now().replace(microsecond=0)
print("{f} finished in {t}.".format(f=func.__name__, t=(end - start)))
logger.info("{f} finished in {t}.".format(f=func.__name__, t=(end - start)))
return res
return wrapper

Expand Down Expand Up @@ -120,6 +124,11 @@ def wait_for_state(self, state: str):
ec2 = boto3.resource(service_name='ec2', region_name=self.region)

while any([i.state['Name'] != state for i in self.instances]):
if logger.isEnabledFor(logging.DEBUG):
waiting_instances = [i for i in self.instances if i.state['Name'] != state]
sample = ', '.join(["'{}'".format(i.id) for i in waiting_instances][:3])
logger.debug("{size} instances not in state '{state}': {sample}, ...".format(size=len(waiting_instances), state=state, sample=sample))
time.sleep(3)
# Update metadata for all instances in one shot. We don't want
# to make a call to AWS for each of potentially hundreds of
# instances.
Expand All @@ -131,7 +140,6 @@ def wait_for_state(self, state: str):
{'Name': 'instance-id', 'Values': [i.id for i in self.instances]}
]))
(self.master_instance, self.slave_instances) = _get_cluster_master_slaves(instances)
time.sleep(3)

def destroy(self):
self.destroy_check()
Expand Down Expand Up @@ -692,7 +700,7 @@ def _create_instances(
try:
if spot_price:
user_data = base64.b64encode(user_data.encode('utf-8')).decode()
print("Requesting {c} spot instances at a max price of ${p}...".format(
logger.info("Requesting {c} spot instances at a max price of ${p}...".format(
c=num_instances, p=spot_price))
client = ec2.meta.client
spot_requests = client.request_spot_instances(
Expand All @@ -717,7 +725,7 @@ def _create_instances(
pending_request_ids = request_ids

while pending_request_ids:
print("{grant} of {req} instances granted. Waiting...".format(
logger.info("{grant} of {req} instances granted. Waiting...".format(
grant=num_instances - len(pending_request_ids),
req=num_instances))
time.sleep(30)
Expand All @@ -737,7 +745,7 @@ def _create_instances(
r['SpotInstanceRequestId'] for r in spot_requests
if r['State'] == 'open']

print("All {c} instances granted.".format(c=num_instances))
logger.info("All {c} instances granted.".format(c=num_instances))

cluster_instances = list(
ec2.instances.filter(
Expand All @@ -746,7 +754,7 @@ def _create_instances(
]))
else:
# Move this to flintrock.py?
print("Launching {c} instance{s}...".format(
logger.info("Launching {c} instance{s}...".format(
c=num_instances,
s='' if num_instances == 1 else 's'))

Expand Down
63 changes: 42 additions & 21 deletions flintrock/flintrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import urllib.parse
import urllib.request
import warnings
import logging

# External modules
import click
Expand Down Expand Up @@ -38,6 +39,9 @@
THIS_DIR = os.path.dirname(os.path.realpath(__file__))


logger = logging.getLogger('flintrock.flintrock')


def format_message(*, message: str, indent: int=4, wrap: int=70):
"""
Format a lengthy message for printing to screen.
Expand Down Expand Up @@ -156,15 +160,30 @@ def get_config_file() -> str:
return config_file


def configure_log(debug: bool):
root_logger = logging.getLogger('flintrock')
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
if debug:
root_logger.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s - flintrock.%(module)-9s - %(levelname)-5s - %(message)s'))
else:
root_logger.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter('%(message)s'))
root_logger.addHandler(handler)


@click.group()
@click.option(
'--config',
help="Path to a Flintrock configuration file.",
default=get_config_file())
@click.option('--provider', default='ec2', type=click.Choice(['ec2']))
@click.version_option(version=__version__)
# TODO: implement some solution like in https://github.com/pallets/click/issues/108
@click.option('--debug/--no-debug', default=False, help="Show debug information.")
@click.pass_context
def cli(cli_context, config, provider):
def cli(cli_context, config, provider, debug):
"""
Flintrock

Expand All @@ -175,12 +194,14 @@ def cli(cli_context, config, provider):
if os.path.isfile(config):
with open(config) as f:
config_raw = yaml.safe_load(f)
debug = config_raw.get('debug') or debug
config_map = config_to_click(normalize_keys(config_raw))

cli_context.default_map = config_map
else:
if config != get_config_file():
raise FileNotFoundError(errno.ENOENT, 'No such file', config)
configure_log(debug=debug)


@cli.command()
Expand Down Expand Up @@ -327,12 +348,12 @@ def launch(
download_source=spark_download_source,
)
elif spark_git_commit:
print(
logger.warning(
"Warning: Building Spark takes a long time. "
"e.g. 15-20 minutes on an m3.xlarge instance on EC2.")
if spark_git_commit == 'latest':
spark_git_commit = get_latest_commit(spark_git_repository)
print("Building Spark at latest commit: {c}".format(c=spark_git_commit))
logger.info("Building Spark at latest commit: {c}".format(c=spark_git_commit))
spark = Spark(
git_commit=spark_git_commit,
git_repository=spark_git_repository,
Expand Down Expand Up @@ -424,7 +445,7 @@ def destroy(cli_context, cluster_name, assume_yes, ec2_region, ec2_vpc_id):
text="Are you sure you want to destroy this cluster?",
abort=True)

print("Destroying {c}...".format(c=cluster.name))
logger.info("Destroying {c}...".format(c=cluster.name))
cluster.destroy()


Expand Down Expand Up @@ -474,21 +495,21 @@ def describe(
if cluster_name:
cluster = clusters[0]
if master_hostname_only:
print(cluster.master_host)
logger.info(cluster.master_host)
else:
cluster.print()
else:
if master_hostname_only:
for cluster in sorted(clusters, key=lambda x: x.name):
print(cluster.name + ':', cluster.master_host)
logger.info(cluster.name + ':', cluster.master_host)
else:
print("Found {n} cluster{s}{space}{search_area}.".format(
logger.info("Found {n} cluster{s}{space}{search_area}.".format(
n=len(clusters),
s='' if len(clusters) == 1 else 's',
space=' ' if search_area else '',
search_area=search_area))
if clusters:
print('---')
logger.info('---')
for cluster in sorted(clusters, key=lambda x: x.name):
cluster.print()

Expand Down Expand Up @@ -572,7 +593,7 @@ def start(cli_context, cluster_name, ec2_region, ec2_vpc_id, ec2_identity_file,
raise UnsupportedProviderError(provider)

cluster.start_check()
print("Starting {c}...".format(c=cluster_name))
logger.info("Starting {c}...".format(c=cluster_name))
cluster.start(user=user, identity_file=identity_file)


Expand Down Expand Up @@ -610,9 +631,9 @@ def stop(cli_context, cluster_name, ec2_region, ec2_vpc_id, assume_yes):
text="Are you sure you want to stop this cluster?",
abort=True)

print("Stopping {c}...".format(c=cluster_name))
logger.info("Stopping {c}...".format(c=cluster_name))
cluster.stop()
print("{c} is now stopped.".format(c=cluster_name))
logger.info("{c} is now stopped.".format(c=cluster_name))


@cli.command(name='add-slaves')
Expand Down Expand Up @@ -740,7 +761,7 @@ def remove_slaves(
raise UnsupportedProviderError(provider)

if num_slaves > cluster.num_slaves:
print(
logger.warning(
"Warning: Cluster has {c} slave{cs}. "
"You asked to remove {n} slave{ns}."
.format(
Expand All @@ -759,10 +780,10 @@ def remove_slaves(
s='' if num_slaves == 1 else 's')),
abort=True)

print("Removing {n} slave{s}..."
.format(
n=num_slaves,
s='' if num_slaves == 1 else 's'))
logger.info("Removing {n} slave{s}..."
.format(
n=num_slaves,
s='' if num_slaves == 1 else 's'))
cluster.remove_slaves(
user=user,
identity_file=identity_file,
Expand Down Expand Up @@ -823,7 +844,7 @@ def run_command(

cluster.run_command_check()

print("Running command on {target}...".format(
logger.info("Running command on {target}...".format(
target="master only" if master_only else "cluster"))

cluster.run_command(
Expand Down Expand Up @@ -903,8 +924,8 @@ def copy_file(
total_size_bytes = file_size_bytes * num_nodes

if total_size_bytes > 10 ** 6:
print("WARNING:")
print(
logger.warning("WARNING:")
logger.warning(
format_message(
message="""\
You are trying to upload {total_size} bytes ({size} bytes x {count}
Expand All @@ -924,7 +945,7 @@ def copy_file(
default=True,
abort=True)

print("Copying file to {target}...".format(
logger.info("Copying file to {target}...".format(
target="master only" if master_only else "cluster"))

cluster.copy_file(
Expand Down Expand Up @@ -995,7 +1016,7 @@ def configure(cli_context, locate):
config_file = get_config_file()

if not os.path.isfile(config_file):
print("Initializing config file from template...")
logger.info("Initializing config file from template...")
os.makedirs(os.path.dirname(config_file), exist_ok=True)
shutil.copyfile(
src=os.path.join(THIS_DIR, 'config.yaml.template'),
Expand Down
Loading