diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index 98bae956499..44cbf3d40f8 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_bundle/db_writer.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" #include "tensorflow/core/util/tensor_slice_reader.h" @@ -106,25 +108,22 @@ class SaveV2 : public OpKernel { } template - void DumpEvWithGlobalStep(OpKernelContext* context, int variable_index, - const string& tensor_name, BundleWriter& writer, - DataType global_step_type) { + void DumpEvWithGlobalStep(OpKernelContext* context, int variable_index, + const string& tensor_name, BundleWriter& writer, + DataType global_step_type) { if (global_step_type == DT_INT32) { - DumpEv(context, variable_index, - tensor_name, writer); + DumpEv(context, variable_index, tensor_name, writer); } else { - DumpEv(context, variable_index, - tensor_name, writer); + DumpEv(context, variable_index, tensor_name, writer); } } template - void DumpEv(OpKernelContext* context, int variable_index, - const string& tensor_name, BundleWriter& writer) { + void DumpEv(OpKernelContext* context, int variable_index, + const string& tensor_name, BundleWriter& writer) { EmbeddingVar* variable = nullptr; OP_REQUIRES_OK(context, - LookupResource(context, - HandleFromInput(context, variable_index), &variable)); + LookupResource(context, HandleFromInput(context, variable_index), &variable)); const Tensor& global_step = context->input(3); Tensor part_offset_tensor; context->allocate_temp(DT_INT32, @@ -136,8 +135,7 @@ class SaveV2 : public OpKernel { OP_REQUIRES_OK(context, variable->Shrink()); else OP_REQUIRES_OK(context, variable->Shrink(global_step_scalar)); - OP_REQUIRES_OK(context, DumpEmbeddingValues(variable, tensor_name, - &writer, &part_offset_tensor)); + OP_REQUIRES_OK(context, DumpEmbeddingValues(variable, tensor_name, &writer, &part_offset_tensor)); } void Compute(OpKernelContext* context) override { @@ -146,36 +144,53 @@ class SaveV2 : public OpKernel { const Tensor& shape_and_slices = context->input(2); ValidateInputs(true /* is save op */, context, prefix, tensor_names, shape_and_slices); - if (!context->status().ok()) return; const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices. - const int num_tensors = static_cast(tensor_names.NumElements()); + const int num_tensors = static_cast(tensor_names.NumElements()); const string& prefix_string = prefix.scalar()(); const auto& tensor_names_flat = tensor_names.flat(); const auto& shape_and_slices_flat = shape_and_slices.flat(); - BundleWriter writer(Env::Default(), prefix_string); + const int nosql_marker = 0; + auto tempstate = random::New64(); + string db_prefix_tmp = strings::StrCat(prefix_string,"--temp",tempstate); + DBWriter dbwriter(Env::Default(), prefix_string,db_prefix_tmp); + OP_REQUIRES_OK(context, dbwriter.status()); + + BundleWriter writer(Env::Default(), prefix_string,db_prefix_tmp); OP_REQUIRES_OK(context, writer.status()); - VLOG(1) << "BundleWriter, prefix_string: " << prefix_string; + int start_index = 0; if (has_ev_) { start_index = 1; } + int start_ev_key_index = 0; + for (int i = start_index; i < num_tensors; ++i) { - const string& tensor_name = tensor_names_flat(i); + const string& tensor_name = tensor_names_flat(i); + + if (tensor_types_[i] == DT_RESOURCE) { auto& handle = HandleFromInput(context, i + kFixedInputs); if (IsHandle>(handle)) { + EmbeddingVar* variable = nullptr; + OP_REQUIRES_OK(context, + LookupResource(context, HandleFromInput(context, i + kFixedInputs), &variable)); + core::ScopedUnref unref_variable(variable); + const Tensor& global_step = context->input(3); + Tensor part_offset_tensor; + context->allocate_temp(DT_INT32, + TensorShape({kSavedPartitionNum + 1}), + &part_offset_tensor); + if (ev_key_types_[start_ev_key_index] == DT_INT32) { - DumpEvWithGlobalStep(context, - i + kFixedInputs, tensor_name, writer, tensor_types_[0]); + DumpEvWithGlobalStep(context, i + kFixedInputs, tensor_name, writer, tensor_types_[0]); } else if (ev_key_types_[start_ev_key_index] == DT_INT64) { - DumpEvWithGlobalStep(context, - i + kFixedInputs, tensor_name, writer, tensor_types_[0]); + DumpEvWithGlobalStep(context, i + kFixedInputs, tensor_name, writer, tensor_types_[0]); } } else if (IsHandle(handle)) { auto handles = context->input(i + kFixedInputs).flat(); @@ -205,7 +220,6 @@ class SaveV2 : public OpKernel { OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( shape_spec, &shape, &slice, &slice_shape)); - std::vector names_lst = str_util::Split(tensor_name, '|'); for (auto&& name : names_lst) { std::vector tensor_name_x = @@ -218,15 +232,14 @@ class SaveV2 : public OpKernel { OP_REQUIRES_OK(context, SaveHashTable( &writer, hashtable, tensibles, table_name, tensible_name, slice.start(0), slice.length(0), slice_shape.dim_size(0))); + } } else if (IsHandle(handle)) { HashTableAdmitStrategyResource* resource; OP_REQUIRES_OK(context, - LookupResource(context, - HandleFromInput(context, i + kFixedInputs), &resource)); + LookupResource(context, HandleFromInput(context, i + kFixedInputs), &resource)); HashTableAdmitStrategy* strategy = resource->Internal(); - BloomFilterAdmitStrategy* bf = - dynamic_cast(strategy); + BloomFilterAdmitStrategy* bf = dynamic_cast(strategy); CHECK(bf != nullptr) << "Cannot save Non-BloomFilterAdmitStrategy!"; string shape_spec = shape_and_slices_flat(i); @@ -240,33 +253,54 @@ class SaveV2 : public OpKernel { &writer, bf, tensor_name, slice.start(0), slice.length(0), slice_shape.dim_size(0))); } + start_ev_key_index++; - } else { + } else { const Tensor& tensor = context->input(i + kFixedInputs); - if (!shape_and_slices_flat(i).empty()) { const string& shape_spec = shape_and_slices_flat(i); TensorShape shape; TensorSlice slice(tensor.dims()); TensorShape slice_shape; + OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( - shape_spec, &shape, &slice, &slice_shape)); + shape_spec, &shape, &slice, &slice_shape)); OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()), - errors::InvalidArgument("Slice in shape_and_slice " - "specification does not match the " - "shape of the tensor to save: ", - shape_spec, ", tensor: ", - tensor.shape().DebugString())); - - OP_REQUIRES_OK(context, - writer.AddSlice(tensor_name, shape, slice, tensor)); + errors::InvalidArgument("Slice in shape_and_slice " + "specification does not match the " + "shape of the tensor to save: ", + shape_spec, ", tensor: ", + tensor.shape().DebugString())); + + if(nosql_marker==1){ + + OP_REQUIRES_OK(context, + dbwriter.AddSlice(tensor_name, shape, slice, tensor,"slice_tensor")); + } else{ + + OP_REQUIRES_OK(context, + writer.AddSlice(tensor_name, shape, slice, tensor)); + } } else { - OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor)); + if(nosql_marker==1){ + OP_REQUIRES_OK(context, + dbwriter.Add(tensor_name, tensor,"normal_tensor")); + } else{ + string tmp_dbfile_prefix_string = + strings::StrCat(prefix_string,"--temp",tempstate,"--data--0--1","--tensor--",tensor_name); + OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor,tmp_dbfile_prefix_string)); + } } } } - OP_REQUIRES_OK(context, writer.Finish()); + if(nosql_marker==1){ + + OP_REQUIRES_OK(context, dbwriter.Finish()); + } else{ + + OP_REQUIRES_OK(context, writer.Finish()); + } } private: DataTypeVector tensor_types_; @@ -278,8 +312,7 @@ REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2); // Restores a list of named tensors from a tensor bundle (V2 checkpoint format). class RestoreHashTableOp : public AsyncOpKernel { public: - explicit RestoreHashTableOp(OpKernelConstruction* context) - : AsyncOpKernel(context) { + explicit RestoreHashTableOp(OpKernelConstruction* context) : AsyncOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("clear", &clear_)); } @@ -289,8 +322,7 @@ class RestoreHashTableOp : public AsyncOpKernel { const Tensor& shape_and_slices = context->input(2); const Tensor& handles = context->input(3); const string& prefix_string = prefix.scalar()(); - const string& shape_and_slices_string = - shape_and_slices.scalar()(); + const string& shape_and_slices_string = shape_and_slices.scalar()(); auto tensor_names_flat = tensor_names.flat(); auto handles_flat = handles.flat(); @@ -376,8 +408,7 @@ class RestoreHashTableOp : public AsyncOpKernel { private: bool clear_; }; -REGISTER_KERNEL_BUILDER(Name("RestoreHashTable").Device(DEVICE_CPU), - RestoreHashTableOp); +REGISTER_KERNEL_BUILDER(Name("RestoreHashTable").Device(DEVICE_CPU), RestoreHashTableOp); class RestoreBloomFilterOp : public AsyncOpKernel { public: @@ -408,8 +439,7 @@ class RestoreBloomFilterOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC( context, LookupResource(context, handle_flat, &resource), done); strategy = dynamic_cast(resource->Internal()); - CHECK(strategy != nullptr) - << "Cannot restore BloomFilter from another strategy"; + CHECK(strategy != nullptr) << "Cannot restore BloomFilter from another strategy"; } Status st = RestoreBloomFilter( reader.get(), strategy, tensor_name_flat, slice.start(0), @@ -418,8 +448,7 @@ class RestoreBloomFilterOp : public AsyncOpKernel { done(); } }; -REGISTER_KERNEL_BUILDER(Name("RestoreBloomFilter").Device(DEVICE_CPU), - RestoreBloomFilterOp); +REGISTER_KERNEL_BUILDER(Name("RestoreBloomFilter").Device(DEVICE_CPU), RestoreBloomFilterOp); // Restores a list of named tensors from a tensor bundle (V2 checkpoint format). class RestoreV2 : public OpKernel { @@ -438,7 +467,6 @@ class RestoreV2 : public OpKernel { " expected dtypes.")); ValidateInputs(false /* not save op */, context, prefix, tensor_names, shape_and_slices); - if (!context->status().ok()) return; const string& prefix_string = prefix.scalar()(); @@ -501,7 +529,7 @@ class MergeV2Checkpoints : public OpKernel { const string& merged_prefix = destination_prefix.scalar()(); OP_REQUIRES_OK( context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix)); - + if (delete_old_dirs_) { const string merged_dir(io::Dirname(merged_prefix)); for (const string& input_prefix : input_prefixes) { diff --git a/tensorflow/core/util/db_writer.cc b/tensorflow/core/util/db_writer.cc new file mode 100644 index 00000000000..295a7a1ae3b --- /dev/null +++ b/tensorflow/core/util/db_writer.cc @@ -0,0 +1,1677 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/tensor_bundle/db_writer.h" + +#include +#include +#include +#include +#include +#include +#include + + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb_text.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb_text.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/framework/versions.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" +#include "tensorflow/core/util/tensor_bundle/byte_swap.h" +#include "tensorflow/core/util/tensor_slice_util.h" +#include + +namespace tensorflow { + +// Versioning of the tensor bundle format. +//const int kTensorBundleMinProducer = 0; +//const int kTensorBundleMinConsumer = 0; +//const int kTensorBundleVersion = 1; +// +//// Size of our input buffer for streaming reads +static const int kBufferSize = 1024 * 1024; +// +//// Key to the special BundleHeaderProto entry. Do not change this, as clients +//// can make the assumption that the header is always the first entry in the +//// bundle. +//const char* const kHeaderEntryKey = ""; + +namespace { + + +string Gen_DBKey_Helper(const string& key, int64 offset) { + return strings::StrCat(key, ":", offset); +} + +// Reads "num_elements" string elements from file[offset, offset+size) into the +// length-N "destination". Discards the original content of "destination". +// +// Checksums the string lengths (as restored uint32 or uint64, not varint64 +// bytes) and string bytes, and stores it into "actual_crc32c". +Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements, + size_t offset, size_t size, tstring* destination, + uint32* actual_crc32c, bool need_to_swap_bytes) { + if (size == 0) return Status::OK(); + CHECK_GT(size, 0); + + // Reads "num_elements" varint64's from "buffered_file". + TF_RETURN_IF_ERROR(buffered_file->Seek(offset)); + std::vector string_lengths(num_elements); + for (size_t i = 0; i < num_elements; ++i) { + TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i])); + if (string_lengths[i] <= UINT32_MAX) { + // We need to do this because older checkpoints only used uint32s and we + // should still support them. + uint32 elem_size_uint32 = static_cast(string_lengths[i]); + if (need_to_swap_bytes) { + // Checksum would have been computed on the source machine's byte order + elem_size_uint32 = BYTE_SWAP_32(elem_size_uint32); + } + *actual_crc32c = crc32c::Extend( + *actual_crc32c, reinterpret_cast(&elem_size_uint32), + sizeof(uint32)); + } else { + uint64 length = string_lengths[i]; + if (need_to_swap_bytes) { + length = BYTE_SWAP_64(length); + } + *actual_crc32c = + crc32c::Extend(*actual_crc32c, reinterpret_cast(&length), + sizeof(uint64)); + } + } + if (offset + size < buffered_file->Tell()) { + return errors::DataLoss("String lengths longer than expected offset ", + offset + size); + } + + // Reads the length-checksum. + uint32 raw_length_checksum = 0; // Bytes in file + uint32 length_checksum = 0; // In-memory representation + size_t unused_bytes_read = 0; + TF_RETURN_IF_ERROR(buffered_file->ReadNBytes( + sizeof(uint32), reinterpret_cast(&raw_length_checksum), + &unused_bytes_read)); + length_checksum = need_to_swap_bytes ? BYTE_SWAP_32(raw_length_checksum) + : raw_length_checksum; + if (crc32c::Unmask(length_checksum) != *actual_crc32c) { + return errors::DataLoss( + "The length checksum does not match: expected ", + strings::Printf("%08u", crc32c::Unmask(length_checksum)), + " but actual is ", strings::Printf("%08u", *actual_crc32c)); + } + *actual_crc32c = crc32c::Extend(*actual_crc32c, + reinterpret_cast(&raw_length_checksum), + sizeof(uint32)); + + // Reads the actual string bytes. + for (size_t i = 0; i < num_elements; ++i) { + const uint64 string_length = string_lengths[i]; + tstring* buffer = &destination[i]; + + buffer->resize(string_length); + size_t bytes_read = 0; + TF_RETURN_IF_ERROR( + buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read)); + *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read); + } + return Status::OK(); +} + +Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret, + size_t offset, size_t size, uint32* actual_crc32c) { + // On-disk format: + // [varint64 len1][bytes variant1][4 byte checksum] + // .. + // [varint64 lenN][bytes variantN][4 byte checksum] + // Var "crc32c" checksums all the lens, variant bytes, individual variant + // checksums (as uint32, not varint32 bytes). + if (size == 0) return Status::OK(); + size_t num_elements = ret->NumElements(); + + // Reads the actual string bytes. + TF_RETURN_IF_ERROR(buffered_file->Seek(offset)); + for (size_t i = 0; i < num_elements; ++i) { + // Read the serialized variant length. + uint64 string_length = 0; + TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length)); + *actual_crc32c = crc32c::Extend( + *actual_crc32c, reinterpret_cast(&string_length), + sizeof(uint64)); + // Read the actual serialized variant. + string buffer; + buffer.resize(string_length); + size_t bytes_read = 0; + TF_RETURN_IF_ERROR( + buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read)); + *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read); + VariantTensorDataProto proto; + if (!proto.ParseFromString(buffer)) { + return errors::DataLoss("Unable to parse VariantTensorDataProto from ", + "buffer of size ", string_length, ". ", + "Bundle entry offset: ", offset, " size: ", size); + } + Variant v = proto; + if (!DecodeUnaryVariant(&v)) { + return errors::Internal("Could not decode variant with type_name: \"", + v.TypeName(), "\". Perhaps you forgot to ", + "register a decoder via ", + "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?"); + } + + // Read the checksum. + uint32 checksum = 0; + size_t unused_bytes_read = 0; + TF_RETURN_IF_ERROR(buffered_file->ReadNBytes( + sizeof(uint32), reinterpret_cast(&checksum), + &unused_bytes_read)); + if (crc32c::Unmask(checksum) != *actual_crc32c) { + return errors::DataLoss( + "The checksum after Variant ", i, " does not match.", + " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)), + " Actual: ", strings::Printf("%08u", *actual_crc32c)); + } + *actual_crc32c = crc32c::Extend( + *actual_crc32c, reinterpret_cast(&checksum), sizeof(uint32)); + + ret->flat()(i) = std::move(v); + } + + return Status::OK(); +} + +char* GetBackingBuffer(const Tensor& val) { + CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype(); + return const_cast(val.tensor_data().data()); +} + +tstring* GetStringBackingBuffer(const Tensor& val) { + CHECK_EQ(DT_STRING, val.dtype()); + return const_cast(val.flat().data()); +} + +Status ParseEntryProto(StringPiece key, StringPiece value, + protobuf::MessageLite* out) { + if (!out->ParseFromArray(value.data(), value.size())) { + return errors::DataLoss("Entry for key ", key, " not parseable."); + } + return Status::OK(); +} + +// Serializes the data bytes of the non-string tensor "val". Discards the +// original content of "bytes_written", and on OK updates it with number of +// bytes written. +// REQUIRES: val.dtype() != DT_STRING +Status WriteTensor(const Tensor& val, leveldb::DB* dbout_, + size_t* bytes_written) { + + *bytes_written = val.TotalBytes(); + char* buf = GetBackingBuffer(val); + return Status::OK(); +} +Status WriteTensor(const Tensor& val, leveldb::DB* dbout_, + size_t* bytes_written,const string db_tensor_prefix_) { + *bytes_written = val.TotalBytes(); + char* buf = GetBackingBuffer(val); + dbout_->Put(leveldb::WriteOptions(),db_tensor_prefix_,buf); + return Status::OK(); +} + +// Serializes string tensor "val". "bytes_written" is treated in the same +// fashion as WriteTensor(). +// +// Checksums all bytes written and stores it into "crc32c". +// REQUIRES: val.dtype() == DT_STRING +Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out, + size_t* bytes_written, uint32* crc32c) { + // On-disk format: + // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes] + // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes), + // the length-checksum, and all the string bytes. + DCHECK_EQ(val.dtype(), DT_STRING); + const tstring* strings = GetStringBackingBuffer(val); + + // Writes the varint lengths. + string lengths; + lengths.reserve(val.NumElements()); // At least 1 byte per element. + *crc32c = 0; + for (int64 i = 0; i < val.NumElements(); ++i) { + const tstring* elem = &strings[i]; + DCHECK_EQ(elem->size(), static_cast(elem->size())); + const uint64 elem_size = static_cast(elem->size()); + + core::PutVarint64(&lengths, elem_size); + if (elem_size <= UINT32_MAX) { + // We need to do this because older checkpoints only used uint32s and we + // should still support them. + const uint32 elem_size_uint32 = static_cast(elem_size); + *crc32c = crc32c::Extend(*crc32c, + reinterpret_cast(&elem_size_uint32), + sizeof(uint32)); + } else { + *crc32c = crc32c::Extend( + *crc32c, reinterpret_cast(&elem_size), sizeof(uint64)); + } + } + TF_RETURN_IF_ERROR(out->Append(lengths)); + *bytes_written = lengths.size(); + + // Writes the length checksum. + const uint32 length_checksum = crc32c::Mask(*crc32c); + TF_RETURN_IF_ERROR(out->Append(StringPiece( + reinterpret_cast(&length_checksum), sizeof(uint32)))); + *crc32c = crc32c::Extend( + *crc32c, reinterpret_cast(&length_checksum), sizeof(uint32)); + *bytes_written += sizeof(uint32); + + // Writes all the string bytes out. + for (int64 i = 0; i < val.NumElements(); ++i) { + const tstring* string = &strings[i]; + TF_RETURN_IF_ERROR(out->Append(*string)); + *bytes_written += string->size(); + *crc32c = crc32c::Extend(*crc32c, string->data(), string->size()); + } + return Status::OK(); +} +Status WriteStringTensor(const Tensor& val, leveldb::DB* dbout_, + size_t* bytes_written, uint32* crc32c,const string db_tensor_prefix_) { + // On-disk format: + // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes] + // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes), + // the length-checksum, and all the string bytes. + DCHECK_EQ(val.dtype(), DT_STRING); + const tstring* strings = GetStringBackingBuffer(val); + string to_written_data; + // Writes the varint lengths. + string lengths; + lengths.reserve(val.NumElements()); // At least 1 byte per element. + *crc32c = 0; + for (int64 i = 0; i < val.NumElements(); ++i) { + const tstring* elem = &strings[i]; + DCHECK_EQ(elem->size(), static_cast(elem->size())); + const uint64 elem_size = static_cast(elem->size()); + + core::PutVarint64(&lengths, elem_size); + if (elem_size <= UINT32_MAX) { + // We need to do this because older checkpoints only used uint32s and we + // should still support them. + const uint32 elem_size_uint32 = static_cast(elem_size); + *crc32c = crc32c::Extend(*crc32c, + reinterpret_cast(&elem_size_uint32), + sizeof(uint32)); + } else { + *crc32c = crc32c::Extend( + *crc32c, reinterpret_cast(&elem_size), sizeof(uint64)); + } + } + // TF_RETURN_IF_ERROR(out->Append(lengths)); + to_written_data.append(lengths); + *bytes_written = lengths.size(); + + // Writes the length checksum. + const uint32 length_checksum = crc32c::Mask(*crc32c); + std::stringstream ss; + ss << length_checksum; + string tmp_ss; + ss >> tmp_ss; + to_written_data.append(tmp_ss); + + *crc32c = crc32c::Extend( + *crc32c, reinterpret_cast(&length_checksum), sizeof(uint32)); + *bytes_written += sizeof(uint32); + + // Writes all the string bytes out. + for (int64 i = 0; i < val.NumElements(); ++i) { + const tstring* string = &strings[i]; + //TF_RETURN_IF_ERROR(out->Append(*string)); + + to_written_data.append(*string); + *bytes_written += string->size(); + *crc32c = crc32c::Extend(*crc32c, string->data(), string->size()); + } + return Status::OK(); +} +//,const string db_tensor_prefix_ + +Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out, + size_t* bytes_written, uint32* crc32c) { + // On-disk format: + // [varint64 len1][bytes variant1][4 byte checksum] + // .. + // [varint64 lenN][bytes variantN][4 byte checksum] + // Var "crc32c" checksums all the lens, variant bytes, individual variant + // checksums (as uint32, not varint32 bytes). + DCHECK_EQ(val.dtype(), DT_VARIANT); + + *crc32c = 0; + *bytes_written = 0; + + + for (int64 i = 0; i < val.NumElements(); ++i) { + VariantTensorData data; + val.flat()(i).Encode(&data); + VariantTensorDataProto proto; + data.ToProto(&proto); + string elem; + proto.SerializeToString(&elem); + + + // Write the length of the serialized variant. + DCHECK_EQ(elem.size(), static_cast(elem.size())); + const auto elem_size = static_cast(elem.size()); + string len; + core::PutVarint64(&len, elem_size); + TF_RETURN_IF_ERROR(out->Append(len)); + + + *crc32c = crc32c::Extend(*crc32c, reinterpret_cast(&elem_size), + sizeof(uint64)); + *bytes_written += len.size(); + + // Write the serialized variant. + TF_RETURN_IF_ERROR(out->Append(elem)); + + + *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size()); + *bytes_written += elem.size(); + + // Write the checksum. + const uint32 length_checksum = crc32c::Mask(*crc32c); + TF_RETURN_IF_ERROR(out->Append(StringPiece( + reinterpret_cast(&length_checksum), sizeof(uint32)))); + *crc32c = + crc32c::Extend(*crc32c, reinterpret_cast(&length_checksum), + sizeof(uint32)); + *bytes_written += sizeof(uint32); + } + + return Status::OK(); +} +Status WriteVariantTensor(const Tensor& val, leveldb::DB* dbout_, + size_t* bytes_written, uint32* crc32c,const string db_tensor_prefix_) { + // On-disk format: + // [varint64 len1][bytes variant1][4 byte checksum] + // .. + // [varint64 lenN][bytes variantN][4 byte checksum] + // Var "crc32c" checksums all the lens, variant bytes, individual variant + // checksums (as uint32, not varint32 bytes). + DCHECK_EQ(val.dtype(), DT_VARIANT); + + string to_written_data=""; + + *crc32c = 0; + *bytes_written = 0; + for (int64 i = 0; i < val.NumElements(); ++i) { + VariantTensorData data; + val.flat()(i).Encode(&data); + VariantTensorDataProto proto; + data.ToProto(&proto); + string elem; + proto.SerializeToString(&elem); + + // Write the length of the serialized variant. + DCHECK_EQ(elem.size(), static_cast(elem.size())); + const auto elem_size = static_cast(elem.size()); + string len; + core::PutVarint64(&len, elem_size); + //TF_RETURN_IF_ERROR(out->Append(len)); + to_written_data.append(len); + + *crc32c = crc32c::Extend(*crc32c, reinterpret_cast(&elem_size), + sizeof(uint64)); + *bytes_written += len.size(); + + // Write the serialized variant. + //TF_RETURN_IF_ERROR(out->Append(elem)); + to_written_data.append(elem); + + *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size()); + *bytes_written += elem.size(); + + // Write the checksum. + const uint32 length_checksum = crc32c::Mask(*crc32c); + //TF_RETURN_IF_ERROR(out->Append(StringPiece( + // reinterpret_cast(&length_checksum), sizeof(uint32)))); + std::stringstream ss; + ss << length_checksum; + string tmp_ss; + ss >> tmp_ss; + to_written_data.append(tmp_ss); + *crc32c = + crc32c::Extend(*crc32c, reinterpret_cast(&length_checksum), + sizeof(uint32)); + *bytes_written += sizeof(uint32); + } + dbout_->Put(leveldb::WriteOptions(),db_tensor_prefix_,to_written_data); + return Status::OK(); +} + +// Returns whether "slice_spec" is a full slice, with respect to the full shape. +// +// This can happen say, when "slice_spec" is +// "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0, +// dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against. +bool IsFullSlice(const TensorSlice& slice_spec, + const TensorShape& full_tensor_shape) { + if (slice_spec.IsFull()) { + return true; + } else { + TensorShape sliced_shape; + slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError(); + return sliced_shape == full_tensor_shape; + } +} + +Status CorruptFileError(const Status& in_status, const string& filename, + const string& detail) { + if (in_status.ok()) { + return errors::Internal("Unable to read file (", filename, + "). Perhaps the file is corrupt or was produced by " + "a newer version of TensorFlow with format changes " + "(", + detail, ")"); + } + return Status( + in_status.code(), + strings::StrCat("Unable to read file (", filename, + "). Perhaps the file is corrupt or was produced by a " + "newer version of TensorFlow with format changes (", + detail, "): ", in_status.error_message())); +} + +table::Options TableBuilderOptions() { + table::Options o; + // Compressed tables cannot be read by TensorFlow releases prior to 1.1. + // To smoothen the transition, compressed writes are disabled for now + // (version 1.2) with the intention that they will be enabled again at + // some point (perhaps the 1.3 release?). + o.compression = table::kNoCompression; + return o; +} + +// Writes zeros to output buffer to align the next write to the requested +// alignment. "size" is the current size of the buffer and is updated to the +// new size. +Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) { + int bytes_over = *size % alignment; + if (bytes_over == 0) { + return Status::OK(); + } + int bytes_to_write = alignment - bytes_over; + Status status = out->Append(string(bytes_to_write, '\0')); + if (status.ok()) { + *size += bytes_to_write; + } + return status; +} + +} // namespace + +DBWriter::DBWriter(Env* env, StringPiece prefix, const Options& options) + : env_(env), + options_(options), + prefix_(prefix), + + tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate", + random::New64())), + tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate", + random::New64())), + out_(nullptr), + size_(0) { + status_ = env_->CreateDir(string(io::Dirname(prefix_))); + if (!status_.ok() && !errors::IsAlreadyExists(status_)) { + return; + } + const string filename = DataFilename(prefix_, 0, 1); + std::unique_ptr wrapper; + status_ = env_->NewWritableFile(tmp_data_path_, &wrapper); + if (!status_.ok()) return; + out_ = std::unique_ptr( + new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */)); + + leveldb::Options optionss; + optionss.create_if_missing = true; + + leveldb::Status status = leveldb::DB::Open(optionss,"testdb",&dbout_); + //dbout_->Put(leveldb::WriteOptions(), "key-build", "value"); test ok + assert(status.ok()); + + +} + +DBWriter::DBWriter(Env* env, StringPiece prefix, string db_prefix_tmp,const Options& options) + : env_(env), + options_(options), + prefix_(prefix), + db_prefix_tmp_(db_prefix_tmp), + tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate", + random::New64())), + tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate", + random::New64())), + out_(nullptr), + size_(0) { + status_ = env_->CreateDir(string(io::Dirname(prefix_))); + if (!status_.ok() && !errors::IsAlreadyExists(status_)) { + return; + } + const string filename = DataFilename(prefix_, 0, 1); + std::unique_ptr wrapper; + status_ = env_->NewWritableFile(tmp_data_path_, &wrapper); + if (!status_.ok()) return; + out_ = std::unique_ptr( + new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */)); + + leveldb::Options optionss; + optionss.create_if_missing = true; + leveldb::Status status = leveldb::DB::Open(optionss,"testdb",&dbout_); + assert(status.ok()); + + +} +Status DBWriter::Add(StringPiece key, const Tensor& val,std::string tensor_type) { + + if (!status_.ok()) return status_; + CHECK_NE(key, kHeaderEntryKey); + const string key_string(key); + if (entries_.find(key_string) != entries_.end()) { + status_ = errors::InvalidArgument("Adding duplicate key: ", key); + return status_; + } + + auto random_key = random::New64(); + std::string random_key_string = std::to_string(random_key); + const std::string add_key = std::string(db_prefix_tmp_+"--data--0--1--"+tensor_type+"--"+key_string+"-"+random_key_string); + db_key_list_.push_back(add_key); + + BundleEntryProto* entry = &entries_[key_string]; + entry->set_dtype(val.dtype()); + val.shape().AsProto(entry->mutable_shape()); + entry->set_shard_id(0); + entry->set_offset(random_key); // 修改offset为leveldb的key + + // Updates the data file. 要构建tensor data的key和value,目前已有value + size_t data_bytes_written = 0; + uint32 crc32c = 0; + out_->clear_crc32c(); + + if (val.dtype() == DT_STRING) { + status_ = WriteStringTensor(val, dbout_, &data_bytes_written, &crc32c,add_key); + } else if (val.dtype() == DT_VARIANT) { + status_ = WriteVariantTensor(val, dbout_, &data_bytes_written, &crc32c,add_key); + } else { + status_ = WriteTensor(val, dbout_, &data_bytes_written,add_key); + crc32c = out_->crc32c(); + } + + if (status_.ok()) { + entry->set_crc32c(crc32c::Mask(crc32c)); + + } + return status_; +} + +Status DBWriter::AddSlice(StringPiece full_tensor_key, + const TensorShape& full_tensor_shape, + const TensorSlice& slice_spec, + const Tensor& slice_tensor,std::string tensor_type) { + if (!status_.ok()) return status_; + CHECK_NE(full_tensor_key, kHeaderEntryKey); + + // If just a singleton full slice, use the regular Add() to be more efficient. + if (IsFullSlice(slice_spec, full_tensor_shape)) { + return Add(full_tensor_key, slice_tensor,"slice_tensor"); + } + + // Inserts/updates the full tensor's metadata entry. + // + // In the case of a sharded save, MergeBundles() is responsible for merging + // the "slices" field of multiple metadata entries corresponding to the same + // full tensor. + const string full_tensor_key_string(full_tensor_key); + + BundleEntryProto* full_entry = &entries_[full_tensor_key_string]; + if (full_entry->dtype() != DT_INVALID) { + CHECK_EQ(full_entry->dtype(), slice_tensor.dtype()); + } + if (full_entry->has_shape()) { + CHECK(TensorShape(full_entry->shape()) == full_tensor_shape); + } + + // Populates dtype, shape, and slices. Intentionally leaving out shard_id and + // offset, which do not make sense for this full tensor entry. + full_entry->set_dtype(slice_tensor.dtype()); + full_tensor_shape.AsProto(full_entry->mutable_shape()); + TensorSliceProto* slice_proto = full_entry->add_slices(); + slice_spec.AsProto(slice_proto); + + // The slice itself is handled by a regular Add(), which includes adding its + // own metadata entry, and writing out the slice's values. + const string slice_name = + checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec); + status_ = Add(slice_name, slice_tensor,"slice_tensor"); + return status_; +} + +Status DBWriter::AddTensorHeader(StringPiece key, DataType dtype) { + if (!status_.ok()) return status_; + CHECK_NE(key, kHeaderEntryKey); + const string key_string(key); + if (entries_.find(key_string) != entries_.end()) { + status_ = errors::InvalidArgument("Adding duplicate key: ", key); + return status_; + } + + entry_seg_ = &entries_[key_string]; + entry_seg_->set_dtype(dtype); + entry_seg_->set_shard_id(0); + entry_seg_->set_offset(size_); + + out_->clear_crc32c(); + return status_; +} + +//HashTableDBWriter +Status DBWriter::AddSliceHeader( + string tensor_name, const TensorShape& shape, DataType type, bool is_hash, + TensorSliceProto** proto) { + if (!status_.ok()) return status_; + BundleEntryProto* full_entry = &entries_[tensor_name]; + if (full_entry->dtype() != DT_INVALID) { + CHECK_EQ(full_entry->dtype(), type); + } + if (full_entry->has_shape()) { + CHECK(TensorShape(full_entry->shape()) == shape); + } + + full_entry->set_is_hash_table(is_hash); + full_entry->set_dtype(type); + shape.AsProto(full_entry->mutable_shape()); + *proto = full_entry->add_slices(); + return Status::OK(); +} + +Status DBWriter::AddTensorHeader(StringPiece key, DataType dtype, TensorShape shape) { + if (!status_.ok()) return status_; + CHECK_NE(key, kHeaderEntryKey); + const string key_string(key); + if (entries_.find(key_string) != entries_.end()) { + status_ = errors::InvalidArgument("Adding duplicate key: ", key); + return status_; + } + + entry_seg_ = &entries_[key_string]; + entry_seg_->set_dtype(dtype); + shape.AsProto(entry_seg_->mutable_shape()); + entry_seg_->set_shard_id(0); + entry_seg_->set_offset(size_); + + out_->clear_crc32c(); + return status_; +} + +// use if tensor is less or equal than buffer_size, just dump once +Status DBWriter::AddCompeleteData(char* content, int64 data_bytes_written) { + uint32 crc32c = 0; + + status_ = out_->Append(StringPiece(content, data_bytes_written)); + if (!status_.ok()) + return status_; + + crc32c = out_->crc32c(); + + if (status_.ok()) { + entry_seg_->set_size(data_bytes_written); + entry_seg_->set_crc32c(crc32c::Mask(crc32c)); + size_ += data_bytes_written; + } + return status_; +} + +void DBWriter::FillTensorShape(TensorShape shape) { + shape.AsProto(entry_seg_->mutable_shape()); +} +// dump mutiple times; +Status DBWriter::AppendSegmentData(char* content, int64 data_bytes_written) { + return out_->AppendSegment(StringPiece(content, data_bytes_written)); +} + +void DBWriter::EndSegmentData(int64 total_bytes_written, int64 end_bytes_written) { + + //out_->EndSegment(end_bytes_written); + uint32 crc32c = out_->crc32c(); + entry_seg_->set_size(total_bytes_written); + entry_seg_->set_crc32c(crc32c::Mask(crc32c)); + size_ += total_bytes_written; +} + +// TODO(zongheng): on metadata write failure or !status_.ok(), consider removing +// the orphaned data file. +Status DBWriter::Finish() { + if (out_) { + status_.Update(out_->Close()); + out_ = nullptr; + if (status_.ok()) { + status_ = Env::Default()->RenameFile(tmp_data_path_, + DataFilename(prefix_, 0, 1)); + } else { + Env::Default()->DeleteFile(tmp_data_path_).IgnoreError(); + } + } + if (!status_.ok()) return status_; + + int list_lens = db_key_list_.size(); + std::string key_list_str; + for (int i = 0; i < list_lens; i++) { + VLOG(1)<< "DBWriter::Finish()::db_key_list_:" << db_key_list_[i]<<"|qiguai0913"; + std::cout<<"std::cout:"<Put(leveldb::WriteOptions(),db_prefix_tmp_+"--key_list_str",key_list_str); + + // Build key -> BundleEntryProto table. + std::unique_ptr file; + status_ = env_->NewWritableFile(tmp_metadata_path_, &file); + if (!status_.ok()) return status_; + { + // N.B.: the default use of Snappy compression may not be supported on all + // platforms (e.g. Android). The metadata file is small, so this is fine. + table::Options options; + options.compression = table::kNoCompression; + table::TableBuilder builder(options, file.get()); + // Header entry. + BundleHeaderProto header; + header.set_num_shards(1); + header.set_endianness(BundleHeaderProto::LITTLE); + if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG); + VersionDef* version = header.mutable_version(); + version->set_producer(kTensorBundleVersion); + version->set_min_consumer(kTensorBundleMinConsumer); + + builder.Add(kHeaderEntryKey, header.SerializeAsString()); + + // All others. + for (const auto& p : entries_) { + builder.Add(p.first, p.second.SerializeAsString()); + } + status_ = builder.Finish(); + } + status_.Update(file->Close()); + if (!status_.ok()) { + Env::Default()->DeleteFile(tmp_metadata_path_).IgnoreError(); + return status_; + } else { + status_ = + Env::Default()->RenameFile(tmp_metadata_path_, MetaFilename(prefix_)); + if (!status_.ok()) return status_; + } + status_ = errors::Internal("DBWriter is closed"); + return Status::OK(); +} +// +//// Merging tensor bundles. +// +//// Accumulator of metadata states during a merge. +//struct MergeState { +// // Accumulated from the header entries. +// int num_shards = 0; +// +// // Derives "endianness" and "version" from the first bundle merged (hence the +// // "seen_first_bundle" guard). The two fields must be the same for all +// // bundles in a merge. +// bool seen_first_bundle = false; +// BundleHeaderProto_Endianness endianness; +// VersionDef version; +// +// // Tensor key -> BundleEntryProto. +// std::map entries; +// // Data file path -> new shard id in the final merged bundle. +// std::unordered_map shard_ids; +//}; +// +//// Merges entries of "prefix" into the accumulator state "merge". +//// Returns OK iff the merge succeeds. +//static Status MergeOneBundle(Env* env, StringPiece prefix, +// MergeState* merge_state) { +// VLOG(1) << "Merging bundle:" << prefix; +// const string filename = MetaFilename(prefix); +// uint64 file_size; +// TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size)); +// std::unique_ptr file; +// TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file)); +// +// table::Table* table = nullptr; +// TF_RETURN_IF_ERROR( +// table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table)); +// std::unique_ptr table_deleter(table); +// std::unique_ptr iter(table->NewIterator()); +// +// int num_shards; +// // Process header. +// { +// iter->Seek(kHeaderEntryKey); +// if (!iter->Valid()) { +// return CorruptFileError(iter->status(), filename, +// "failed to seek to header entry"); +// } +// BundleHeaderProto header; +// Status s = ParseEntryProto(iter->key(), iter->value(), &header); +// if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header"); +// +// merge_state->num_shards += header.num_shards(); +// if (!merge_state->seen_first_bundle) { +// merge_state->seen_first_bundle = true; +// merge_state->endianness = header.endianness(); +// merge_state->version = header.version(); +// } else { +// // Validates "endianness". +// if (merge_state->endianness != header.endianness()) { +// return errors::InvalidArgument( +// "Merging bundles with conflicting endianness; inputs corrupted?"); +// } +// // Validates "version". +// string curr_version, merge_version; +// header.version().SerializeToString(&curr_version); +// merge_state->version.SerializeToString(&merge_version); +// if (curr_version != merge_version) { +// return errors::InvalidArgument( +// "Merging bundles with different format versions: merged ", +// merge_version, " vs. curr ", curr_version); +// } +// } +// num_shards = header.num_shards(); +// iter->Next(); +// } +// +// // Loops through the non-header to-merge entries. +// BundleEntryProto to_merge_entry; +// for (; iter->Valid(); iter->Next()) { +// const string key(iter->key()); +// const auto entry_iter = merge_state->entries.find(key); +// +// // Illegal: the duplicated entry is a non-slice tensor. +// if (entry_iter != merge_state->entries.end() && +// entry_iter->second.slices().empty()) { +// return errors::InvalidArgument( +// "Duplicate tensor keyed by ", key, +// " encountered, when merging prefix: ", prefix); +// } +// +// TF_RETURN_IF_ERROR( +// ParseEntryProto(iter->key(), iter->value(), &to_merge_entry)); +// +// // The duplicated entry holds metadata for a sliced full tensor. +// // Allows the duplication and merges "slices". +// if (entry_iter != merge_state->entries.end()) { +// BundleEntryProto& existing_entry = entry_iter->second; +// if (to_merge_entry.slices().empty()) { +// return errors::Internal( +// "Duplicate tensor keyed by ", key, +// "; attempting to merge in a non-slice bundle entry"); +// } +// // Only needs merge the "slices" field (and validate dtype/shape). +// for (int i = 0; i < to_merge_entry.slices_size(); ++i) { +// TensorSliceProto* slot = existing_entry.add_slices(); +// *slot = to_merge_entry.slices(i); +// } +// CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype()); +// CHECK(TensorShape(existing_entry.shape()) == +// TensorShape(to_merge_entry.shape())); +// continue; +// } +// +// // Key doesn't duplicate: a fresh tensor/slice entry. +// auto result = merge_state->shard_ids.insert( +// {DataFilename(prefix, to_merge_entry.shard_id(), num_shards), +// merge_state->shard_ids.size()}); +// to_merge_entry.set_shard_id(result.first->second); +// merge_state->entries[key] = to_merge_entry; +// } +// return Status::OK(); +//} +// +//Status FixMergeHashTableBundles(MergeState* state) { +// std::unordered_map bundle_mapping; +// for (auto&& item : state->entries) { +// if (!item.second.is_hash_table()) { +// continue; +// } +// std::multimap sorter; +// for (int slice = 0; slice < item.second.slices_size(); slice++) { +// sorter.emplace(item.second.slices(slice).hash_slice_begin(), +// item.second.mutable_slices(slice)); +// } +// int64 idx = 0; +// std::vector slices; +// for (auto&& itemx : sorter) { +// if (itemx.second->extent(0).length() > 0) { +// slices.emplace_back(); +// TensorSliceProto& slice = slices.back(); +// slice.CopyFrom(*itemx.second); +// slice.mutable_extent(0)->set_start(idx); +// idx += slice.extent(0).length(); +// TensorSlice from_slice(1); +// from_slice.set_start(0, slice.hash_slice_begin()); +// from_slice.set_length(0, slice.hash_slice_length()); +// string from = checkpoint::EncodeTensorNameSlice( +// item.first, from_slice); +// string to = checkpoint::EncodeTensorNameSlice( +// item.first, TensorSlice(slice)); +// if (!bundle_mapping.emplace(from, to).second) { +// return errors::FailedPrecondition( +// "FixMergeHashTableBundles has some error when create bundle mapping."); +// } +// } else { +// TensorSlice from_slice(1); +// from_slice.set_start(0, itemx.second->hash_slice_begin()); +// from_slice.set_length(0, itemx.second->hash_slice_length()); +// string from = checkpoint::EncodeTensorNameSlice( +// item.first, from_slice); +// if (!bundle_mapping.emplace(from, "").second) { +// return errors::FailedPrecondition( +// "FixMergeHashTableBundles has some error when create bundle mapping. 2"); +// } +// } +// } +// item.second.clear_slices(); +// for (auto&& itemx : slices) { +// item.second.add_slices()->CopyFrom(itemx); +// } +// item.second.mutable_shape()->mutable_dim(0)->set_size(idx); +// } +// std::map entries_tmp; +// entries_tmp.swap(state->entries); +// for (auto&& item : entries_tmp) { +// auto iter = bundle_mapping.find(item.first); +// string real_name; +// if (iter == bundle_mapping.end()) { +// real_name = item.first; +// } else { +// real_name = iter->second; +// } +// if (real_name == "") { +// LOG(INFO) << "Ignore Hash Table: " << str_util::CEscape(item.first); +// continue; +// } +// state->entries.emplace(real_name, item.second); +// } +// return Status::OK(); +//}; +// +//Status MergeBundles(Env* env, gtl::ArraySlice prefixes, +// StringPiece merged_prefix) { +// // Merges all metadata tables. +// // TODO(zhifengc): KeyValue sorter if it becomes too big. +// MergeState merge; +// Status status = env->CreateDir(string(io::Dirname(merged_prefix))); +// if (!status.ok() && !errors::IsAlreadyExists(status)) return status; +// for (int i = 0; i < prefixes.size(); ++i) { +// TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge)); +// } +// +// TF_RETURN_IF_ERROR(FixMergeHashTableBundles(&merge)); +// +// // Renames data files to contain the merged bundle prefix. +// for (const auto& p : merge.shard_ids) { +// VLOG(1) << "Renaming " << p.first << " to " +// << DataFilename(merged_prefix, p.second, merge.shard_ids.size()); +// TF_RETURN_IF_ERROR(env->RenameFile( +// p.first, +// DataFilename(merged_prefix, p.second, merge.shard_ids.size()))); +// } +// +// // Writes the final metadata table under the merged prefix. +// std::unique_ptr merged_metadata; +// TF_RETURN_IF_ERROR( +// env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata)); +// { +// table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get()); +// // Header entry. +// BundleHeaderProto header; +// header.set_num_shards(merge.num_shards); +// header.set_endianness(merge.endianness); +// *header.mutable_version() = merge.version; +// builder.Add(kHeaderEntryKey, header.SerializeAsString()); +// // All others. +// for (const auto& p : merge.entries) { +// builder.Add(p.first, p.second.SerializeAsString()); +// } +// status = builder.Finish(); +// } +// status.Update(merged_metadata->Close()); +// if (!status.ok()) return status; +// VLOG(1) << "Merged bundles to:" << merged_prefix; +// +// // Cleanup: best effort based and ignores errors. +// for (const string& prefix : prefixes) { +// env->DeleteFile(MetaFilename(prefix)).IgnoreError(); +// } +// return status; +//} + +// Interface for reading a tensor bundle. + +DBReader::DBReader(Env* env, StringPiece prefix) + : env_(env), + prefix_(prefix), + metadata_(nullptr), + table_(nullptr), + iter_(nullptr), + need_to_swap_bytes_(false) { + const string filename = MetaFilename(prefix_); + uint64 file_size; + status_ = env_->GetFileSize(filename, &file_size); + if (!status_.ok()) return; + + // Opens the metadata table. + std::unique_ptr wrapper; + status_ = env_->NewRandomAccessFile(filename, &wrapper); + if (!status_.ok()) return; + metadata_ = wrapper.release(); + status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_); + if (!status_.ok()) return; + iter_ = table_->NewIterator(); + + // Reads "num_shards_" from the first entry. + iter_->Seek(kHeaderEntryKey); + if (!iter_->Valid()) { + status_ = CorruptFileError(iter_->status(), filename, + "failed to seek to header entry"); + return; + } + BundleHeaderProto header; + status_ = ParseEntryProto(iter_->key(), iter_->value(), &header); + if (!status_.ok()) { + status_ = CorruptFileError(status_, filename, "unable to parse header"); + return; + } + num_shards_ = header.num_shards(); + if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) || + (header.endianness() == BundleHeaderProto::LITTLE && + !port::kLittleEndian)) { + need_to_swap_bytes_ = true; + } + status_ = CheckVersions(header.version(), kTensorBundleVersion, + kTensorBundleMinProducer, "Checkpoint", "checkpoint"); +} + +DBReader::~DBReader() { + delete metadata_; + delete iter_; + delete table_; + // InputBuffer does not own the underlying RandomAccessFile. + for (auto pair : data_) { + if (pair.second != nullptr && pair.second->file() != nullptr) { + delete pair.second->file(); + } + } + gtl::STLDeleteValues(&data_); + gtl::STLDeleteValues(&tensor_slices_); +} + +Status DBReader::GetBundleEntryProto(StringPiece key, + BundleEntryProto* entry) { + entry->Clear(); + TF_CHECK_OK(status_); + Seek(key); + if (!iter_->Valid() || iter_->key() != key) { + return errors::NotFound("Key ", key, " not found in checkpoint"); + } + + BundleEntryProto entry_copy; + TF_RETURN_IF_ERROR( + ParseEntryProto(iter_->key(), iter_->value(), &entry_copy)); + if (!TensorShape::IsValid(entry_copy.shape())) { + return errors::DataLoss("Invalid tensor shape: ", key, " ", + ProtoShortDebugString(entry_copy.shape())); + } + + *entry = entry_copy; + return Status::OK(); +} + +Status DBReader::GetValue(const BundleEntryProto& entry, Tensor* val) { + Tensor* ret = val; + const TensorShape stored_shape(TensorShape(entry.shape())); + if (val->NumElements() == 0) { + ret = new Tensor(entry.dtype(), stored_shape); + } + + // Validates the "size" field. + if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) { + if (entry.size() != ret->TotalBytes()) { + return errors::DataLoss("Invalid size in bundle entry: key ", key(), + "; stored size ", entry.size(), + "; expected size ", ret->TotalBytes()); + } + } else if (entry.dtype() == DT_STRING) { + // Relaxes the check for string tensors as follows: + // entry.size() == bytes(varint lengths) + bytes(data) + // >= NumElems + bytes(data), since size bytes(varint) >= 1. + // TotalBytes() == sizeof(tstring) * NumElems + bytes(data) + // Since we don't know bytes(varint lengths), we just check an inequality. + const size_t lower_bound = ret->NumElements() + ret->TotalBytes() - + sizeof(tstring) * ret->NumElements(); + if (entry.size() < lower_bound) { + return errors::DataLoss("Invalid size in bundle entry: key ", key(), + "; stored size ", entry.size(), + "; expected size is at least ", lower_bound); + } + } + + // Open the data file if it has not been opened. + io::InputBuffer* buffered_file = data_[entry.shard_id()]; + if (buffered_file == nullptr) { + std::unique_ptr file = nullptr; + TF_RETURN_IF_ERROR(env_->NewRandomAccessFile( + DataFilename(prefix_, entry.shard_id(), num_shards_), &file)); + buffered_file = new io::InputBuffer(file.release(), kBufferSize); + // The InputBuffer and RandomAccessFile objects are both released in dtor. + data_[entry.shard_id()] = buffered_file; + } + CHECK(buffered_file != nullptr); + + TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset())); + uint32 actual_crc32c = 0; + + if (DataTypeCanUseMemcpy(entry.dtype())) { + char* backing_buffer = const_cast((ret->tensor_data().data())); + size_t unused_bytes_read; + if (entry.size() > kBufferSize) { + StringPiece sp; + TF_RETURN_IF_ERROR(buffered_file->file()->Read( + entry.offset(), entry.size(), &sp, backing_buffer)); + if (sp.data() != backing_buffer) { + memmove(backing_buffer, sp.data(), entry.size()); + } + } else { + TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer, + &unused_bytes_read)); + } + // Note that we compute the checksum *before* byte-swapping. The checksum + // should be on the bytes in the order they appear in the file. + actual_crc32c = crc32c::Value(backing_buffer, entry.size()); + if (need_to_swap_bytes_) { + TF_RETURN_IF_ERROR(ByteSwapTensor(ret)); + } + } else if (entry.dtype() == DT_VARIANT) { + if (need_to_swap_bytes_) { + return errors::Unimplemented( + "TensorBundle at ", prefix_, + "is of a different endianness than this machine's hardware, and " + "the bundle contains a variant (arbitrary C++ type) tensor. " + "Byte-swapping of variant tensors is not currently implemented."); + } + // Relies on io::InputBuffer's buffering, because we issue many neighboring + // reads for a single string tensor. + TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(), + entry.size(), &actual_crc32c)); + } else { + // Relies on io::InputBuffer's buffering, because we issue many neighboring + // reads for a single string tensor. + TF_RETURN_IF_ERROR(ReadStringTensor( + buffered_file, ret->NumElements(), entry.offset(), entry.size(), + GetStringBackingBuffer(*ret), &actual_crc32c, need_to_swap_bytes_)); + } + if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) { + return errors::DataLoss( + "Checksum does not match: stored ", + strings::Printf("%08u", crc32c::Unmask(entry.crc32c())), + " vs. calculated on the restored bytes ", actual_crc32c); + } + + *val = *ret; + if (ret != val) delete ret; + return Status::OK(); +} + +Status DBReader::Lookup(StringPiece key, Tensor* val) { + CHECK(val != nullptr); + BundleEntryProto entry; + TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); + + if (entry.slices().empty()) { + return GetValue(entry, val); + } else { + return GetSliceValue( + key, entry, + /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val); + } +} + +Status DBReader::LookupHeader(StringPiece tensor_key, int64 total_bytes) { + BundleEntryProto entry; + TF_RETURN_IF_ERROR(GetBundleEntryProto(tensor_key, &entry)); + if (entry.size() != total_bytes) { + return errors::DataLoss("Invalid size in bundle entry: key ", key(), + "; stored size ", entry.size(), + "; expected size ", total_bytes); + } + io::InputBuffer* buffered_file = data_[entry.shard_id()]; + if (buffered_file == nullptr) { + std::unique_ptr file = nullptr; + TF_RETURN_IF_ERROR(env_->NewRandomAccessFile( + DataFilename(prefix_, entry.shard_id(), num_shards_), &file)); + buffered_file = + new io::InputBuffer(file.release(), 256 << 10 /* 256KB buffer */); + // The InputBuffer and RandomAccessFile objects are both released in dtor. + data_[entry.shard_id()] = buffered_file; + } + CHECK(buffered_file != nullptr); + + TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset())); + if (!DataTypeCanUseMemcpy(entry.dtype())) { + return errors::DataLoss("segment lookup not support string"); + } + LookupSegItem seg_item; + seg_item.entry = entry; + seg_item.total_size = entry.size(); + seg_item.bytes_read = 0; + + tmp_lookupseg_items_[string(tensor_key)] = seg_item; + return Status::OK(); + +} + +Status DBReader::LookupSegment(StringPiece key, size_t buffer_size, char* destination, size_t& real_bytes_read) { + LookupSegItem& seg_item = tmp_lookupseg_items_[string(key)]; + const size_t desired_bytes = std::min(buffer_size, seg_item.total_size); + if (desired_bytes == 0) { + real_bytes_read = 0; + return Status::OK(); + } + + io::InputBuffer* buffered_file = data_[seg_item.entry.shard_id()]; + StringPiece result; + Status status = buffered_file->file()->Read(seg_item.entry.offset() + seg_item.bytes_read, desired_bytes, &result, destination); + + if (!status.ok()) { + return errors::InvalidArgument("Read Error! ", buffer_size, " ", seg_item.total_size, " ", seg_item.entry.offset() + seg_item.bytes_read, " ", desired_bytes, " ", status.ToString()); + } + if (result.size() != desired_bytes) { + return errors::DataLoss("Requested ", desired_bytes, " bytes but read ", + result.size(), " bytes."); + } + // Data is already in the correct location. + seg_item.bytes_read += result.size(); + seg_item.total_size -= result.size(); + real_bytes_read = result.size(); + return Status::OK(); +} + +Status DBReader::LookupSegmentOffset(StringPiece key, uint64_t offset, size_t buffer_size, char* destination, size_t& real_bytes_read) { + LookupSegItem& seg_item = tmp_lookupseg_items_[string(key)]; + const size_t desired_bytes = std::min(buffer_size, seg_item.total_size); + if (desired_bytes == 0) { + real_bytes_read = 0; + return Status::OK(); + } + + io::InputBuffer* buffered_file = data_[seg_item.entry.shard_id()]; + StringPiece result; + Status status = buffered_file->file()->Read(seg_item.entry.offset() + offset, desired_bytes, &result, destination); + + if (!status.ok()) { + return errors::InvalidArgument("Read Error! ", buffer_size, " ", seg_item.total_size, " ", seg_item.entry.offset() + seg_item.bytes_read, " ", desired_bytes, " ", status.ToString()); + } + if (result.size() != desired_bytes) { + return errors::DataLoss("Requested ", desired_bytes, " bytes but read ", + result.size(), " bytes."); + } + // Data is already in the correct location. + seg_item.bytes_read += result.size(); + seg_item.total_size -= result.size(); + real_bytes_read = result.size(); + return Status::OK(); +} + +Status DBReader::GetTensorInfo( + StringPiece key, int64* size, + std::unique_ptr* file, int64* offset) { + BundleEntryProto entry; + TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); + TF_RETURN_IF_ERROR(env_->NewRandomAccessFile( + DataFilename(prefix_, entry.shard_id(), num_shards_), file)); + *size = entry.size(); + *offset = entry.offset(); + return Status::OK(); +} + +Status DBReader::ReadCurrent(Tensor* val) { + CHECK(val != nullptr); + BundleEntryProto entry; + TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry)); + if (!TensorShape::IsValid(entry.shape())) { + return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ", + ProtoShortDebugString(entry.shape())); + } + + if (entry.slices().empty()) { + return GetValue(entry, val); + } else { + return GetSliceValue( + iter_->key(), entry, + /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val); + } +} + +Status DBReader::LookupTensorSlices(StringPiece key, + std::vector* slices) { + slices->clear(); + BundleEntryProto entry; + TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); + slices->reserve(entry.slices_size()); + for (const auto& slice : entry.slices()) { + slices->emplace_back(slice); + } + return Status::OK(); +} + +Status DBReader::LookupSlice(StringPiece full_tensor_key, + const TensorSlice& slice_spec, Tensor* val) { + CHECK(val != nullptr); + BundleEntryProto entry; + TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry)); + return GetSliceValue(full_tensor_key, entry, slice_spec, val); +} + +Status DBReader::LookupTensorSliceProtos( + StringPiece key, std::vector* slices) { + slices->clear(); + BundleEntryProto entry; + TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); + slices->reserve(entry.slices_size()); + for (const auto& slice : entry.slices()) { + slices->emplace_back(slice); + } + return Status::OK(); +} + +Status DBReader::GetSliceValue(StringPiece full_tensor_key, + const BundleEntryProto& full_tensor_entry, + const TensorSlice& slice_spec, Tensor* val) { + using checkpoint::RegisterTensorSlice; + using checkpoint::TensorSliceSet; + DCHECK_GE(full_tensor_entry.slices_size(), 0); + + const TensorShape full_shape(TensorShape(full_tensor_entry.shape())); + std::vector> details; + const string full_tensor_key_string(full_tensor_key); + const TensorSliceSet* tss = + gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string); + + // Populates the "full tensor key -> TensorSliceSet" cache. + if (tss == nullptr) { + if (full_tensor_entry.slices().empty()) { + // Special case: a writer has saved a tensor fully, but the reader wants + // to read in slices. We therefore register the full slice on-demand here + // without further complicating the on-disk bundle format. + TF_RETURN_IF_ERROR(RegisterTensorSlice( + full_tensor_key_string, full_shape, full_tensor_entry.dtype(), + /* tag */ "", + /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_)); + } + for (const TensorSliceProto& slice : full_tensor_entry.slices()) { + TF_RETURN_IF_ERROR(RegisterTensorSlice( + full_tensor_key_string, full_shape, full_tensor_entry.dtype(), + /* tag */ "", TensorSlice(slice), &tensor_slices_)); + } + tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string); + CHECK_NE(tss, nullptr); + } + if (!tss->QueryMeta(slice_spec, &details)) { + return errors::InvalidArgument( + "Does not have sufficient slices for partitioned tensor ", + full_tensor_key, + " to restore in slice_spec: ", slice_spec.DebugString()); + } + + // The union of the slices in "details" covers "slice_spec". Performs the + // copies from each. + BundleEntryProto stored_slice_entry = full_tensor_entry; + for (const auto& slice_tag_pair : details) { + // Seeks for the stored slice. + const TensorSlice& stored_slice = slice_tag_pair.first; + + // We already have the entry for the full tensor, so don't query again if + // the slice is full. + if (!stored_slice.IsFull()) { + const string encoded_stored_slice_name = + checkpoint::EncodeTensorNameSlice(full_tensor_key_string, + stored_slice); + status_ = + GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry); + if (!status_.ok()) return status_; + } + + // TODO(zongheng): should we take an OpKernelContext, so that we can call + // allocate_temp()? Note that without major refactorings to Saver, it's + // hard for the caller of the tensor bundle module to allocate these + // precisely-shaped scratch storage. + + // Optimization for the common case: the stored slice can be directly + // copied to the destination without additional slicing. This is true when + // either the slices are equal or when they are both full slices having the + // same shape. + TensorShape stored_slice_shape(stored_slice_entry.shape()); + if (stored_slice == slice_spec || + (stored_slice_shape == val->shape() && + IsFullSlice(stored_slice, stored_slice_shape) && + IsFullSlice(slice_spec, stored_slice_shape))) { + VLOG(1) << "Optimized for common case: directly copying into " + "pre-allocated buffer; spec: " + << slice_spec.DebugString(); + status_ = GetValue(stored_slice_entry, val); + return status_; + } + + Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape); + status_ = GetValue(stored_slice_entry, &stored_slice_tensor); + if (!status_.ok()) return status_; + + // Copies the intersection over. + const DataType common_dtype = full_tensor_entry.dtype(); + switch (common_dtype) { +#define HANDLE_COPY(T) \ + case DataTypeToEnum::value: \ + CHECK(CopyDataFromTensorSliceToTensorSlice( \ + full_shape, stored_slice, slice_spec, \ + stored_slice_tensor.flat().data(), val->flat().data())); \ + break; + + HANDLE_COPY(float) + HANDLE_COPY(double) + HANDLE_COPY(int32) + HANDLE_COPY(uint8) + HANDLE_COPY(int16) + HANDLE_COPY(int8) + HANDLE_COPY(complex64) + HANDLE_COPY(complex128) + HANDLE_COPY(int64) + HANDLE_COPY(bool) + HANDLE_COPY(qint32) + HANDLE_COPY(quint8) + HANDLE_COPY(qint8) + HANDLE_COPY(bfloat16) + default: + return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype), + " not supported."); + } +#undef HANDLE_COPY + } + return Status::OK(); +} + +bool DBReader::Contains(StringPiece key) { + Seek(key); + return Valid() && (this->key() == key); +} + +Status DBReader::LookupDtypeAndShape(StringPiece key, DataType* dtype, + TensorShape* shape) { + BundleEntryProto entry; + TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); + *dtype = entry.dtype(); + *shape = TensorShape(entry.shape()); + return Status::OK(); +} + +Status DBReader::LookupTensorShape(StringPiece key, TensorShape* shape) { + DataType ignored; + return LookupDtypeAndShape(key, &ignored, shape); +} + +string DBReader::DebugString() { + // Format used below emulates that of TensorSliceReader::DebugString(). + string shape_str; + BundleEntryProto entry; + Seek(kHeaderEntryKey); + for (Next(); Valid(); Next()) { + CHECK(entry.ParseFromArray(value().data(), value().size())); + if (entry.slices_size() > 0) continue; // Slice of some partitioned var. + + strings::StrAppend(&shape_str, key(), " (", + EnumName_DataType(entry.dtype()), ") ", + TensorShape(entry.shape()).DebugString()); + strings::StrAppend(&shape_str, "\n"); + } + return shape_str; +} + + +SegmentDBWriter::SegmentDBWriter( + DBWriter* writer, const string& name, + const TensorShape& shape, DataType type, int64 buffer_size) + : writer_(writer), name_(name), shape_(shape), type_(type), + buffer_size_(buffer_size), buffer_(new char[buffer_size]), + buffer_ptr_(0), write_counter_(0) {} + +Status SegmentDBWriter::Begin() { + return writer_->AddTensorHeader(name_, type_, shape_); +} + +Status SegmentDBWriter::WriteData(const void* data, int64 size) { + while (size > 0) { + if (buffer_ptr_ + size <= buffer_size_) { + memcpy(buffer_.get() + buffer_ptr_, data, size); + buffer_ptr_ += size; + size = 0; + } else { + int64 w = buffer_size_ - buffer_ptr_; + memcpy(buffer_.get() + buffer_ptr_, data, w); + TF_RETURN_IF_ERROR(writer_->AppendSegmentData(buffer_.get(), buffer_size_)); + size -= w; + data = (const char*)data + w; + buffer_ptr_ = 0; + write_counter_++; + } + } + return Status::OK(); +} + +Status SegmentDBWriter::End() { + if (write_counter_ * buffer_size_ + buffer_ptr_ != + shape_.num_elements() * DataTypeSize(type_)) { + return errors::Internal("SegmentDBWriter write size error"); + } + if (write_counter_ == 0) { + return writer_->AddCompeleteData(buffer_.get(), buffer_ptr_); + } else if (buffer_ptr_ > 0) { + TF_RETURN_IF_ERROR(writer_->AppendSegmentData(buffer_.get(), buffer_ptr_)); + writer_->EndSegmentData( + write_counter_ * buffer_size_ + buffer_ptr_, buffer_ptr_); + return Status::OK(); + } else { + writer_->EndSegmentData( + write_counter_ * buffer_size_ + buffer_ptr_, buffer_size_); + return Status::OK(); + } +} + +SegmentDBReader::SegmentDBReader( + DBReader* reader, const string& name, + int64 offset, int64 size, int64 buffer_size) + : reader_(reader), name_(name), buffer_size_(buffer_size), + offset_(offset), size_(size) { } + +Status SegmentDBReader::Begin() { + TF_RETURN_WITH_CONTEXT_IF_ERROR(reader_->LookupDtypeAndShape(name_, &type_, &shape_), "xx1"); + if (size_ == -1) { + size_ = shape_.dim_size(0); + } + if (offset_ + size_ > shape_.dim_size(0)) { + return errors::InvalidArgument("SegmentDBReader offset error"); + } + int64 xsize = DataTypeSize(type_); + for (int i = 1; i < shape_.dims(); i++) { + xsize *= shape_.dim_size(i); + } + int64 real_size_ = xsize * size_; + if (real_size_ < buffer_size_) { + buffer_size_ = real_size_; + } + remain_size_ = real_size_; + int64 var_offset; + int64 var_size; + TF_RETURN_IF_ERROR(reader_->GetTensorInfo(name_, &var_size, &file_, &var_offset)); + input_.reset(new io::InputBuffer(file_.get(), buffer_size_)); + TF_RETURN_IF_ERROR(input_->Seek(var_offset + xsize * offset_)); + return Status::OK(); +} + +const TensorShape& SegmentDBReader::shape() { + return shape_; +} + +DataType SegmentDBReader::type() { + return type_; +} + +Status SegmentDBReader::Read(void* data, int64 size) { + if (size > remain_size_) { + return errors::InvalidArgument("SegmentDBReader Read Exhuasted"); + } + size_t read_size; + TF_RETURN_IF_ERROR(input_->ReadNBytes(size, (char*)data, &read_size)); + remain_size_ -= size; + return Status::OK(); +} + +Status SegmentDBReader::Skip(int64 size) { + if (size > remain_size_) { + return errors::InvalidArgument("SegmentDBReader Read Exhuasted"); + } + TF_RETURN_IF_ERROR(input_->SkipNBytes(size)); + remain_size_ -= size; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/db_writer.h b/tensorflow/core/util/db_writer.h new file mode 100644 index 00000000000..d733fbb631d --- /dev/null +++ b/tensorflow/core/util/db_writer.h @@ -0,0 +1,397 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A tensor bundle is a set of immutable persistent files storing a set of named +// tensors. It is designed for checkpointing TensorFlow tensors. +// +// The paths of the managed files share a common prefix; e.g., with the prefix: +// /fs/model/train/ckpt-step/ckpt +// +// the bundle may contain a metadata file, and sharded data files: +// /fs/model/train/ckpt-step/ +// ckpt.index +// ckpt.data-00000-of-00020 +// ckpt.data-00001-of-00020 +// ... +// ckpt.data-00019-of-00020 +// +// The ".index" file is a string-string immutable table +// (tensorflow::table::Table). Each key is a name of a tensor and its value is +// a serialized BundleEntryProto. Each BundleEntryProto describes the metadata +// of a tensor: which of the "data" files contains the content of a tensor, the +// offset into that file, checksum, some auxiliary data, etc. +// +// A tensor bundle can be accessed randomly using a DBReader. Usage: +// +// DBReader reader(env, "/fs/model/train/ckpt-step/ckpt"); +// reader.Lookup("name", &tensor); +// +// A tensor bundle can be built using DBWriter. Each DBWriter builds a +// single data file bundle. Multiple bundles can then be merged by +// MergeBundles() without reading and writing large chunk of data: it reads the +// metadata files and outputs a single merged metadata. Typical usage: +// +// worker 0: +// DBWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step"); +// writer.Add(...); // Adds the tensors on this worker. +// writer.Finish(); // Flushes. +// worker 1: +// DBWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step"); +// writer.Add(...); +// writer.Finish(); +// worker 2: +// MergeBundles(env, +// {"/fs/model/train/ckpt-step/tmp/worker0-step", +// "/fs/model/train/ckpt-step/tmp/worker1-step"}, +// "/fs/model/train/ckpt-step/ckpt" /* merged prefix */); +// + + +#ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_DB_WRITER_H_ +#define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_DB_WRITER_H_ + + +#include "tensorflow/core/protobuf/tensor_bundle.pb.h" + + +#include +#include +#include +#include "leveldb/db.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/io/table.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_bundle/naming.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" +#include "tensorflow/core/util/tensor_slice_set.h" + +namespace tensorflow { + +// Versioning of the tensor bundle format. +// Follows the same rules as 3p/tf/core/public/version.h. +// +// History: +// 0. Any tensor bundles produced before this field was added. +// 1. Added this field (2016-09-14). +//extern const int kTensorBundleMinProducer; +//extern const int kTensorBundleMinConsumer; +//extern const int kTensorBundleVersion; +// +//// The empty string, hence always the first key in the metadata table. Its +//// corresponding value is a BundleHeaderProto. +//extern const char* const kHeaderEntryKey; + +// Builds a string-string table of tensor names to BundleEntryProto (metadata). +// +// On construction, attempts to create a directory given by the dirname of +// "prefix", so "status()" must be checked before calling any member functions. +// +// All threads accessing the same DBWriter must synchronize. +class DBWriter { + public: + struct Options { + Options() {} + // Alignment, in bytes, for tensor data. //数据对齐 + // Must be >= 1. The default size of 1 densely packs tensors. + int data_alignment{1}; + }; + DBWriter(Env* env, StringPiece prefix, + const Options& options = Options()); + DBWriter(Env* env, StringPiece prefix,string db_prefix_tmp_, + const Options& options = Options()); + // Adds the tensor "val" under key "key". + // Across calls "key" must be unique but can be added in any order. + + Status Add(StringPiece key, const Tensor& val,std::string tensor_type="normal_tensor"); + Status AddTensorHeader(StringPiece key, DataType dtype, TensorShape shape); + Status AddTensorHeader(StringPiece key, DataType dtype); + void FillTensorShape(TensorShape shape); + Status AddCompeleteData(char* content, int64 data_bytes_written); + Status AppendSegmentData(char* content, int64 data_bytes_written); + void EndSegmentData(int64 total_bytes_written, int64 end_bytes_written); + // Partitioned variables support. + // A slice of a full tensor is stored in two entries in the metadata table: + // + // full_tensor_key -> BundleEntryProto, describing all stored slices + // of this full tensor. Does not append to the data + // file. + // encoded slice key -> BundleEntryProto, describing one particular slice. + // Appends values of this slice to the data file. + // + // Slices of a full tensor can be added in any order. + // + // If a full tensor has slices placed on N devices and N DBWriter's are + // concurrently used, the caller must use MergeBundles() to ensure that a + // consistent entry for "full_tensor_key" is produced. + // + // Returns an error if the same slice is added the second time. + Status AddSlice(StringPiece full_tensor_key, + const TensorShape& full_tensor_shape, + const TensorSlice& slice_spec, const Tensor& slice_tensor,std::string tensor_type="slice_tensor"); + + Status AddSliceHeader( + string tensor_name, const TensorShape& shape, DataType type, bool is_hash, + TensorSliceProto** proto); + + // Finishes the writer and flushes. + Status Finish() TF_MUST_USE_RESULT; + + Status status() const { return status_; } + + private: + + Env* const env_; // Not owned. + const Options options_; + const string prefix_; + const string tmp_metadata_path_; + const string tmp_data_path_; + leveldb::DB *dbout_; + std::unique_ptr out_; + int64 size_; // Number of bytes written into out_. + std::map entries_; + Status status_; + const string db_prefix_tmp_; + const string db_prefix_; + BundleEntryProto* entry_seg_; + std::vector db_key_list_; + + TF_DISALLOW_COPY_AND_ASSIGN(DBWriter); +}; + +// Merges a set of bundles (given their prefixes) into a single bundle with the +// given "merged_prefix". The merged metadata is guaranteed to be consistent. +// +// If there are N bundles in "prefixes", during the merge the data files will be +// renamed to contain a proper sharded file spec, with num_shards set to the sum +// of num_shards across the N input bundles. +// +// The caller should only rely on the metadata file of the merged bundle to +// query information about a tensor. In particular, this function does not +// guarantee not to re-order the input data files. +// +// Once merged, makes a best effort to delete the old metadata files. +// Returns OK iff all bundles are successfully merged. + + +// On construction, silently attempts to read the metadata associated with +// "prefix". If caller intends to call any function afterwards, "status()" +// must be checked. +// All threads accessing the same DBReader must synchronize. +class DBReader { + public: + DBReader(Env* const env, StringPiece prefix); + ~DBReader(); + + // Is ok() iff the reader construction is successful (completed the read of + // the metadata). + Status status() const { return status_; } + + // Queries whether the bundle contains an entry keyed by "key". Calls Seek() + // internally, so this call invalidates the reader's current position. + // REQUIRES: status().ok() + bool Contains(StringPiece key); + + // Looks up the dtype and the shape of the tensor keyed by "key". + // REQUIRES: status().ok() + Status LookupDtypeAndShape(StringPiece key, DataType* dtype, + TensorShape* shape) TF_MUST_USE_RESULT; + + // Looks up the shape of the tensor keyed by "key". + // Clears "shape" if not found. + // REQUIRES: status().ok() + Status LookupTensorShape(StringPiece key, + TensorShape* shape) TF_MUST_USE_RESULT; + + // Looks up the tensor keyed by "key". If "key" refers to a partitioned + // tensor, attempts to look up the full contents using all stored slices. + // + // Caller must make sure "val" has the same shape and dtype as the + // corresponding contents, so that its buffer can be filled without needing + // extra allocation. These can be queried via "LookupDtypeAndShape()". + // + // On error, "val" may contain nonsense data. Returns a NotFound error if + // tensor keyed by "key" does not exist in this bundle. + // + // Validates the stored crc32c checksum against the restored bytes. + // REQUIRES: status().ok() + Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT; + + Status LookupHeader(StringPiece key, int64 total_bytes); + Status LookupSegment(StringPiece key, size_t buffer_size, char* destination, size_t& real_bytes_read); + Status LookupSegmentOffset(StringPiece key, uint64_t offset, size_t buffer_size, char* destination, size_t& real_bytes_read); + + Status GetTensorInfo( + StringPiece key, int64* size, + std::unique_ptr* file, int64* offset); + + // Looks up the tensor pointed to by the internal iterator. + // + // On error, "val" may contain nonsense data. + // + // Validates the stored crc32c checksum against the restored bytes. + // REQUIRES: status().ok() && Valid() + Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT; + + // Looks up the slices of the tensor keyed by "key". On OK, "slices" + // is non-empty if and only if the tensor is a partitioned tensor. + // + // Warning - there is no guaranteed ordering for the returned slices, so + // a slice with a larger start index in some dimension could come before + // another slice with a smaller start index in the same dimension. + // REQUIRES: status().ok() + Status LookupTensorSlices(StringPiece key, std::vector* slices) + TF_MUST_USE_RESULT; + + // Looks up a specific slice of a partitioned tensor. + // It is only required that the stored slices cover the requested slice, + // namely "slice_spec" is a subset of the union of the stored slices. + // REQUIRES: status().ok() + Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec, + Tensor* val) TF_MUST_USE_RESULT; + + Status LookupTensorSliceProtos( + StringPiece key, std::vector* slices) + TF_MUST_USE_RESULT; + + // Seeks to the first position in the bundle whose key is no less than "key". + // REQUIRES: status().ok() + void Seek(StringPiece key) { return iter_->Seek(key); } + // Moves to the next position in the bundle. + // REQUIRES: status().ok() + void Next() const { iter_->Next(); } + // Returns true iff the reader is positioned to a key/val pair. + // REQUIRES: status().ok() + bool Valid() const { return iter_->Valid(); } + + // Returns the key at the current position. + // REQUIRES: status().ok() && Valid() + StringPiece key() const { return iter_->key(); } + // Returns the raw value at the current position. + // REQUIRES: status().ok() && Valid() + StringPiece value() const { return iter_->value(); } + + string DebugString(); + + struct LookupSegItem{ + BundleEntryProto entry; + size_t total_size; + size_t bytes_read; + }; + + private: + // Seeks for "key" and reads the metadata proto. + // On non-OK return, clears "entry" for the caller. + // REQUIRES: status().ok() + Status GetBundleEntryProto(StringPiece key, + BundleEntryProto* entry) TF_MUST_USE_RESULT; + + // Reads the tensor value described by the metadata proto "entry". + // Usage for "val" follows the comment of "Lookup()". + Status GetValue(const BundleEntryProto& entry, + Tensor* val) TF_MUST_USE_RESULT; + + // Reads the slice described by "slice_spec". The corresponding full tensor + // has key "ful_tensor_key" and metadata proto "full_tensor_entry". + // REQUIRES: full_tensor_entry.slices_size() > 0 + Status GetSliceValue(StringPiece full_tensor_key, + const BundleEntryProto& full_tensor_entry, + const TensorSlice& slice_spec, + Tensor* val) TF_MUST_USE_RESULT; + + Env* env_; // Not owned. + const string prefix_; + + Status status_; + RandomAccessFile* metadata_; // Owned. + table::Table* table_; + table::Iterator* iter_; + // Owned the InputBuffer objects and their underlying RandomAccessFile's. + std::unordered_map data_; + + // Maps each partitioned tensor's key to its stored slices (represented in a + // TensorSliceSet). Populated on-demand. + std::unordered_map tensor_slices_; + + std::map tmp_lookupseg_items_; + + // Expected number of data file shards in the bundle. Extracted by reading + // the header entry in the metadata table. + int num_shards_; + + // Flag that this class sets to true when the endianness of the target bundle + // differs from that of the current system's processor architecture. + bool need_to_swap_bytes_; + + friend class TensorBundleAlignmentTest; // For testing data alignment. + + TF_DISALLOW_COPY_AND_ASSIGN(DBReader); +}; + +class SegmentDBWriter { + public: + SegmentDBWriter( + DBWriter* writer, const string& name, + const TensorShape& shape, DataType type, int64 buffer_size = 4 << 20); + Status Begin(); + Status WriteData(const void* data, int64 size); + Status End(); + private: + DBWriter* writer_; + string name_; + TensorShape shape_; + DataType type_; + int64 buffer_size_; + std::unique_ptr buffer_; + + int64 buffer_ptr_; + int64 write_counter_; +}; + +class SegmentDBReader { + public: + SegmentDBReader( + DBReader* reader, const string& name, + int64 offset, int64 size, int64 buffer_size = 4 << 20); + Status Begin(); + const TensorShape& shape(); + DataType type(); + Status Read(void* data, int64 size); + Status Skip(int64 size); + private: + DBReader* reader_; + string name_; + int64 buffer_size_; + int64 offset_, size_; + + TensorShape shape_; + DataType type_; + + int64 remain_size_; + + std::unique_ptr file_; + std::unique_ptr input_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ diff --git a/tensorflow/core/util/db_writer_test.cc b/tensorflow/core/util/db_writer_test.cc new file mode 100644 index 00000000000..ed7de1dd467 --- /dev/null +++ b/tensorflow/core/util/db_writer_test.cc @@ -0,0 +1,697 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/tensor_bundle/db_writer.h" + +#include +#include + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/util/tensor_bundle/byte_swap.h" + +namespace tensorflow { + +namespace { + +// Prepend the current test case's working temporary directory to +string Prefix(const string& prefix) { + return strings::StrCat(testing::TmpDir(), "/", prefix); +} + +// Construct a data input directory by prepending the test data root +// directory to +string TestdataPrefix(const string& prefix) { + return strings::StrCat(testing::TensorFlowSrcRoot(), + "/core/util/tensor_bundle/testdata/", prefix); +} + +template +Tensor Constant(T v, TensorShape shape) { + Tensor ret(DataTypeToEnum::value, shape); + ret.flat().setConstant(v); + return ret; +} + +template +Tensor Constant_2x3(T v) { + return Constant(v, TensorShape({2, 3})); +} + +Tensor ByteSwap(Tensor t) { + Tensor ret = tensor::DeepCopy(t); + TF_EXPECT_OK(ByteSwapTensor(&ret)); + return ret; +} + +// Assert that has a tensor under matching in +// terms of both shape, dtype, and value +template +void Expect(DBReader* reader, const string& key, + const Tensor& expected_val) { + // Tests for Contains(). + EXPECT_TRUE(reader->Contains(key)); + // Tests for LookupDtypeAndShape(). + DataType dtype; + TensorShape shape; + TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape)); + EXPECT_EQ(expected_val.dtype(), dtype); + EXPECT_EQ(expected_val.shape(), shape); + // Tests for Lookup(), checking tensor contents. + Tensor val(expected_val.dtype(), shape); + TF_ASSERT_OK(reader->Lookup(key, &val)); + test::ExpectTensorEqual(val, expected_val); +} + +template +void ExpectVariant(DBReader* reader, const string& key, + const Tensor& expected_t) { + // Tests for Contains(). + EXPECT_TRUE(reader->Contains(key)); + // Tests for LookupDtypeAndShape(). + DataType dtype; + TensorShape shape; + TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape)); + // Tests for Lookup(), checking tensor contents. + EXPECT_EQ(expected_t.dtype(), dtype); + EXPECT_EQ(expected_t.shape(), shape); + Tensor actual_t(dtype, shape); + TF_ASSERT_OK(reader->Lookup(key, &actual_t)); + for (int i = 0; i < expected_t.NumElements(); i++) { + Variant actual_var = actual_t.flat()(i); + Variant expected_var = expected_t.flat()(i); + EXPECT_EQ(actual_var.TypeName(), expected_var.TypeName()); + auto* actual_val = actual_var.get(); + auto* expected_val = expected_var.get(); + EXPECT_EQ(*expected_val, *actual_val); + } +} + +template +void ExpectNext(DBReader* reader, const Tensor& expected_val) { + EXPECT_TRUE(reader->Valid()); + reader->Next(); + TF_ASSERT_OK(reader->status()); + Tensor val; + TF_ASSERT_OK(reader->ReadCurrent(&val)); + test::ExpectTensorEqual(val, expected_val); +} + +std::vector AllTensorKeys(DBReader* reader) { + std::vector ret; + reader->Seek(kHeaderEntryKey); + reader->Next(); + for (; reader->Valid(); reader->Next()) { + ret.emplace_back(reader->key()); + } + return ret; +} + +// Writes out the metadata file of a bundle again, with the endianness marker +// bit flipped. +Status FlipEndiannessBit(const string& prefix) { + Env* env = Env::Default(); + const string metadata_tmp_path = Prefix("some_tmp_path"); + std::unique_ptr metadata_file; + TF_RETURN_IF_ERROR(env->NewWritableFile(metadata_tmp_path, &metadata_file)); + // We create the builder lazily in case we run into an exception earlier, in + // which case we'd forget to call Finish() and TableBuilder's destructor + // would complain. + std::unique_ptr builder; + + // Reads the existing metadata file, and fills the builder. + { + const string filename = MetaFilename(prefix); + uint64 file_size; + TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size)); + std::unique_ptr file; + TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file)); + + table::Table* table = nullptr; + TF_RETURN_IF_ERROR( + table::Table::Open(table::Options(), file.get(), file_size, &table)); + std::unique_ptr table_deleter(table); + std::unique_ptr iter(table->NewIterator()); + + // Reads the header entry. + iter->Seek(kHeaderEntryKey); + CHECK(iter->Valid()); + BundleHeaderProto header; + CHECK(header.ParseFromArray(iter->value().data(), iter->value().size())); + // Flips the endianness. + if (header.endianness() == BundleHeaderProto::LITTLE) { + header.set_endianness(BundleHeaderProto::BIG); + } else { + header.set_endianness(BundleHeaderProto::LITTLE); + } + builder.reset( + new table::TableBuilder(table::Options(), metadata_file.get())); + builder->Add(iter->key(), header.SerializeAsString()); + iter->Next(); + + // Adds the non-header entries unmodified. + for (; iter->Valid(); iter->Next()) + builder->Add(iter->key(), iter->value()); + } + TF_RETURN_IF_ERROR(builder->Finish()); + TF_RETURN_IF_ERROR(env->RenameFile(metadata_tmp_path, MetaFilename(prefix))); + return metadata_file->Close(); +} + +template +void TestBasic() { + { + DBWriter writer(Env::Default(), Prefix("foo")); + TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3(T(3)))); + TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3(T(0)))); + TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3(T(2)))); + TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3(T(1)))); + TF_ASSERT_OK(writer.Finish()); + } + { + DBReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ( + AllTensorKeys(&reader), + std::vector({"foo_000", "foo_001", "foo_002", "foo_003"})); + Expect(&reader, "foo_000", Constant_2x3(T(0))); + Expect(&reader, "foo_001", Constant_2x3(T(1))); + Expect(&reader, "foo_002", Constant_2x3(T(2))); + Expect(&reader, "foo_003", Constant_2x3(T(3))); + } + { + DBReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); + EXPECT_TRUE(reader.Valid()); + reader.Next(); + EXPECT_FALSE(reader.Valid()); + } + { + DBWriter writer(Env::Default(), Prefix("bar")); + TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3(T(3)))); + TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3(T(0)))); + TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3(T(2)))); + TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3(T(1)))); + TF_ASSERT_OK(writer.Finish()); + } + { + DBReader reader(Env::Default(), Prefix("bar")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ( + AllTensorKeys(&reader), + std::vector({"bar_000", "bar_001", "bar_002", "bar_003"})); + Expect(&reader, "bar_003", Constant_2x3(T(3))); + Expect(&reader, "bar_002", Constant_2x3(T(2))); + Expect(&reader, "bar_001", Constant_2x3(T(1))); + Expect(&reader, "bar_000", Constant_2x3(T(0))); + } + { + DBReader reader(Env::Default(), Prefix("bar")); + TF_ASSERT_OK(reader.status()); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); + EXPECT_TRUE(reader.Valid()); + reader.Next(); + EXPECT_FALSE(reader.Valid()); + } + TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")}, + Prefix("merged"))); + { + DBReader reader(Env::Default(), Prefix("merged")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ( + AllTensorKeys(&reader), + std::vector({"bar_000", "bar_001", "bar_002", "bar_003", + "foo_000", "foo_001", "foo_002", "foo_003"})); + Expect(&reader, "bar_000", Constant_2x3(T(0))); + Expect(&reader, "bar_001", Constant_2x3(T(1))); + Expect(&reader, "bar_002", Constant_2x3(T(2))); + Expect(&reader, "bar_003", Constant_2x3(T(3))); + Expect(&reader, "foo_000", Constant_2x3(T(0))); + Expect(&reader, "foo_001", Constant_2x3(T(1))); + Expect(&reader, "foo_002", Constant_2x3(T(2))); + Expect(&reader, "foo_003", Constant_2x3(T(3))); + } + { + DBReader reader(Env::Default(), Prefix("merged")); + TF_ASSERT_OK(reader.status()); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); + EXPECT_TRUE(reader.Valid()); + reader.Next(); + EXPECT_FALSE(reader.Valid()); + } +} + +// Type-specific subroutine of SwapBytes test below +template +void TestByteSwap(const T* forward, const T* swapped, int array_len) { + auto bytes_per_elem = sizeof(T); + + // Convert the entire array at once + std::unique_ptr forward_copy(new T[array_len]); + std::memcpy(forward_copy.get(), forward, array_len * bytes_per_elem); + TF_EXPECT_OK(ByteSwapArray(reinterpret_cast(forward_copy.get()), + bytes_per_elem, array_len)); + for (int i = 0; i < array_len; i++) { + EXPECT_EQ(forward_copy.get()[i], swapped[i]); + } + + // Then the array wrapped in a tensor + auto shape = TensorShape({array_len}); + auto dtype = DataTypeToEnum::value; + Tensor forward_tensor(dtype, shape); + Tensor swapped_tensor(dtype, shape); + std::memcpy(const_cast(forward_tensor.tensor_data().data()), forward, + array_len * bytes_per_elem); + std::memcpy(const_cast(swapped_tensor.tensor_data().data()), swapped, + array_len * bytes_per_elem); + TF_EXPECT_OK(ByteSwapTensor(&forward_tensor)); + test::ExpectTensorEqual(forward_tensor, swapped_tensor); +} + +// Unit test of the byte-swapping operations that TensorBundle uses. +TEST(TensorBundleTest, SwapBytes) { + // A bug in the compiler on MacOS causes ByteSwap() and FlipEndiannessBit() + // to be removed from the executable if they are only called from templated + // functions. As a workaround, we make some dummy calls here. + // TODO(frreiss): Remove this workaround when the compiler bug is fixed. + ByteSwap(Constant_2x3(42)); + EXPECT_NE(Status::OK(), FlipEndiannessBit(Prefix("not_a_valid_prefix"))); + + // Test patterns, manually swapped so that we aren't relying on the + // correctness of our own byte-swapping macros when testing those macros. + // At least one of the entries in each list has the sign bit set when + // interpreted as a signed int. + const int arr_len_16 = 4; + const uint16_t forward_16[] = {0x1de5, 0xd017, 0xf1ea, 0xc0a1}; + const uint16_t swapped_16[] = {0xe51d, 0x17d0, 0xeaf1, 0xa1c0}; + const int arr_len_32 = 2; + const uint32_t forward_32[] = {0x0ddba115, 0xf01dab1e}; + const uint32_t swapped_32[] = {0x15a1db0d, 0x1eab1df0}; + const int arr_len_64 = 2; + const uint64_t forward_64[] = {0xf005ba11caba1000, 0x5ca1ab1ecab005e5}; + const uint64_t swapped_64[] = {0x0010baca11ba05f0, 0xe505b0ca1eaba15c}; + + // 16-bit types + TestByteSwap(forward_16, swapped_16, arr_len_16); + TestByteSwap(reinterpret_cast(forward_16), + reinterpret_cast(swapped_16), arr_len_16); + TestByteSwap(reinterpret_cast(forward_16), + reinterpret_cast(swapped_16), arr_len_16); + + // 32-bit types + TestByteSwap(forward_32, swapped_32, arr_len_32); + TestByteSwap(reinterpret_cast(forward_32), + reinterpret_cast(swapped_32), arr_len_32); + TestByteSwap(reinterpret_cast(forward_32), + reinterpret_cast(swapped_32), arr_len_32); + + // 64-bit types + // Cast to uint64*/int64* to make DataTypeToEnum happy + TestByteSwap(reinterpret_cast(forward_64), + reinterpret_cast(swapped_64), arr_len_64); + TestByteSwap(reinterpret_cast(forward_64), + reinterpret_cast(swapped_64), arr_len_64); + TestByteSwap(reinterpret_cast(forward_64), + reinterpret_cast(swapped_64), arr_len_64); + + // Complex types. + // Logic for complex number handling is only in ByteSwapTensor, so don't test + // ByteSwapArray + const float* forward_float = reinterpret_cast(forward_32); + const float* swapped_float = reinterpret_cast(swapped_32); + const double* forward_double = reinterpret_cast(forward_64); + const double* swapped_double = reinterpret_cast(swapped_64); + Tensor forward_complex64 = Constant_2x3( + std::complex(forward_float[0], forward_float[1])); + Tensor swapped_complex64 = Constant_2x3( + std::complex(swapped_float[0], swapped_float[1])); + Tensor forward_complex128 = Constant_2x3( + std::complex(forward_double[0], forward_double[1])); + Tensor swapped_complex128 = Constant_2x3( + std::complex(swapped_double[0], swapped_double[1])); + + TF_EXPECT_OK(ByteSwapTensor(&forward_complex64)); + test::ExpectTensorEqual(forward_complex64, swapped_complex64); + + TF_EXPECT_OK(ByteSwapTensor(&forward_complex128)); + test::ExpectTensorEqual(forward_complex128, swapped_complex128); +} + +// Basic test of alternate-endianness support. Generates a bundle in +// the opposite of the current system's endianness and attempts to +// read the bundle back in. Does not exercise sharding or access to +// nonaligned tensors. Does cover the major access types exercised +// in TestBasic. +template +void TestEndianness() { + { + // Write out a TensorBundle in the opposite of this host's endianness. + DBWriter writer(Env::Default(), Prefix("foo")); + TF_EXPECT_OK(writer.Add("foo_003", ByteSwap(Constant_2x3(T(3))))); + TF_EXPECT_OK(writer.Add("foo_000", ByteSwap(Constant_2x3(T(0))))); + TF_EXPECT_OK(writer.Add("foo_002", ByteSwap(Constant_2x3(T(2))))); + TF_EXPECT_OK(writer.Add("foo_001", ByteSwap(Constant_2x3(T(1))))); + TF_ASSERT_OK(writer.Finish()); + TF_ASSERT_OK(FlipEndiannessBit(Prefix("foo"))); + } + { + DBReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ( + AllTensorKeys(&reader), + std::vector({"foo_000", "foo_001", "foo_002", "foo_003"})); + Expect(&reader, "foo_000", Constant_2x3(T(0))); + Expect(&reader, "foo_001", Constant_2x3(T(1))); + Expect(&reader, "foo_002", Constant_2x3(T(2))); + Expect(&reader, "foo_003", Constant_2x3(T(3))); + } + { + DBReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); + EXPECT_TRUE(reader.Valid()); + reader.Next(); + EXPECT_FALSE(reader.Valid()); + } + { + DBWriter writer(Env::Default(), Prefix("bar")); + TF_EXPECT_OK(writer.Add("bar_003", ByteSwap(Constant_2x3(T(3))))); + TF_EXPECT_OK(writer.Add("bar_000", ByteSwap(Constant_2x3(T(0))))); + TF_EXPECT_OK(writer.Add("bar_002", ByteSwap(Constant_2x3(T(2))))); + TF_EXPECT_OK(writer.Add("bar_001", ByteSwap(Constant_2x3(T(1))))); + TF_ASSERT_OK(writer.Finish()); + TF_ASSERT_OK(FlipEndiannessBit(Prefix("bar"))); + } + { + DBReader reader(Env::Default(), Prefix("bar")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ( + AllTensorKeys(&reader), + std::vector({"bar_000", "bar_001", "bar_002", "bar_003"})); + Expect(&reader, "bar_003", Constant_2x3(T(3))); + Expect(&reader, "bar_002", Constant_2x3(T(2))); + Expect(&reader, "bar_001", Constant_2x3(T(1))); + Expect(&reader, "bar_000", Constant_2x3(T(0))); + } + { + DBReader reader(Env::Default(), Prefix("bar")); + TF_ASSERT_OK(reader.status()); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); + EXPECT_TRUE(reader.Valid()); + reader.Next(); + EXPECT_FALSE(reader.Valid()); + } + TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")}, + Prefix("merged"))); + { + DBReader reader(Env::Default(), Prefix("merged")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ( + AllTensorKeys(&reader), + std::vector({"bar_000", "bar_001", "bar_002", "bar_003", + "foo_000", "foo_001", "foo_002", "foo_003"})); + Expect(&reader, "bar_000", Constant_2x3(T(0))); + Expect(&reader, "bar_001", Constant_2x3(T(1))); + Expect(&reader, "bar_002", Constant_2x3(T(2))); + Expect(&reader, "bar_003", Constant_2x3(T(3))); + Expect(&reader, "foo_000", Constant_2x3(T(0))); + Expect(&reader, "foo_001", Constant_2x3(T(1))); + Expect(&reader, "foo_002", Constant_2x3(T(2))); + Expect(&reader, "foo_003", Constant_2x3(T(3))); + } + { + DBReader reader(Env::Default(), Prefix("merged")); + TF_ASSERT_OK(reader.status()); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); + EXPECT_TRUE(reader.Valid()); + reader.Next(); + EXPECT_FALSE(reader.Valid()); + } +} + +template +void TestNonStandardShapes() { + { + DBWriter writer(Env::Default(), Prefix("nonstandard")); + TF_EXPECT_OK(writer.Add("scalar", Constant(T(0), TensorShape()))); + TF_EXPECT_OK( + writer.Add("non_standard0", Constant(T(0), TensorShape({0, 1618})))); + TF_EXPECT_OK( + writer.Add("non_standard1", Constant(T(0), TensorShape({16, 0, 18})))); + TF_ASSERT_OK(writer.Finish()); + } + { + DBReader reader(Env::Default(), Prefix("nonstandard")); + TF_ASSERT_OK(reader.status()); + Expect(&reader, "scalar", Constant(T(0), TensorShape())); + Expect(&reader, "non_standard0", Constant(T(0), TensorShape({0, 1618}))); + Expect(&reader, "non_standard1", + Constant(T(0), TensorShape({16, 0, 18}))); + } +} + +// Writes a bundle to disk with a bad "version"; checks for "expected_error". +void VersionTest(const VersionDef& version, StringPiece expected_error) { + const string path = Prefix("version_test"); + { + // Prepare an empty bundle with the given version information. + BundleHeaderProto header; + *header.mutable_version() = version; + + // Write the metadata file to disk. + std::unique_ptr file; + TF_ASSERT_OK(Env::Default()->NewWritableFile(MetaFilename(path), &file)); + table::TableBuilder builder(table::Options(), file.get()); + builder.Add(kHeaderEntryKey, header.SerializeAsString()); + TF_ASSERT_OK(builder.Finish()); + } + // Read it back in and verify that we get the expected error. + DBReader reader(Env::Default(), path); + EXPECT_TRUE(errors::IsInvalidArgument(reader.status())); + EXPECT_TRUE( + absl::StartsWith(reader.status().error_message(), expected_error)); +} + +} // namespace + + + +TEST(TensorBundleTest, StringTensors) { + constexpr size_t kLongLength = static_cast(UINT32_MAX) + 1; + Tensor long_string_tensor(DT_STRING, TensorShape({1})); + + { + DBWriter writer(Env::Default(), Prefix("foo")); + TF_EXPECT_OK(writer.Add("string_tensor", + Tensor(DT_STRING, TensorShape({1})))); // Empty. + TF_EXPECT_OK(writer.Add("scalar", test::AsTensor({"hello"}))); + TF_EXPECT_OK(writer.Add( + "strs", + test::AsTensor({"hello", "", "x01", string(1 << 25, 'c')}))); + + // Requires a 64-bit length. + tstring* backing_string = long_string_tensor.flat().data(); +#ifdef USE_TSTRING + backing_string->resize_uninitialized(kLongLength); + std::char_traits::assign(backing_string->data(), kLongLength, 'd'); +#else // USE_TSTRING + backing_string->assign(kLongLength, 'd'); +#endif // USE_TSTRING + TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor)); + + // Mixes in some floats. + TF_EXPECT_OK(writer.Add("floats", Constant_2x3(16.18))); + TF_ASSERT_OK(writer.Finish()); + } + { + DBReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ(AllTensorKeys(&reader), + std::vector({"floats", "long_scalar", "scalar", + "string_tensor", "strs"})); + + Expect(&reader, "string_tensor", + Tensor(DT_STRING, TensorShape({1}))); + Expect(&reader, "scalar", test::AsTensor({"hello"})); + Expect( + &reader, "strs", + test::AsTensor({"hello", "", "x01", string(1 << 25, 'c')})); + + Expect(&reader, "floats", Constant_2x3(16.18)); + + // We don't use the Expect function so we can re-use the + // `long_string_tensor` buffer for reading out long_scalar to keep memory + // usage reasonable. + EXPECT_TRUE(reader.Contains("long_scalar")); + DataType dtype; + TensorShape shape; + TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape)); + EXPECT_EQ(DT_STRING, dtype); + EXPECT_EQ(TensorShape({1}), shape); + + // Zero-out the string so that we can be sure the new one is read in. + tstring* backing_string = long_string_tensor.flat().data(); + backing_string->assign(""); + + // Read long_scalar and check it contains kLongLength 'd's. + TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor)); + ASSERT_EQ(backing_string, long_string_tensor.flat().data()); + EXPECT_EQ(kLongLength, backing_string->length()); + for (size_t i = 0; i < kLongLength; i++) { + // Not using ASSERT_EQ('d', c) because this way is twice as fast due to + // compiler optimizations. + if ((*backing_string)[i] != 'd') { + FAIL() << "long_scalar is not full of 'd's as expected."; + break; + } + } + } +} + +class VariantObject { + public: + VariantObject() {} + VariantObject(const string& metadata, int64 value) + : metadata_(metadata), value_(value) {} + + string TypeName() const { return "TEST VariantObject"; } + void Encode(VariantTensorData* data) const { + data->set_type_name(TypeName()); + data->set_metadata(metadata_); + Tensor val_t = Tensor(DT_INT64, TensorShape({})); + val_t.scalar()() = value_; + *(data->add_tensors()) = val_t; + } + bool Decode(const VariantTensorData& data) { + EXPECT_EQ(data.type_name(), TypeName()); + data.get_metadata(&metadata_); + EXPECT_EQ(data.tensors_size(), 1); + value_ = data.tensors(0).scalar()(); + return true; + } + bool operator==(const VariantObject other) const { + return metadata_ == other.metadata_ && value_ == other.value_; + } + string metadata_; + int64 value_; +}; + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantObject, "TEST VariantObject"); + +TEST(TensorBundleTest, VariantTensors) { + { + DBWriter writer(Env::Default(), Prefix("foo")); + TF_EXPECT_OK( + writer.Add("variant_tensor", + test::AsTensor({VariantObject("test", 10), + VariantObject("test1", 20)}))); + TF_ASSERT_OK(writer.Finish()); + } + { + DBReader reader(Env::Default(), Prefix("foo")); + TF_ASSERT_OK(reader.status()); + ExpectVariant( + &reader, "variant_tensor", + test::AsTensor( + {VariantObject("test", 10), VariantObject("test1", 20)})); + } +} + + +class TensorBundleAlignmentTest : public ::testing::Test { + protected: + template + void ExpectAlignment(DBReader* reader, const string& key, int alignment) { + BundleEntryProto full_tensor_entry; + TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry)); + EXPECT_EQ(0, full_tensor_entry.offset() % alignment); + } +}; + + + +static void BM_BundleAlignmentByteOff(int iters, int alignment, + int tensor_size) { + testing::StopTiming(); + { + DBWriter::Options opts; + opts.data_alignment = alignment; + DBWriter writer(Env::Default(), Prefix("foo"), opts); + TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1})))); + TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size})))); + TF_CHECK_OK(writer.Finish()); + } + DBReader reader(Env::Default(), Prefix("foo")); + TF_CHECK_OK(reader.status()); + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + Tensor t; + TF_CHECK_OK(reader.Lookup("big", &t)); + } + testing::StopTiming(); +} + +#define BM_BundleAlignment(ALIGN, SIZE) \ + static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \ + BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \ + } \ + BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE) + +BM_BundleAlignment(1, 512); +BM_BundleAlignment(1, 4096); +BM_BundleAlignment(1, 1048576); +BM_BundleAlignment(4096, 512); +BM_BundleAlignment(4096, 4096); +BM_BundleAlignment(4096, 1048576); + +} // namespace tensorflow diff --git a/tensorflow/core/util/leveldb_save.h b/tensorflow/core/util/leveldb_save.h new file mode 100644 index 00000000000..61d0593aa70 --- /dev/null +++ b/tensorflow/core/util/leveldb_save.h @@ -0,0 +1,112 @@ +// +// Created by Administrator on 2022/7/31. +// + +#ifndef CPP_TUTORIAL_LEVELDB_SAVE_H +#define CPP_TUTORIAL_LEVELDB_SAVE_H + +#include "leveldb/db.h" +#include "leveldb/comparator.h" +#include "leveldb/write_batch.h" + +#include + +using leveldb::DB; +using leveldb::Options; +using leveldb::ReadOptions; +using leveldb::WriteBatch; +using leveldb::WriteOptions; + +namespace tensorflow{ + template + class LevelDBSave{ +public: + LevelDbSave(std::string DBPath){ + DBPath = "testdb"; + options.create_if_missing = true; + leveldb::Status status = leveldb::DB::Open(options, DBPath, &db); + if(!status.ok()){ + cout << "open db failed" << endl; + } + } + ~LevelDbSave(){ + delete db; + } + void insert(K key, V value){ + string key_str = key.toString(); + string value_str = value.toString(); + db->Put(WriteOptions(), key_str, value_str); + } + void remove(K key){ + string key_str = key.toString(); + db->Delete(WriteOptions(), key_str); + } + void update(K key, V value){ + string key_str = key.toString(); + string value_str = value.toString(); + db->Put(WriteOptions(), key_str, value_str); + } + void query(K key){ + string keyStr = keyToString(key); + string valueStr; + db->Get(ReadOptions(), keyStr, &valueStr); + cout << valueStr << endl; + } + void queryAll(){ + leveldb::Iterator* it = db->NewIterator(ReadOptions()); + for(it->SeekToFirst(); it->Valid(); it->Next()){ + cout << it->key() << ":" << it->value() << endl; + } + delete it; + } + void save(K key, V value){ + string keyStr = keyToString(key); + string valueStr = valueToString(value); + writeBatch.Put(keyStr, valueStr); + } + void save(K key, V value, string tableName){ + string keyStr = keyToString(key); + string valueStr = valueToString(value); + writeBatch.Put(tableName + keyStr, valueStr); + } + void save(K key, V value, string tableName, string keyPrefix){ + string keyStr = keyToString(key); + string valueStr = valueToString(value); + writeBatch.Put(tableName + keyPrefix + keyStr, valueStr); + } + void save(K key, V value, string tableName, string keyPrefix, string keySuffix){ + string keyStr = keyToString(key); + string valueStr = valueToString(value); + writeBatch.Put(tableName + keyPrefix + keyStr + keySuffix, valueStr); + } + void save(K key, V value, string tableName, string keyPrefix, string keySuffix, string keySuffix2){ + string keyStr = keyToString(key); + string valueStr = valueToString(value); + writeBatch.Put(tableName + keyPrefix + keyStr + keySuffix + keySuffix2, valueStr); + } + void save(K key, V value, string tableName, string keyPrefix, string keySuffix, string keySuffix2, string keySuffix3){ + string keyStr = keyToString(key); + string valueStr = valueToString(value); + } + + private: + string keyToString(K key){ + stringstream ss; + ss << key; + return ss.str(); + } + string valueToString(V value){ + stringstream ss; + ss << value; + return ss.str(); + } + Options options; + DB* db; + WriteBatch writeBatch; + std::string DBPath; + + + }; + +} +#endif //CPP_TUTORIAL_LEVELDB_SAVE_H