Skip to content

Commit 5e855c5

Browse files
Merge pull request #13976 from benoitsteiner/branch_173415707
Branch 173415707
2 parents 0d2fd1f + 52837da commit 5e855c5

File tree

237 files changed

+7414
-2663
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

237 files changed

+7414
-2663
lines changed

configure.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,19 @@ def set_monolithic():
963963
write_to_bazelrc('build --define framework_shared_object=true')
964964

965965

966+
def create_android_bazelrc_configs():
967+
# Flags for --config=android
968+
write_to_bazelrc('build:android --crosstool_top=//external:android/crosstool')
969+
write_to_bazelrc(
970+
'build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain')
971+
# Flags for --config=android_arm
972+
write_to_bazelrc('build:android_arm --config=android')
973+
write_to_bazelrc('build:android_arm --cpu=armeabi-v7a')
974+
# Flags for --config=android_arm64
975+
write_to_bazelrc('build:android_arm64 --config=android')
976+
write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a')
977+
978+
966979
def main():
967980
# Make a copy of os.environ to be clear when functions and getting and setting
968981
# environment variables.
@@ -1033,7 +1046,7 @@ def main():
10331046
set_cc_opt_flags(environ_cp)
10341047
set_mkl()
10351048
set_monolithic()
1036-
1049+
create_android_bazelrc_configs()
10371050

10381051
if __name__ == '__main__':
10391052
main()

