diff --git a/nmt/nmt.py b/nmt/nmt.py index f5823d893..fce143bf8 100644 --- a/nmt/nmt.py +++ b/nmt/nmt.py @@ -620,10 +620,6 @@ def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): num_workers = flags.num_workers utils.print_out("# Job id %d" % jobid) - # GPU device - utils.print_out( - "# Devices visible to TensorFlow: %s" % repr(tf.Session().list_devices())) - # Random random_seed = flags.random_seed if random_seed is not None and random_seed > 0: @@ -653,6 +649,14 @@ def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): out_dir, default_hparams, flags.hparams_path, save_hparams=(jobid == 0)) + # GPU device + config_proto = utils.get_config_proto( + allow_soft_placement=True, + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads) + utils.print_out( + "# Devices visible to TensorFlow: %s" % repr(tf.Session(config=config_proto).list_devices())) + ## Train / Decode if flags.inference_input_file: # Inference output directory