diff --git a/compiler_opt/es/blackbox_evaluator.py b/compiler_opt/es/blackbox_evaluator.py index 049ca5b9..cdd43f38 100644 --- a/compiler_opt/es/blackbox_evaluator.py +++ b/compiler_opt/es/blackbox_evaluator.py @@ -64,30 +64,56 @@ class SamplingBlackboxEvaluator(BlackboxEvaluator): def __init__(self, train_corpus: corpus.Corpus, estimator_type: blackbox_optimizers.EstimatorType, total_num_perturbations: int, num_ir_repeats_within_worker: int): - self._samples = [] + self._samples: list[list[corpus.LoadedModuleSpec]] = [] self._train_corpus = train_corpus self._total_num_perturbations = total_num_perturbations self._num_ir_repeats_within_worker = num_ir_repeats_within_worker self._estimator_type = estimator_type + self._baselines: list[float | None] | None = None super().__init__(train_corpus) - def get_results( - self, pool: FixedWorkerPool, - perturbations: list[bytes]) -> list[concurrent.futures.Future]: + def _load_samples(self) -> None: + """Samples and loads modules if not already done. + + Ensures self._samples contains the expected number of loaded samples. + Raises RuntimeError if loading fails and counts don't match. + """ if not self._samples: + logging.info('Sampling and loading modules for evaluator...') + for _ in range(self._total_num_perturbations): - sample = self._train_corpus.sample(self._num_ir_repeats_within_worker) - self._samples.append(sample) + samples = self._train_corpus.sample(self._num_ir_repeats_within_worker) + loaded_samples = [ + self._train_corpus.load_module_spec(sample) for sample in samples + ] + self._samples.append(loaded_samples) + # add copy of sample for antithetic perturbation pair if self._estimator_type == ( blackbox_optimizers.EstimatorType.ANTITHETIC): - self._samples.append(sample) + self._samples.append(loaded_samples) + + logging.info('Done sampling and loading modules for evaluator.') - compile_args = zip(perturbations, self._samples) + if self._estimator_type == (blackbox_optimizers.EstimatorType.ANTITHETIC): + expected_count = 2 * self._total_num_perturbations + else: + expected_count = self._total_num_perturbations + + if len(self._samples) != expected_count: + raise RuntimeError("Some samples could not be loaded correctly.") + def _launch_compilation_workers(self, + pool: FixedWorkerPool, + perturbations: list[bytes] | None = None + ) -> list[concurrent.futures.Future]: + if perturbations is None: + perturbations = [None] * len(self._samples) + + compile_args = list(zip(perturbations, self._samples)) _, futures = buffered_scheduler.schedule_on_worker_pool( - action=lambda w, v: w.compile(v[0], v[1]), + action=lambda w, args: w.compile(policy=args[0], modules=args[1]), jobs=compile_args, worker_pool=pool) @@ -97,12 +123,44 @@ def get_results( # update lists as work gets done _, not_done = concurrent.futures.wait( not_done, return_when=concurrent.futures.FIRST_COMPLETED) - return futures + def get_results( + self, pool: FixedWorkerPool, + perturbations: list[bytes]) -> list[concurrent.futures.Future]: + # We should have _samples by now. + if not self._samples: + raise RuntimeError("Loaded samples are not available.") + return self._launch_compilation_workers(pool, perturbations) + def set_baseline(self, pool: FixedWorkerPool) -> None: - del pool # Unused. - pass + if self._baselines is not None: + raise RuntimeError('The baseline has already been set.') + self._load_samples() + results = self._launch_compilation_workers(pool) + self._baselines = super().get_rewards(results) + + def get_rewards( + self, results: list[concurrent.futures.Future]) -> list[float | None]: + if self._baselines is None: + raise RuntimeError('The baseline has not been set.') + + if len(results) != len(self._baselines): + raise RuntimeError( + 'The number of results does not match the number of baselines.') + + policy_results = super().get_rewards(results) + + rewards = [] + for i in range(len(policy_results)): + policy_result = policy_results[i] + baseline = self._baselines[i] + if policy_result is None or baseline is None: + rewards.append(None) + else: + rewards.append( + compilation_runner.calculate_reward(policy_result, baseline)) + return rewards @gin.configurable diff --git a/compiler_opt/es/inlining/__init__.py b/compiler_opt/es/inlining/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/compiler_opt/es/inlining/blackbox_learner.gin b/compiler_opt/es/inlining/blackbox_learner.gin new file mode 100644 index 00000000..a42537e2 --- /dev/null +++ b/compiler_opt/es/inlining/blackbox_learner.gin @@ -0,0 +1,30 @@ +import compiler_opt.es.blackbox_learner +import compiler_opt.rl.gin_external_configurables +import compiler_opt.es.blackbox_optimizers +import compiler_opt.es.blackbox_evaluator +import compiler_opt.es.es_trainer_lib + +# Blackbox learner config +BlackboxLearnerConfig.total_steps = 10000 +BlackboxLearnerConfig.total_num_perturbations = 100 +BlackboxLearnerConfig.blackbox_optimizer = %blackbox_optimizers.Algorithm.MONTE_CARLO +BlackboxLearnerConfig.estimator_type = %blackbox_optimizers.EstimatorType.ANTITHETIC +BlackboxLearnerConfig.fvalues_normalization = True +BlackboxLearnerConfig.hyperparameters_update_method = %blackbox_optimizers.UpdateMethod.NO_METHOD + +BlackboxLearnerConfig.num_top_directions = 0 + +BlackboxLearnerConfig.precision_parameter = 0.5 + +BlackboxLearnerConfig.step_size = 0.005 + +blackbox_evaluator.TraceBlackboxEvaluator.bb_trace_path = '' +blackbox_evaluator.TraceBlackboxEvaluator.function_index_path = '' + +BlackboxLearnerConfig.evaluator = @blackbox_evaluator.TraceBlackboxEvaluator + +#compiler_opt.es.es_trainer_lib.train.worker_class = @RegallocTraceWorker +# Some flags that need to be deleted for successful compilation of XFDO +# binaries. This set will need to be modified depending upon your compilation +# setup. +compiler_opt.es.es_trainer_lib.train.delete_compilation_flags = ('-fprofile-sample-use', '-split-dwarf-file', '-split-dwarf-output', '-fdebug-compilation-dir', '--warning-suppression-mappings') diff --git a/compiler_opt/es/inlining/inlining.gin b/compiler_opt/es/inlining/inlining.gin new file mode 100644 index 00000000..fa993919 --- /dev/null +++ b/compiler_opt/es/inlining/inlining.gin @@ -0,0 +1,25 @@ +import compiler_opt.rl.gin_external_configurables +import compiler_opt.rl.inlining.config + +include 'compiler_opt/rl/inlining/gin_configs/common.gin' + +# Inlining model settings +ActorDistributionNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate() +ActorDistributionNetwork.fc_layer_params=(64, 64, 64, 64) +ActorDistributionNetwork.dropout_layer_params=None +ActorDistributionNetwork.activation_fn=@tf.keras.activations.relu + +inlining.config.get_observation_processing_layer_creator.quantile_file_dir='/cns/oz-d/home/mlcompileropt-dev/llvm_inlining/muppet_20210707/vocab/' +inlining.config.get_observation_processing_layer_creator.with_sqrt = False +inlining.config.get_observation_processing_layer_creator.with_z_score_normalization = False + +policy_utils.create_actor_policy.actor_network_ctor = @actor_distribution_network.ActorDistributionNetwork + +inlining.config.get_observation_processing_layer_creator.quantile_file_dir='compiler_opt/rl/inlining/vocab' +inlining.config.get_observation_processing_layer_creator.with_sqrt = False +inlining.config.get_observation_processing_layer_creator.with_z_score_normalization = False +inlining.config.get_observation_processing_layer_creator.normalize_ir2vec = False + +# ToDo: Change IR2Vec vocab JSON to contain dim rather than getting it as input separately +inlining.config.get_inlining_signature_spec.ir2vec_dim=%ir2vec_dim +inlining.config.get_nonnormalized_features.ir2vec_vocab_path=%ir2vec_vocab_path \ No newline at end of file diff --git a/compiler_opt/es/inlining/inlining_worker.py b/compiler_opt/es/inlining/inlining_worker.py new file mode 100644 index 00000000..cbd88c1a --- /dev/null +++ b/compiler_opt/es/inlining/inlining_worker.py @@ -0,0 +1,184 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Worker for inlining for size. +""" + +from collections.abc import Collection +import logging +from compiler_opt.rl import compilation_runner +import os +import pathlib +import subprocess +import json +import concurrent.futures +import tempfile +import shutil + +import gin +from absl import flags +from compiler_opt.rl import corpus +from compiler_opt.distributed import worker +from compiler_opt.rl import policy_saver +from compiler_opt.es import policy_utils + + +@gin.configurable +class InliningWorker(worker.Worker): + """A worker that produces rewards for a given Inlining policy. + + InliningWorker exposes a compile function, which + compiles a set of modules in parallel remotely, evaluates them with + llvm-size, and then computes the rewards based on the baseline size. + """ + + def _setup_base_policy(self): + self._tf_base_temp_dir = tempfile.mkdtemp() + policy = policy_utils.create_actor_policy() + saver = policy_saver.PolicySaver({"policy": policy}) + saver.save(self._tf_base_temp_dir) + self._tf_base_policy_path = os.path.join(self._tf_base_temp_dir, "policy") + + def __init__(self, + *, + gin_config: str, + clang_path: str, + llvm_size_path: str, + ir2vec_vocab_path: str | None = None, + ir2vec_avg: bool = False, + thread_count: int, + corpus_path: str): + """Initializes the RegallocTraceWorker class. + + Args: + clang_path: The path to the clang binary to use for compiling the corpus. + basic_block_trace_model_path: The path to the basic_block_trace_model + binary to use for trace-based modelling. basic_block_trace_model takes + in a set of modules, a trace, and auxiliary information for + interpreting the trace, simulates the trace against the code in the + passed-in modules, returning estimated cycle counts. + thread_count: The number of threads to use for concurrent compilation + and modelling. + corpus_path: The path to the corpus that modules will be compiled from. + """ + self._clang_path = clang_path + self._thread_count = thread_count + self._corpus_path = corpus_path + self._llvm_size_path = llvm_size_path + self._ir2vec_vocab_path = ir2vec_vocab_path + self._ir2vec_avg = ir2vec_avg + self._compilation_timeout = compilation_runner.COMPILATION_TIMEOUT.value + self._cancellation_manager = compilation_runner.WorkerCancellationManager() + + gin.parse_config(gin_config) + self._setup_base_policy() + + # Deletion here is best effort as it occurs at GC time. If the shutdown is + # forced, cleanup might not happen as expected. This does not matter too + # much though as resource leakage will be small, and any cloud setups will + # have tempdirs wiped periodically. + def __del__(self): + shutil.rmtree(self._tf_base_temp_dir, ignore_errors=True) + + def _compile_module_and_get_size(self, + loaded_module_spec: corpus.LoadedModuleSpec, + output_directory: str, + tflite_policy_path: str | None) -> float: + """Compiles a single LoadedModuleSpec and returns its native code size.""" + working_dir = tempfile.mkdtemp(dir=output_directory) + log_path = os.path.join(working_dir, 'log') + output_native_path = os.path.join(working_dir, 'native.o') + + # Build the final command line using LoadedModuleSpec + original_cmd_line = loaded_module_spec.build_command_line(working_dir) + + cmdline = [] + cmdline.extend([self._clang_path] + list(original_cmd_line)) + + # Add ML Inliner flags + cmdline.extend(['-mllvm', '-enable-ml-inliner=development']) + if self._ir2vec_vocab_path is not None: + cmdline.extend([ + '-mllvm', '-ml-inliner-ir2vec-vocab-file=' + self._ir2vec_vocab_path, + '-mllvm', '-ml-inliner-ir2vec-avg=' + str(self._ir2vec_avg) + ]) + if tflite_policy_path: + cmdline.extend( + ['-mllvm', f'-ml-inliner-model-under-training={tflite_policy_path}']) + # Add other necessary flags (e.g., ir2vec, -mllvm -training-log=...) + + cmdline.extend( + ['-mllvm', '-training-log=' + log_path, '-o', output_native_path]) + + # Run Clang Compilation using cancellable process + compilation_runner.start_cancellable_process( + cmdline, + timeout=self._compilation_timeout, + cancellation_manager=self._cancellation_manager) + + # Run llvm-size + size_cmd = [self._llvm_size_path, output_native_path] + output_bytes = compilation_runner.start_cancellable_process( + size_cmd, + timeout=self._compilation_timeout, + cancellation_manager=self._cancellation_manager, + want_output=True) + + if not output_bytes: + raise RuntimeError(f'Empty llvm-size output: {" ".join(size_cmd)}') + + # Parse llvm-size output (adjust parsing as needed) + output = output_bytes.decode('utf-8') + tmp = output.split('\n') + if len(tmp) != 3: + raise RuntimeError(f'Wrong llvm-size output {output}') + tmp = tmp[1].split('\t') + native_size = int(tmp[0]) + + return native_size + + def compile(self, policy: bytes | None, + modules: list[corpus.LoadedModuleSpec]) -> float: + with tempfile.TemporaryDirectory() as compilation_dir: + tflite_policy_path = None + if policy is not None: + tflite_policy_path = policy_utils.convert_to_tflite( + policy, compilation_dir, self._tf_base_policy_path) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=self._thread_count) as thread_pool: + compile_futures = { + thread_pool.submit(self._compile_module_and_get_size, module, + compilation_dir, tflite_policy_path): + module for module in modules + } + + # Recheck this logic + total_size = 0 + for future in concurrent.futures.as_completed(compile_futures): + module = compile_futures[future] + try: + size = future.result() + # Check for failure indicator from the compile function + if size == float('inf'): + logging.warning( + f"Module {module.name} failed compilation/size measurement.") + total_size += size + except Exception as exc: + # Catch unexpected errors during future processing + logging.error( + f'Module {module.name} generated an exception during future processing: {exc}' + ) + total_size = float('inf') + + return total_size diff --git a/compiler_opt/rl/trainer.py b/compiler_opt/rl/trainer.py index 04959359..b9f2e5d9 100644 --- a/compiler_opt/rl/trainer.py +++ b/compiler_opt/rl/trainer.py @@ -113,6 +113,15 @@ def __init__( ckpt_dir=self._root_dir, agent=self._agent, global_step=self._global_step) + + if self._checkpointer.checkpoint_exists and warmstart_policy_dir: + raise ValueError( + f'Checkpoint exists at {self._root_dir}, but warmstart policy dir is' + ' also provided. This is not supported; please provide only one of' + ' these. To warmstart, use a different root_dir which does not have' + ' a checkpoint. Or, to restore from a checkpoint, do not provide a' + ' warmstart policy dir.') + self._checkpointer.initialize_or_restore() self._start_time = time.time() diff --git a/docs/inlining-demo/demo.md b/docs/inlining-demo/demo.md index 863cbe1a..f66df3fa 100644 --- a/docs/inlining-demo/demo.md +++ b/docs/inlining-demo/demo.md @@ -311,6 +311,12 @@ rm -rf $OUTPUT_DIR && \ --gin_bindings=train_eval.warmstart_policy_dir=\"$WARMSTART_OUTPUT_DIR/saved_policy\" ``` +You can resume training from a previously saved checkpoint by specifying +the directory path containing `ckpt-*.index` files as the `root_dir`. +Typically, this would be the `$OUTPUT_DIR`. So, if `$OUTPUT_DIR` has +previously saved checkpoints, running the above command would restore the +latest checkpoint and resume training. + You may also start a tensorboard to monitor the training process with ```shell