Skip to content

[WIP] Inlining worker for ES (DO NOT SUBMIT) #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 70 additions & 12 deletions compiler_opt/es/blackbox_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,56 @@ class SamplingBlackboxEvaluator(BlackboxEvaluator):
def __init__(self, train_corpus: corpus.Corpus,
estimator_type: blackbox_optimizers.EstimatorType,
total_num_perturbations: int, num_ir_repeats_within_worker: int):
self._samples = []
self._samples: list[list[corpus.LoadedModuleSpec]] = []
self._train_corpus = train_corpus
self._total_num_perturbations = total_num_perturbations
self._num_ir_repeats_within_worker = num_ir_repeats_within_worker
self._estimator_type = estimator_type
self._baselines: list[float | None] | None = None

super().__init__(train_corpus)

def get_results(
self, pool: FixedWorkerPool,
perturbations: list[bytes]) -> list[concurrent.futures.Future]:
def _load_samples(self) -> None:
"""Samples and loads modules if not already done.

Ensures self._samples contains the expected number of loaded samples.
Raises RuntimeError if loading fails and counts don't match.
"""
if not self._samples:
logging.info('Sampling and loading modules for evaluator...')

for _ in range(self._total_num_perturbations):
sample = self._train_corpus.sample(self._num_ir_repeats_within_worker)
self._samples.append(sample)
samples = self._train_corpus.sample(self._num_ir_repeats_within_worker)
loaded_samples = [
self._train_corpus.load_module_spec(sample) for sample in samples
]
self._samples.append(loaded_samples)

# add copy of sample for antithetic perturbation pair
if self._estimator_type == (
blackbox_optimizers.EstimatorType.ANTITHETIC):
self._samples.append(sample)
self._samples.append(loaded_samples)

logging.info('Done sampling and loading modules for evaluator.')

compile_args = zip(perturbations, self._samples)
if self._estimator_type == (blackbox_optimizers.EstimatorType.ANTITHETIC):
expected_count = 2 * self._total_num_perturbations
else:
expected_count = self._total_num_perturbations

if len(self._samples) != expected_count:
raise RuntimeError("Some samples could not be loaded correctly.")

def _launch_compilation_workers(self,
pool: FixedWorkerPool,
perturbations: list[bytes] | None = None
) -> list[concurrent.futures.Future]:
if perturbations is None:
perturbations = [None] * len(self._samples)

compile_args = list(zip(perturbations, self._samples))
_, futures = buffered_scheduler.schedule_on_worker_pool(
action=lambda w, v: w.compile(v[0], v[1]),
action=lambda w, args: w.compile(policy=args[0], modules=args[1]),
jobs=compile_args,
worker_pool=pool)

