Skip to content

Commit

Permalink
Added debug messages and use the logging framework for stdout (#178)
Browse files Browse the repository at this point in the history
* Added debug messages and use the logging framework for stdout
* Let debug defaults be configurable and code cleanup
  • Loading branch information
douglaz authored and nchammas committed Apr 2, 2017
1 parent 3d53326 commit 818fa91
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 40 deletions.
2 changes: 2 additions & 0 deletions flintrock/config.yaml.template
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ launch:
num-slaves: 1
# install-hdfs: True
# install-spark: False

debug: false
16 changes: 10 additions & 6 deletions flintrock/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shlex
import sys
import time
import logging
from concurrent.futures import FIRST_EXCEPTION

# External modules
Expand All @@ -25,6 +26,9 @@
SCRIPTS_DIR = os.path.join(THIS_DIR, 'scripts')


logger = logging.getLogger('flintrock.core')


class StorageDirs:
def __init__(self, *, root, ephemeral, persistent):
self.root = root
Expand Down Expand Up @@ -530,7 +534,7 @@ def ensure_java8(client: paramiko.client.SSHClient):
java_major_version = get_java_major_version(client)

if not java_major_version or java_major_version < (1, 8):
print("[{h}] Installing Java 1.8...".format(h=host))
logger.info("[{h}] Installing Java 1.8...".format(h=host))

ssh_check_output(
client=client,
Expand Down Expand Up @@ -583,7 +587,7 @@ def setup_node(
localpath=os.path.join(SCRIPTS_DIR, 'setup-ephemeral-storage.py'),
remotepath='/tmp/setup-ephemeral-storage.py')

print("[{h}] Configuring ephemeral storage...".format(h=host))
logger.info("[{h}] Configuring ephemeral storage...".format(h=host))
# TODO: Print some kind of warning if storage is large, since formatting
# will take several minutes (~4 minutes for 2TB).
storage_dirs_raw = ssh_check_output(
Expand Down Expand Up @@ -806,7 +810,7 @@ def run_command_node(*, user: str, host: str, identity_file: str, command: tuple
host=host,
identity_file=identity_file)

print("[{h}] Running command...".format(h=host))
logger.info("[{h}] Running command...".format(h=host))

command_str = ' '.join(command)

Expand All @@ -815,7 +819,7 @@ def run_command_node(*, user: str, host: str, identity_file: str, command: tuple
client=ssh_client,
command=command_str)

print("[{h}] Command complete.".format(h=host))
logger.info("[{h}] Command complete.".format(h=host))


def copy_file_node(
Expand Down Expand Up @@ -850,11 +854,11 @@ def copy_file_node(
raise Exception("Remote directory does not exist: {d}".format(d=remote_dir))

with ssh_client.open_sftp() as sftp:
print("[{h}] Copying file...".format(h=host))
logger.info("[{h}] Copying file...".format(h=host))

sftp.put(localpath=local_path, remotepath=remote_path)

print("[{h}] Copy complete.".format(h=host))
logger.info("[{h}] Copy complete.".format(h=host))


# This is necessary down here since we have a circular import dependency between
Expand Down
20 changes: 14 additions & 6 deletions flintrock/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import urllib.request
import base64
import logging
from collections import namedtuple
from datetime import datetime

Expand All @@ -26,6 +27,9 @@
from .ssh import generate_ssh_key_pair


logger = logging.getLogger('flintrock.ec2')


class NoDefaultVPC(Error):
def __init__(self, *, region: str):
super().__init__(
Expand All @@ -48,7 +52,7 @@ def wrapper(*args, **kwargs):
start = datetime.now().replace(microsecond=0)
res = func(*args, **kwargs)
end = datetime.now().replace(microsecond=0)
print("{f} finished in {t}.".format(f=func.__name__, t=(end - start)))
logger.info("{f} finished in {t}.".format(f=func.__name__, t=(end - start)))
return res
return wrapper

Expand Down Expand Up @@ -120,6 +124,11 @@ def wait_for_state(self, state: str):
ec2 = boto3.resource(service_name='ec2', region_name=self.region)

while any([i.state['Name'] != state for i in self.instances]):
if logger.isEnabledFor(logging.DEBUG):
waiting_instances = [i for i in self.instances if i.state['Name'] != state]
sample = ', '.join(["'{}'".format(i.id) for i in waiting_instances][:3])
logger.debug("{size} instances not in state '{state}': {sample}, ...".format(size=len(waiting_instances), state=state, sample=sample))
time.sleep(3)
# Update metadata for all instances in one shot. We don't want
# to make a call to AWS for each of potentially hundreds of
# instances.
Expand All @@ -131,7 +140,6 @@ def wait_for_state(self, state: str):
{'Name': 'instance-id', 'Values': [i.id for i in self.instances]}
]))
(self.master_instance, self.slave_instances) = _get_cluster_master_slaves(instances)
time.sleep(3)

def destroy(self):
self.destroy_check()
Expand Down Expand Up @@ -692,7 +700,7 @@ def _create_instances(
try:
if spot_price:
user_data = base64.b64encode(user_data.encode('utf-8')).decode()
print("Requesting {c} spot instances at a max price of ${p}...".format(
logger.info("Requesting {c} spot instances at a max price of ${p}...".format(
c=num_instances, p=spot_price))
client = ec2.meta.client
spot_requests = client.request_spot_instances(
Expand All @@ -717,7 +725,7 @@ def _create_instances(
pending_request_ids = request_ids

while pending_request_ids:
print("{grant} of {req} instances granted. Waiting...".format(
logger.info("{grant} of {req} instances granted. Waiting...".format(
grant=num_instances - len(pending_request_ids),
req=num_instances))
time.sleep(30)
Expand All @@ -737,7 +745,7 @@ def _create_instances(
r['SpotInstanceRequestId'] for r in spot_requests
if r['State'] == 'open']

print("All {c} instances granted.".format(c=num_instances))
logger.info("All {c} instances granted.".format(c=num_instances))

cluster_instances = list(
ec2.instances.filter(
Expand All @@ -746,7 +754,7 @@ def _create_instances(
]))
else:
# Move this to flintrock.py?
print("Launching {c} instance{s}...".format(
logger.info("Launching {c} instance{s}...".format(
c=num_instances,
s='' if num_instances == 1 else 's'))

Expand Down
63 changes: 42 additions & 21 deletions flintrock/flintrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import urllib.parse
import urllib.request
import warnings
import logging

# External modules
import click
Expand Down Expand Up @@ -38,6 +39,9 @@
THIS_DIR = os.path.dirname(os.path.realpath(__file__))


logger = logging.getLogger('flintrock.flintrock')


def format_message(*, message: str, indent: int=4, wrap: int=70):
"""
Format a lengthy message for printing to screen.
Expand Down Expand Up @@ -156,15 +160,30 @@ def get_config_file() -> str:
return config_file


def configure_log(debug: bool):
root_logger = logging.getLogger('flintrock')
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
if debug:
root_logger.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s - flintrock.%(module)-9s - %(levelname)-5s - %(message)s'))
else:
root_logger.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter('%(message)s'))
root_logger.addHandler(handler)


@click.group()
@click.option(
'--config',
help="Path to a Flintrock configuration file.",
default=get_config_file())
@click.option('--provider', default='ec2', type=click.Choice(['ec2']))
@click.version_option(version=__version__)
# TODO: implement some solution like in https://github.com/pallets/click/issues/108
@click.option('--debug/--no-debug', default=False, help="Show debug information.")
@click.pass_context
def cli(cli_context, config, provider):
def cli(cli_context, config, provider, debug):
"""
Flintrock
Expand All @@ -175,12 +194,14 @@ def cli(cli_context, config, provider):
if os.path.isfile(config):
with open(config) as f:
config_raw = yaml.safe_load(f)
debug = config_raw.get('debug') or debug
config_map = config_to_click(normalize_keys(config_raw))

cli_context.default_map = config_map
else:
if config != get_config_file():
raise FileNotFoundError(errno.ENOENT, 'No such file', config)
configure_log(debug=debug)


@cli.command()
Expand Down Expand Up @@ -327,12 +348,12 @@ def launch(
download_source=spark_download_source,
)
elif spark_git_commit:
print(
logger.warning(
"Warning: Building Spark takes a long time. "
"e.g. 15-20 minutes on an m3.xlarge instance on EC2.")
if spark_git_commit == 'latest':
spark_git_commit = get_latest_commit(spark_git_repository)
print("Building Spark at latest commit: {c}".format(c=spark_git_commit))
logger.info("Building Spark at latest commit: {c}".format(c=spark_git_commit))
spark = Spark(
git_commit=spark_git_commit,
git_repository=spark_git_repository,
Expand Down Expand Up @@ -424,7 +445,7 @@ def destroy(cli_context, cluster_name, assume_yes, ec2_region, ec2_vpc_id):
text="Are you sure you want to destroy this cluster?",
abort=True)

print("Destroying {c}...".format(c=cluster.name))
logger.info("Destroying {c}...".format(c=cluster.name))
cluster.destroy()


Expand Down Expand Up @@ -474,21 +495,21 @@ def describe(
if cluster_name:
cluster = clusters[0]
if master_hostname_only:
print(cluster.master_host)
logger.info(cluster.master_host)
else:
cluster.print()
else:
if master_hostname_only:
for cluster in sorted(clusters, key=lambda x: x.name):
print(cluster.name + ':', cluster.master_host)
logger.info(cluster.name + ':', cluster.master_host)
else:
print("Found {n} cluster{s}{space}{search_area}.".format(
logger.info("Found {n} cluster{s}{space}{search_area}.".format(
n=len(clusters),
s='' if len(clusters) == 1 else 's',
space=' ' if search_area else '',
search_area=search_area))
if clusters:
print('---')
logger.info('---')
for cluster in sorted(clusters, key=lambda x: x.name):
cluster.print()

Expand Down Expand Up @@ -572,7 +593,7 @@ def start(cli_context, cluster_name, ec2_region, ec2_vpc_id, ec2_identity_file,
raise UnsupportedProviderError(provider)

cluster.start_check()
print("Starting {c}...".format(c=cluster_name))
logger.info("Starting {c}...".format(c=cluster_name))
cluster.start(user=user, identity_file=identity_file)


Expand Down Expand Up @@ -610,9 +631,9 @@ def stop(cli_context, cluster_name, ec2_region, ec2_vpc_id, assume_yes):
text="Are you sure you want to stop this cluster?",
abort=True)

print("Stopping {c}...".format(c=cluster_name))
logger.info("Stopping {c}...".format(c=cluster_name))
cluster.stop()
print("{c} is now stopped.".format(c=cluster_name))
logger.info("{c} is now stopped.".format(c=cluster_name))


@cli.command(name='add-slaves')
Expand Down Expand Up @@ -740,7 +761,7 @@ def remove_slaves(
raise UnsupportedProviderError(provider)

if num_slaves > cluster.num_slaves:
print(
logger.warning(
"Warning: Cluster has {c} slave{cs}. "
"You asked to remove {n} slave{ns}."
.format(
Expand All @@ -759,10 +780,10 @@ def remove_slaves(
s='' if num_slaves == 1 else 's')),
abort=True)

print("Removing {n} slave{s}..."
.format(
n=num_slaves,
s='' if num_slaves == 1 else 's'))
logger.info("Removing {n} slave{s}..."
.format(
n=num_slaves,
s='' if num_slaves == 1 else 's'))
cluster.remove_slaves(
user=user,
identity_file=identity_file,
Expand Down Expand Up @@ -823,7 +844,7 @@ def run_command(

cluster.run_command_check()

print("Running command on {target}...".format(
logger.info("Running command on {target}...".format(
target="master only" if master_only else "cluster"))

cluster.run_command(
Expand Down Expand Up @@ -903,8 +924,8 @@ def copy_file(
total_size_bytes = file_size_bytes * num_nodes

if total_size_bytes > 10 ** 6:
print("WARNING:")
print(
logger.warning("WARNING:")
logger.warning(
format_message(
message="""\
You are trying to upload {total_size} bytes ({size} bytes x {count}
Expand All @@ -924,7 +945,7 @@ def copy_file(
default=True,
abort=True)

print("Copying file to {target}...".format(
logger.info("Copying file to {target}...".format(
target="master only" if master_only else "cluster"))

cluster.copy_file(
Expand Down Expand Up @@ -995,7 +1016,7 @@ def configure(cli_context, locate):
config_file = get_config_file()

if not os.path.isfile(config_file):
print("Initializing config file from template...")
logger.info("Initializing config file from template...")
os.makedirs(os.path.dirname(config_file), exist_ok=True)
shutil.copyfile(
src=os.path.join(THIS_DIR, 'config.yaml.template'),
Expand Down
Loading

0 comments on commit 818fa91

Please sign in to comment.