tensorflow/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ filegroup(
331331
"//tensorflow/compiler/jit/kernels:all_files",
332332
"//tensorflow/compiler/jit/legacy_flags:all_files",
333333
"//tensorflow/compiler/jit/ops:all_files",
334+
"//tensorflow/compiler/plugin:all_files",
334335
"//tensorflow/compiler/tests:all_files",
335336
"//tensorflow/compiler/tf2xla:all_files",
336337
"//tensorflow/compiler/tf2xla/cc:all_files",
@@ -456,7 +457,6 @@ filegroup(
456457
"//tensorflow/contrib/training:all_files",
457458
"//tensorflow/contrib/util:all_files",
458459
"//tensorflow/contrib/verbs:all_files",
459-
"//tensorflow/contrib/xla_tf_graph:all_files",
460460
"//tensorflow/core:all_files",
461461
"//tensorflow/core/debug:all_files",
462462
"//tensorflow/core/distributed_runtime:all_files",

tensorflow/c/eager/c_api.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,19 @@ tensorflow::Status ValidateInputTypeAndPlacement(
440440
if (expected_device != actual_device) {
441441
switch (ctx->policy) {
442442
case TFE_DEVICE_PLACEMENT_EXPLICIT:
443+
// TODO(xpan): See if we could bubble python related error up
444+
// to python level.
443445
return tensorflow::errors::InvalidArgument(
444-
"cannot compute ", op->name, " as input #", i,
445-
" was expected to be on ", expected_device->name(),
446-
" but is actually on ", actual_device->name(),
447-
" (operation running on ", op_device->name(), ")");
446+
"Tensors on conflicting devices:"
447+
" cannot compute ",
448+
op->name, " as input #", i, " was expected to be on ",
449+
expected_device->name(), " but is actually on ",
450+
actual_device->name(), " (operation running on ",
451+
op_device->name(), ")",
452+
" Tensors can be copied explicitly using .gpu() or .cpu(),"
453+
" or transparently copied by using tfe.enable_eager_execution("
454+
"tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices"
455+
" may slow down your model");
448456
case TFE_DEVICE_PLACEMENT_WARN:
449457
LOG(WARNING) << "before computing " << op->name << " input #" << i
450458
<< " was expected to be on " << expected_device->name()

tensorflow/compiler/aot/compile.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,12 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
100100
if (!flags.out_session_module.empty()) {
101101
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
102102
computation.Snapshot());
103+
// Serialize the SessionModule deterministically so that all the outputs of
104+
// a tf_library genrule are deterministic.
105+
string proto;
106+
TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto));
103107
TF_RETURN_IF_ERROR(
104-
WriteBinaryProto(Env::Default(), flags.out_session_module, *module));
108+
WriteStringToFile(Env::Default(), flags.out_session_module, proto));
105109
}
106110
xla::cpu::CpuAotCompilationOptions aot_opts(
107111
flags.target_triple, flags.target_cpu, flags.target_features,

tensorflow/compiler/aot/tfcompile.bzl

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def tf_library(name, graph, config,
129129
# Rule that runs tfcompile to produce the header and object file.
130130
header_file = name + ".h"
131131
object_file = name + ".o"
132-
session_module_pb = name + "_session_module.pb"
133132
ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_")
134133
native.genrule(
135134
name=("gen_" + name),
@@ -140,7 +139,6 @@ def tf_library(name, graph, config,
140139
outs=[
141140
header_file,
142141
object_file,
143-
session_module_pb,
144142
],
145143
cmd=("$(location " + tfcompile_tool + ")" +
146144
" --graph=$(location " + tfcompile_graph + ")" +
@@ -150,7 +148,6 @@ def tf_library(name, graph, config,
150148
" --target_triple=" + target_llvm_triple() +
151149
" --out_header=$(@D)/" + header_file +
152150
" --out_object=$(@D)/" + object_file +
153-
" --out_session_module=$(@D)/" + session_module_pb +
154151
" " + (tfcompile_flags or "")),
155152
tools=[tfcompile_tool],
156153
visibility=visibility,
@@ -168,6 +165,34 @@ def tf_library(name, graph, config,
168165
tags=tags,
169166
)
170167

168+
# Rule that runs tfcompile to produce the SessionModule proto, useful for
169+
# debugging. TODO(b/64813587): Once the SessionModule proto is
170+
# deterministic, move this into the main rule above.
171+
session_module_pb = name + "_session_module.pb"
172+
native.genrule(
173+
name=(name + "_session_module"),
174+
srcs=[
175+
tfcompile_graph,
176+
config,
177+
],
178+
outs=[
179+
session_module_pb,
180+
],
181+
cmd=("$(location " + tfcompile_tool + ")" +
182+
" --graph=$(location " + tfcompile_graph + ")" +
183+
" --config=$(location " + config + ")" +
184+
" --entry_point=" + ep +
185+
" --cpp_class=" + cpp_class +
186+
" --target_triple=" + target_llvm_triple() +
187+
" --out_session_module=$(@D)/" + session_module_pb +
188+
" " + (tfcompile_flags or "")),
189+
tools=[tfcompile_tool],
190+
visibility=visibility,
191+
testonly=testonly,
192+
local=1,
193+
tags=tags,
194+
)
195+
171196
# The cc_library rule packaging up the header and object file, and needed
172197
# kernel implementations.
173198
need_xla_data_proto = (tfcompile_flags and

tensorflow/compiler/plugin/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,17 @@ cc_library(
4040
#"//tensorflow/compiler/plugin/example:example_lib",
4141
],
4242
)
43+
44+
#-----------------------------------------------------------------------------
45+
46+
filegroup(
47+
name = "all_files",
48+
srcs = glob(
49+
["**/*"],
50+
exclude = [
51+
"**/METADATA",
52+
"**/OWNERS",
53+
],
54+
),
55+
visibility = ["//tensorflow:__subpackages__"],
56+
)

tensorflow/compiler/tf2xla/kernels/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package(
55
)
66

77
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
8-
load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts")
98

109
tf_kernel_library(
1110
name = "xla_ops",
@@ -83,6 +82,7 @@ tf_kernel_library(
8382
"//tensorflow/compiler/tf2xla:xla_compiler",
8483
"//tensorflow/compiler/tf2xla/ops:sendrecv_ops",
8584
"//tensorflow/compiler/xla:literal_util",
85+
"//tensorflow/compiler/xla:shape_util",
8686
"//tensorflow/compiler/xla:util",
8787
"//tensorflow/compiler/xla:xla_data_proto",
8888
"//tensorflow/compiler/xla/client:client_library",
@@ -152,6 +152,7 @@ cc_library(
152152
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
153153
visibility = ["//visibility:public"],
154154
deps = [
155+
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
155156
"//tensorflow/core:framework_lite",
156157
"//third_party/eigen3",
157158
],
@@ -163,6 +164,7 @@ cc_library(
163164
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
164165
visibility = ["//visibility:public"],
165166
deps = [
167+
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
166168
"//tensorflow/core:framework_lite",
167169
"//third_party/eigen3",
168170
],

tensorflow/compiler/tf2xla/kernels/index_ops.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
2323
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
2424
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
25+
#include "tensorflow/compiler/xla/shape_util.h"
2526
#include "tensorflow/core/framework/kernel_def_builder.h"
2627
#include "tensorflow/core/framework/op_kernel.h"
2728
#include "tensorflow/core/framework/register_types.h"
@@ -82,16 +83,24 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
8283
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
8384
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
8485
// Compute a mask that has 1s for elements equal to the maximum.
85-
xla::ComputationDataHandle mask = b->ConvertElementType(
86+
xla::ComputationDataHandle partial_mask = b->ConvertElementType(
8687
b->Eq(input, input_max, broadcast_dims), xla_index_type);
8788

88-
// Multiply by the vector [0, 1, 2, ...] to convert each 1 into its index.
89-
// TODO(phawkins): add a bitwise And operator to HLO, use a bitwise and
90-
// instead of a multiplication here.
89+
// In order to make identity elements for a bitwise And, we:
90+
// Left shift the 1 to the leftmost bit, yielding 0x10...0
91+
// Arithmetic right shift the 1 back to the rightmost bit, yielding 0xFF...F
92+
int32 bits_in_type =
93+
xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_index_type) * 8 - 1;
94+
xla::ComputationDataHandle shift_amount =
95+
XlaHelpers::IntegerLiteral(b, index_type, bits_in_type);
96+
xla::ComputationDataHandle full_mask = b->ShiftRightArithmetic(
97+
b->ShiftLeft(partial_mask, shift_amount), shift_amount);
98+
99+
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its index.
91100
xla::ComputationDataHandle iota;
92101
OP_REQUIRES_OK(ctx, XlaHelpers::Iota(b, index_type, axis_size, &iota));
93102
xla::ComputationDataHandle product =
94-
b->Mul(mask, iota, /*broadcast_dimensions=*/{axis});
103+
b->And(full_mask, iota, /*broadcast_dimensions=*/{axis});
95104

96105
// If there are multiple maximum elements, choose the one with the highest
97106
// index.

tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#define EIGEN_USE_THREADS
1717

1818
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19+
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
1920
#include "tensorflow/core/framework/tensor_types.h"
2021
#include "tensorflow/core/platform/dynamic_annotations.h"
2122
#include "tensorflow/core/platform/macros.h"
@@ -47,3 +48,5 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) {
4748
extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
4849
tensorflow::argmax_float_1d_xla_impl(out, data);
4950
}
51+
52+
REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);

tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#define EIGEN_USE_THREADS
1717

1818
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19+
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
1920
#include "tensorflow/core/framework/tensor_types.h"
2021
#include "tensorflow/core/platform/dynamic_annotations.h"
2122
#include "tensorflow/core/platform/macros.h"
@@ -49,3 +50,5 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) {
4950
extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
5051
tensorflow::argmax_float_2d_xla_impl(out, data);
5152
}
53+
54+
REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);

0 commit comments

Comments
 (0)