Expand All @@ -97,12 +123,44 @@ def get_results(
# update lists as work gets done
_, not_done = concurrent.futures.wait(
not_done, return_when=concurrent.futures.FIRST_COMPLETED)

return futures

def get_results(
self, pool: FixedWorkerPool,
perturbations: list[bytes]) -> list[concurrent.futures.Future]:
# We should have _samples by now.
if not self._samples:
raise RuntimeError("Loaded samples are not available.")
return self._launch_compilation_workers(pool, perturbations)

def set_baseline(self, pool: FixedWorkerPool) -> None:
del pool # Unused.
pass
if self._baselines is not None:
raise RuntimeError('The baseline has already been set.')
self._load_samples()
results = self._launch_compilation_workers(pool)
self._baselines = super().get_rewards(results)

def get_rewards(
self, results: list[concurrent.futures.Future]) -> list[float | None]:
if self._baselines is None:
raise RuntimeError('The baseline has not been set.')

if len(results) != len(self._baselines):
raise RuntimeError(
'The number of results does not match the number of baselines.')

policy_results = super().get_rewards(results)

rewards = []
for i in range(len(policy_results)):
policy_result = policy_results[i]
baseline = self._baselines[i]
Comment on lines +155 to +157
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

range(len()) is almost always a sign to use enumerate() or zip() instead.

Suggested change
for i in range(len(policy_results)):
policy_result = policy_results[i]
baseline = self._baselines[i]
for policy_result, baseline in zip(policy_results, self._baselines, strict=True):

if policy_result is None or baseline is None:
rewards.append(None)
else:
rewards.append(
compilation_runner.calculate_reward(policy_result, baseline))
return rewards


@gin.configurable
Expand Down
Empty file.
30 changes: 30 additions & 0 deletions compiler_opt/es/inlining/blackbox_learner.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import compiler_opt.es.blackbox_learner
import compiler_opt.rl.gin_external_configurables
import compiler_opt.es.blackbox_optimizers
import compiler_opt.es.blackbox_evaluator
import compiler_opt.es.es_trainer_lib

# Blackbox learner config
BlackboxLearnerConfig.total_steps = 10000
BlackboxLearnerConfig.total_num_perturbations = 100
BlackboxLearnerConfig.blackbox_optimizer = %blackbox_optimizers.Algorithm.MONTE_CARLO
BlackboxLearnerConfig.estimator_type = %blackbox_optimizers.EstimatorType.ANTITHETIC
BlackboxLearnerConfig.fvalues_normalization = True
BlackboxLearnerConfig.hyperparameters_update_method = %blackbox_optimizers.UpdateMethod.NO_METHOD

BlackboxLearnerConfig.num_top_directions = 0

BlackboxLearnerConfig.precision_parameter = 0.5

BlackboxLearnerConfig.step_size = 0.005

blackbox_evaluator.TraceBlackboxEvaluator.bb_trace_path = '<bb trace path>'
blackbox_evaluator.TraceBlackboxEvaluator.function_index_path = '<function index path>'

BlackboxLearnerConfig.evaluator = @blackbox_evaluator.TraceBlackboxEvaluator

#compiler_opt.es.es_trainer_lib.train.worker_class = @RegallocTraceWorker
# Some flags that need to be deleted for successful compilation of XFDO
# binaries. This set will need to be modified depending upon your compilation
# setup.
compiler_opt.es.es_trainer_lib.train.delete_compilation_flags = ('-fprofile-sample-use', '-split-dwarf-file', '-split-dwarf-output', '-fdebug-compilation-dir', '--warning-suppression-mappings')
25 changes: 25 additions & 0 deletions compiler_opt/es/inlining/inlining.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import compiler_opt.rl.gin_external_configurables
import compiler_opt.rl.inlining.config

include 'compiler_opt/rl/inlining/gin_configs/common.gin'

# Inlining model settings
ActorDistributionNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate()
ActorDistributionNetwork.fc_layer_params=(64, 64, 64, 64)
ActorDistributionNetwork.dropout_layer_params=None
ActorDistributionNetwork.activation_fn=@tf.keras.activations.relu

inlining.config.get_observation_processing_layer_creator.quantile_file_dir='/cns/oz-d/home/mlcompileropt-dev/llvm_inlining/muppet_20210707/vocab/'
inlining.config.get_observation_processing_layer_creator.with_sqrt = False
inlining.config.get_observation_processing_layer_creator.with_z_score_normalization = False

policy_utils.create_actor_policy.actor_network_ctor = @actor_distribution_network.ActorDistributionNetwork

inlining.config.get_observation_processing_layer_creator.quantile_file_dir='compiler_opt/rl/inlining/vocab'
inlining.config.get_observation_processing_layer_creator.with_sqrt = False
inlining.config.get_observation_processing_layer_creator.with_z_score_normalization = False
inlining.config.get_observation_processing_layer_creator.normalize_ir2vec = False

# ToDo: Change IR2Vec vocab JSON to contain dim rather than getting it as input separately
inlining.config.get_inlining_signature_spec.ir2vec_dim=%ir2vec_dim
inlining.config.get_nonnormalized_features.ir2vec_vocab_path=%ir2vec_vocab_path
184 changes: 184 additions & 0 deletions compiler_opt/es/inlining/inlining_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Worker for inlining for size.
"""

from collections.abc import Collection
import logging
from compiler_opt.rl import compilation_runner
import os
import pathlib
import subprocess
import json
import concurrent.futures
import tempfile
import shutil

import gin
from absl import flags
from compiler_opt.rl import corpus
from compiler_opt.distributed import worker
from compiler_opt.rl import policy_saver
from compiler_opt.es import policy_utils


@gin.configurable
class InliningWorker(worker.Worker):
"""A worker that produces rewards for a given Inlining policy.

InliningWorker exposes a compile function, which
compiles a set of modules in parallel remotely, evaluates them with
llvm-size, and then computes the rewards based on the baseline size.
"""

def _setup_base_policy(self):
self._tf_base_temp_dir = tempfile.mkdtemp()
policy = policy_utils.create_actor_policy()
saver = policy_saver.PolicySaver({"policy": policy})
saver.save(self._tf_base_temp_dir)
self._tf_base_policy_path = os.path.join(self._tf_base_temp_dir, "policy")

def __init__(self,
*,
gin_config: str,
clang_path: str,
llvm_size_path: str,
ir2vec_vocab_path: str | None = None,
ir2vec_avg: bool = False,
thread_count: int,
corpus_path: str):
"""Initializes the RegallocTraceWorker class.

Args:
clang_path: The path to the clang binary to use for compiling the corpus.
basic_block_trace_model_path: The path to the basic_block_trace_model
binary to use for trace-based modelling. basic_block_trace_model takes
in a set of modules, a trace, and auxiliary information for
interpreting the trace, simulates the trace against the code in the
passed-in modules, returning estimated cycle counts.
thread_count: The number of threads to use for concurrent compilation
and modelling.
corpus_path: The path to the corpus that modules will be compiled from.
"""
self._clang_path = clang_path
self._thread_count = thread_count
self._corpus_path = corpus_path
self._llvm_size_path = llvm_size_path
self._ir2vec_vocab_path = ir2vec_vocab_path
self._ir2vec_avg = ir2vec_avg
self._compilation_timeout = compilation_runner.COMPILATION_TIMEOUT.value
self._cancellation_manager = compilation_runner.WorkerCancellationManager()

gin.parse_config(gin_config)
self._setup_base_policy()

# Deletion here is best effort as it occurs at GC time. If the shutdown is
# forced, cleanup might not happen as expected. This does not matter too
# much though as resource leakage will be small, and any cloud setups will
# have tempdirs wiped periodically.
def __del__(self):
shutil.rmtree(self._tf_base_temp_dir, ignore_errors=True)

def _compile_module_and_get_size(self,
loaded_module_spec: corpus.LoadedModuleSpec,
output_directory: str,
tflite_policy_path: str | None) -> float:
"""Compiles a single LoadedModuleSpec and returns its native code size."""
working_dir = tempfile.mkdtemp(dir=output_directory)
log_path = os.path.join(working_dir, 'log')
output_native_path = os.path.join(working_dir, 'native.o')

# Build the final command line using LoadedModuleSpec
original_cmd_line = loaded_module_spec.build_command_line(working_dir)

cmdline = []
cmdline.extend([self._clang_path] + list(original_cmd_line))

# Add ML Inliner flags
cmdline.extend(['-mllvm', '-enable-ml-inliner=development'])
if self._ir2vec_vocab_path is not None:
cmdline.extend([
'-mllvm', '-ml-inliner-ir2vec-vocab-file=' + self._ir2vec_vocab_path,
'-mllvm', '-ml-inliner-ir2vec-avg=' + str(self._ir2vec_avg)
])
if tflite_policy_path:
cmdline.extend(
['-mllvm', f'-ml-inliner-model-under-training={tflite_policy_path}'])
# Add other necessary flags (e.g., ir2vec, -mllvm -training-log=...)

cmdline.extend(
['-mllvm', '-training-log=' + log_path, '-o', output_native_path])

# Run Clang Compilation using cancellable process
compilation_runner.start_cancellable_process(
cmdline,
timeout=self._compilation_timeout,
cancellation_manager=self._cancellation_manager)

# Run llvm-size
size_cmd = [self._llvm_size_path, output_native_path]
output_bytes = compilation_runner.start_cancellable_process(
size_cmd,
timeout=self._compilation_timeout,
cancellation_manager=self._cancellation_manager,
want_output=True)

if not output_bytes:
raise RuntimeError(f'Empty llvm-size output: {" ".join(size_cmd)}')

# Parse llvm-size output (adjust parsing as needed)
output = output_bytes.decode('utf-8')
tmp = output.split('\n')
if len(tmp) != 3:
raise RuntimeError(f'Wrong llvm-size output {output}')
tmp = tmp[1].split('\t')
native_size = int(tmp[0])

return native_size

def compile(self, policy: bytes | None,
modules: list[corpus.LoadedModuleSpec]) -> float:
with tempfile.TemporaryDirectory() as compilation_dir:
tflite_policy_path = None
if policy is not None:
tflite_policy_path = policy_utils.convert_to_tflite(
policy, compilation_dir, self._tf_base_policy_path)

with concurrent.futures.ThreadPoolExecutor(
max_workers=self._thread_count) as thread_pool:
compile_futures = {
thread_pool.submit(self._compile_module_and_get_size, module,
compilation_dir, tflite_policy_path):
module for module in modules
}

# Recheck this logic
total_size = 0
for future in concurrent.futures.as_completed(compile_futures):
module = compile_futures[future]
try:
size = future.result()
# Check for failure indicator from the compile function
if size == float('inf'):
logging.warning(
f"Module {module.name} failed compilation/size measurement.")
total_size += size
except Exception as exc:
# Catch unexpected errors during future processing
logging.error(
f'Module {module.name} generated an exception during future processing: {exc}'
)
total_size = float('inf')

return total_size
9 changes: 9 additions & 0 deletions compiler_opt/rl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def __init__(
ckpt_dir=self._root_dir,
agent=self._agent,
global_step=self._global_step)

if self._checkpointer.checkpoint_exists and warmstart_policy_dir:
raise ValueError(
f'Checkpoint exists at {self._root_dir}, but warmstart policy dir is'
' also provided. This is not supported; please provide only one of'
' these. To warmstart, use a different root_dir which does not have'
' a checkpoint. Or, to restore from a checkpoint, do not provide a'
' warmstart policy dir.')

self._checkpointer.initialize_or_restore()

self._start_time = time.time()
Expand Down
6 changes: 6 additions & 0 deletions docs/inlining-demo/demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ rm -rf $OUTPUT_DIR && \
--gin_bindings=train_eval.warmstart_policy_dir=\"$WARMSTART_OUTPUT_DIR/saved_policy\"
```

You can resume training from a previously saved checkpoint by specifying
the directory path containing `ckpt-*.index` files as the `root_dir`.
Typically, this would be the `$OUTPUT_DIR`. So, if `$OUTPUT_DIR` has
previously saved checkpoints, running the above command would restore the
latest checkpoint and resume training.

You may also start a tensorboard to monitor the training process with

```shell
Expand Down
Loading