diff --git a/pylearn2/train.py b/pylearn2/train.py index 7f384177b6..3e61f9b7a3 100644 --- a/pylearn2/train.py +++ b/pylearn2/train.py @@ -94,6 +94,15 @@ def setup_extensions(self): for ext in self.extensions: ext.setup(self.model, self.dataset, self.algorithm) + def tear_down_extensions(self): + """ Calls tear_down on all extensions.""" + for ext in self.extensions: + try: + ext.tear_down(self.model, self.dataset, self.algorithm) + except Exception: + log.debug('%s train extension failed to terminate gracefully', + exc_info=True) + def exceeded_time_budget(self, t0, time_budget): """ .. todo:: @@ -126,6 +135,12 @@ def setup(self): # make sure the constraints are enforced from the start. self.model.enforce_constraints() + def tear_down(self): + """ + Called at the end of main loop. + """ + self.tear_down_extensions() + def main_loop(self, time_budget=None): """ Repeatedly runs an epoch of the training algorithm, runs any @@ -139,95 +154,103 @@ def main_loop(self, time_budget=None): """ t0 = datetime.now() self.setup() - if self.algorithm is None: - self.run_callbacks_and_monitoring() - while True: - if self.exceeded_time_budget(t0, time_budget): - break + try: + if self.algorithm is None: + self.run_callbacks_and_monitoring() + while True: + if self.exceeded_time_budget(t0, time_budget): + break - rval = self.model.train_all(dataset=self.dataset) - if rval is not None: - raise ValueError("Model.train_all should not return " + - "anything. Use Model.continue_learning " + - "to control whether learning continues.") - self.model.monitor.report_epoch() - extension_continue = self.run_callbacks_and_monitoring() - freq = self.save_freq - if freq > 0 and self.model.monitor.get_epochs_seen() % freq == 0: - self.save() - continue_learning = (self.model.continue_learning() and - extension_continue) - assert continue_learning in [True, False, 0, 1] - if not continue_learning: - break - else: - if not hasattr(self.model, 'monitor'): - # TODO: is this really necessary? I just put this error here - # to prevent an AttributeError later, but I think we could - # rewrite to avoid the AttributeError - raise RuntimeError("The algorithm is responsible for setting" - " up the Monitor, but failed to.") - if len(self.model.monitor._datasets) > 0: - # This monitoring channel keeps track of a shared variable, - # which does not need inputs nor data. - self.training_seconds.__doc__ = """\ + rval = self.model.train_all(dataset=self.dataset) + if rval is not None: + raise ValueError("Model.train_all should not return " + + "anything. Use Model.continue_learning " + + "to control whether learning continues.") + self.model.monitor.report_epoch() + extension_continue = self.run_callbacks_and_monitoring() + freq = self.save_freq + if freq > 0 and self.model.monitor.get_epochs_seen() % freq == 0: + self.save() + continue_learning = (self.model.continue_learning() and + extension_continue) + assert continue_learning in [True, False, 0, 1] + if not continue_learning: + break + else: + if not hasattr(self.model, 'monitor'): + # TODO: is this really necessary? I just put this error here + # to prevent an AttributeError later, but I think we could + # rewrite to avoid the AttributeError + raise RuntimeError("The algorithm is responsible for setting" + " up the Monitor, but failed to.") + if len(self.model.monitor._datasets) > 0: + # This monitoring channel keeps track of a shared variable, + # which does not need inputs nor data. + self.training_seconds.__doc__ = """\ The number of seconds that were spent in actual training during the most recent epoch. This excludes seconds that were spent running callbacks for the extensions, computing monitoring channels, etc.""" - self.model.monitor.add_channel( - name="training_seconds_this_epoch", - ipt=None, - val=self.training_seconds, - data_specs=(NullSpace(), ''), - dataset=self.model.monitor._datasets[0]) - self.total_seconds.__doc__ = """\ + self.model.monitor.add_channel( + name="training_seconds_this_epoch", + ipt=None, + val=self.training_seconds, + data_specs=(NullSpace(), ''), + dataset=self.model.monitor._datasets[0]) + self.total_seconds.__doc__ = """\ The number of seconds that were spent on the entirety of processing for the previous epoch. This includes not only training but also the computation of the monitoring channels, running TrainExtension callbacks, etc. This value is reported for the *previous* epoch because the amount of time spent on monitoring for this epoch is not known until the monitoring channels have already been reported.""" - self.model.monitor.add_channel( - name="total_seconds_last_epoch", - ipt=None, - val=self.total_seconds, - data_specs=(NullSpace(), ''), - dataset=self.model.monitor._datasets[0]) - self.run_callbacks_and_monitoring() + self.model.monitor.add_channel( + name="total_seconds_last_epoch", + ipt=None, + val=self.total_seconds, + data_specs=(NullSpace(), ''), + dataset=self.model.monitor._datasets[0]) + self.run_callbacks_and_monitoring() - while True: - if self.exceeded_time_budget(t0, time_budget): - break + while True: + if self.exceeded_time_budget(t0, time_budget): + break - with log_timing(log, None, level=logging.DEBUG, - callbacks=[self.total_seconds.set_value]): - with log_timing( - log, None, final_msg='Time this epoch:', - callbacks=[self.training_seconds.set_value]): - rval = self.algorithm.train(dataset=self.dataset) - if rval is not None: - raise ValueError("TrainingAlgorithm.train should not " - "return anything. Use " - "TrainingAlgorithm.continue_learning " - "to control whether learning " - "continues.") - self.model.monitor.report_epoch() - extension_continue = self.run_callbacks_and_monitoring() - if self.save_freq > 0 and \ - self.model.monitor.get_epochs_seen() % self.save_freq == 0: - self.save() - continue_learning = ( - self.algorithm.continue_learning(self.model) and - extension_continue - ) - assert continue_learning in [True, False, 0, 1] - if not continue_learning: - break + with log_timing(log, None, level=logging.DEBUG, + callbacks=[self.total_seconds.set_value]): + with log_timing( + log, None, final_msg='Time this epoch:', + callbacks=[self.training_seconds.set_value]): + rval = self.algorithm.train(dataset=self.dataset) + if rval is not None: + raise ValueError("TrainingAlgorithm.train should not " + "return anything. Use " + "TrainingAlgorithm.continue_learning " + "to control whether learning " + "continues.") + self.model.monitor.report_epoch() + extension_continue = self.run_callbacks_and_monitoring() + if self.save_freq > 0 and \ + self.model.monitor.get_epochs_seen() % self.save_freq == 0: + self.save() + continue_learning = ( + self.algorithm.continue_learning(self.model) and + extension_continue + ) + assert continue_learning in [True, False, 0, 1] + if not continue_learning: + break - self.model.monitor.training_succeeded = True + self.model.monitor.training_succeeded = True - if self.save_freq > 0: - self.save() + if self.save_freq > 0: + self.save() + except Exception: + self.tear_down() + log.error("Uncaught exception in Train's main loop", + exc_info=True) + raise + else: + self.tear_down() def run_callbacks_and_monitoring(self): """ diff --git a/pylearn2/train_extensions/__init__.py b/pylearn2/train_extensions/__init__.py index 27f24427aa..3e24863591 100644 --- a/pylearn2/train_extensions/__init__.py +++ b/pylearn2/train_extensions/__init__.py @@ -78,6 +78,23 @@ def setup(self, model, dataset, algorithm): used to train the model. """ + def tear_down(self, model, dataset, algorithm): + """ + Train calls this after the main loop. + + Parameters + ---------- + model : pylearn2.models.Model + The model object being trained. + + dataset : pylearn2.datasets.Dataset + The dataset object being trained. + + algorithm : pylearn2.training_algorithms.TrainingAlgorithm + The object representing the training algorithm being + used to train the model. + """ + class SharedSetter(TrainExtension): """ Sets shared variables to take on the specified values after the diff --git a/pylearn2/train_extensions/live_monitoring.py b/pylearn2/train_extensions/live_monitoring.py index ea4a6e15fd..ad2b76f84c 100644 --- a/pylearn2/train_extensions/live_monitoring.py +++ b/pylearn2/train_extensions/live_monitoring.py @@ -170,22 +170,32 @@ def __init__(self, address='*', req_port=5555, pub_port=5556): assert(pub_port > 1024 and pub_port < 65536) self.pub_port = pub_port - address_template = self.address + ':%d' + self.address_template = self.address + ':%d' self.context = zmq.Context() - self.req_sock = None - if self.req_port > 0: - self.req_sock = self.context.socket(zmq.REP) - self.req_sock.bind(address_template % self.req_port) - self.pub_sock = None - if self.pub_port > 0: - self.pub_sock = self.context.socket(zmq.PUB) - self.req_sock.bind(address_template % self.pub_port) # Tracks the number of times on_monitor has been called self.counter = 0 + @wraps(TrainExtension.setup) + def setup(self, model, dataset, algorithm): + if self.req_port > 0: + self.req_sock = self.context.socket(zmq.REP) + self.req_sock.bind(self.address_template % self.req_port) + if self.pub_port > 0: + self.pub_sock = self.context.socket(zmq.PUB) + self.req_sock.bind(self.address_template % self.pub_port) + + @wraps(TrainExtension.tear_down) + def tear_down(self, model, dataset, algorithm): + if self.req_sock: + self.req_sock.unbind(self.address_template % self.req_port) + self.req_sock = None + if self.pub_sock: + self.req_sock.unbind(self.address_template % self.pub_port) + self.pub_sock = None + @wraps(TrainExtension.on_monitor) def on_monitor(self, model, dataset, algorithm): monitor = Monitor.get_monitor(model)