diff --git a/csrc/mmdeploy/codebase/mmocr/attn.cpp b/csrc/mmdeploy/codebase/mmocr/attn.cpp new file mode 100644 index 0000000000..e40c804577 --- /dev/null +++ b/csrc/mmdeploy/codebase/mmocr/attn.cpp @@ -0,0 +1,108 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "mmdeploy/core/device.h" +#include "mmdeploy/core/registry.h" +#include "mmdeploy/core/tensor.h" +#include "mmdeploy/core/utils/device_utils.h" +#include "base.h" +#include "mmocr.h" + +namespace mmdeploy::mmocr { + +using std::string; +using std::vector; + +class AttnConvertor : public BaseConvertor { + public: + explicit AttnConvertor(const Value& cfg) : BaseConvertor(cfg) { + auto model = cfg["context"]["model"].get(); + if (!cfg.contains("params")) { + MMDEPLOY_ERROR("'params' is required, but it's not in the config"); + throw_exception(eInvalidArgument); + } + auto& _cfg = cfg["params"]; + + // unknwon + if (_cfg.value("with_unknown", false)) { + unknown_idx_ = static_cast(idx2char_.size()); + idx2char_.emplace_back(""); + } + + // BOS/EOS + constexpr char start_end_token[] = ""; + constexpr char padding_token[] = ""; + start_idx_ = static_cast(idx2char_.size()); + end_idx_ = start_idx_; + idx2char_.emplace_back(start_end_token); + if (!_cfg.value("start_end_same", true)) { + end_idx_ = static_cast(idx2char_.size()); + idx2char_.emplace_back(start_end_token); + } + + // padding + padding_idx_ = static_cast(idx2char_.size()); + idx2char_.emplace_back(padding_token); + + model_ = model; + } + + Result operator()(const Value& _data, const Value& _prob) { + auto d_conf = _prob["output"].get(); + + if (!(d_conf.shape().size() == 3 && d_conf.data_type() == DataType::kFLOAT)) { + MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", d_conf.shape(), + (int)d_conf.data_type()); + return Status(eNotSupported); + } + + OUTCOME_TRY(auto h_conf, MakeAvailableOnDevice(d_conf, Device{0}, stream())); + OUTCOME_TRY(stream().Wait()); + + auto data = h_conf.data(); + + auto shape = d_conf.shape(); + auto w = static_cast(shape[1]); + auto c = static_cast(shape[2]); + + auto valid_ratio = _data["img_metas"]["valid_ratio"].get(); + auto [indexes, scores] = Tensor2Idx(data, w, c, valid_ratio); + + auto text = Idx2Str(indexes); + MMDEPLOY_DEBUG("text: {}", text); + + TextRecognition output{text, scores}; + + return make_pointer(to_value(output)); + } + + std::pair, vector > Tensor2Idx(const float* data, int w, int c, + float valid_ratio) { + auto decode_len = std::min(w, static_cast(std::ceil(w * valid_ratio))); + vector indexes; + indexes.reserve(decode_len); + vector scores; + scores.reserve(decode_len); + + for (int t = 0; t < decode_len; ++t, data += c) { + auto iter = std::max_element(data, data + c); + auto index = static_cast(iter - data); + if (index == padding_idx_) continue; + if (index == end_idx_) break; + indexes.push_back(index); + scores.push_back(*iter); + } + + return {indexes, scores}; + } + +private: + int start_idx_{-1}; + int end_idx_{-1}; + int padding_idx_{-1}; +}; + +MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMOCR, AttnConvertor); + +} // namespace mmdeploy::mmocr diff --git a/csrc/mmdeploy/codebase/mmocr/base.cpp b/csrc/mmdeploy/codebase/mmocr/base.cpp new file mode 100644 index 0000000000..eafbf6915d --- /dev/null +++ b/csrc/mmdeploy/codebase/mmocr/base.cpp @@ -0,0 +1,78 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "mmdeploy/codebase/mmocr/base.h" + +namespace mmdeploy { +namespace mmocr { + +using std::string; +using std::vector; + +BaseConvertor::BaseConvertor(const Value& cfg) : MMOCR(cfg) { + auto model = cfg["context"]["model"].get(); + if (!cfg.contains("params")) { + MMDEPLOY_ERROR("'params' is required, but it's not in the config"); + throw_exception(eInvalidArgument); + } + // BaseConverter + auto& _cfg = cfg["params"]; + if (_cfg.contains("dict_file")) { + auto filename = _cfg["dict_file"].get(); + auto content = model.ReadFile(filename).value(); + idx2char_ = SplitLines(content); + } else if (_cfg.contains("dict_list")) { + from_value(_cfg["dict_list"], idx2char_); + } else if (_cfg.contains("dict_type")) { + auto dict_type = _cfg["dict_type"].get(); + if (dict_type == "DICT36") { + idx2char_ = SplitChars(DICT36); + } else if (dict_type == "DICT90") { + idx2char_ = SplitChars(DICT90); + } else { + MMDEPLOY_ERROR("unknown dict_type: {}", dict_type); + throw_exception(eInvalidArgument); + } + } else { + MMDEPLOY_ERROR("either dict_file, dict_list or dict_type must be specified"); + throw_exception(eInvalidArgument); + } + + model_ = model; +} + +string BaseConvertor::Idx2Str(const vector& indexes) { + size_t count = 0; + for (const auto& idx : indexes) { + count += idx2char_[idx].size(); + } + std::string text; + text.reserve(count); + for (const auto& idx : indexes) { + text += idx2char_[idx]; + } + return text; +} + +vector BaseConvertor::SplitLines(const string& s) { + std::istringstream is(s); + vector ret; + string line; + while (std::getline(is, line)) { + ret.push_back(std::move(line)); + } + return ret; +} + +vector BaseConvertor::SplitChars(const string& s) { + vector ret; + ret.reserve(s.size()); + for (char c : s) { + ret.push_back({c}); + } + return ret; +} + +} +} \ No newline at end of file diff --git a/csrc/mmdeploy/codebase/mmocr/base.h b/csrc/mmdeploy/codebase/mmocr/base.h new file mode 100644 index 0000000000..0ec7f5f15e --- /dev/null +++ b/csrc/mmdeploy/codebase/mmocr/base.h @@ -0,0 +1,40 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include + +#include "mmdeploy/core/model.h" +#include "mmocr.h" + +namespace mmdeploy::mmocr { + +using std::string; +using std::vector; + +class BaseConvertor : public MMOCR { + public: + explicit BaseConvertor(const Value& cfg); + + string Idx2Str(const vector& indexes); + + protected: + static vector SplitLines(const string& s); + + static vector SplitChars(const string& s); + + static constexpr const auto DICT36 = R"(0123456789abcdefghijklmnopqrstuvwxyz)"; + static constexpr const auto DICT90 = R"(0123456789abcdefghijklmnopqrstuvwxyz)" + R"(ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())" + R"(*+,-./:;<=>?@[\]_`~)"; + + static constexpr const auto kHost = Device(0); + + Model model_; + + static constexpr const int blank_idx_{0}; + int unknown_idx_{-1}; + + vector idx2char_; +}; + +} // namespace mmdeploy::mmocr \ No newline at end of file diff --git a/csrc/mmdeploy/codebase/mmocr/crnn.cpp b/csrc/mmdeploy/codebase/mmocr/crnn.cpp index 412cd3a1d0..8dfb63b3a3 100644 --- a/csrc/mmdeploy/codebase/mmocr/crnn.cpp +++ b/csrc/mmdeploy/codebase/mmocr/crnn.cpp @@ -11,43 +11,17 @@ #include "mmdeploy/core/utils/formatter.h" #include "mmdeploy/core/value.h" #include "mmdeploy/experimental/module_adapter.h" -#include "mmocr.h" +#include "base.h" namespace mmdeploy::mmocr { using std::string; using std::vector; -class CTCConvertor : public MMOCR { +class CTCConvertor : public BaseConvertor { public: - explicit CTCConvertor(const Value& cfg) : MMOCR(cfg) { - auto model = cfg["context"]["model"].get(); - if (!cfg.contains("params")) { - MMDEPLOY_ERROR("'params' is required, but it's not in the config"); - throw_exception(eInvalidArgument); - } - // BaseConverter + explicit CTCConvertor(const Value& cfg) : BaseConvertor(cfg) { auto& _cfg = cfg["params"]; - if (_cfg.contains("dict_file")) { - auto filename = _cfg["dict_file"].get(); - auto content = model.ReadFile(filename).value(); - idx2char_ = SplitLines(content); - } else if (_cfg.contains("dict_list")) { - from_value(_cfg["dict_list"], idx2char_); - } else if (_cfg.contains("dict_type")) { - auto dict_type = _cfg["dict_type"].get(); - if (dict_type == "DICT36") { - idx2char_ = SplitChars(DICT36); - } else if (dict_type == "DICT90") { - idx2char_ = SplitChars(DICT90); - } else { - MMDEPLOY_ERROR("unknown dict_type: {}", dict_type); - throw_exception(eInvalidArgument); - } - } else { - MMDEPLOY_ERROR("either dict_file, dict_list or dict_type must be specified"); - throw_exception(eInvalidArgument); - } // CTCConverter idx2char_.insert(begin(idx2char_), ""); @@ -55,8 +29,6 @@ class CTCConvertor : public MMOCR { unknown_idx_ = static_cast(idx2char_.size()); idx2char_.emplace_back(""); } - - model_ = model; } Result operator()(const Value& _data, const Value& _prob) { @@ -110,19 +82,6 @@ class CTCConvertor : public MMOCR { return {indexes, scores}; } - string Idx2Str(const vector& indexes) { - size_t count = 0; - for (const auto& idx : indexes) { - count += idx2char_[idx].size(); - } - std::string text; - text.reserve(count); - for (const auto& idx : indexes) { - text += idx2char_[idx]; - } - return text; - } - // TODO: move softmax & top-k into model static void softmax(const float* src, float* dst, int n) { auto max_val = *std::max_element(src, src + n); @@ -136,39 +95,6 @@ class CTCConvertor : public MMOCR { } } - protected: - static vector SplitLines(const string& s) { - std::istringstream is(s); - vector ret; - string line; - while (std::getline(is, line)) { - ret.push_back(std::move(line)); - } - return ret; - } - - static vector SplitChars(const string& s) { - vector ret; - ret.reserve(s.size()); - for (char c : s) { - ret.push_back({c}); - } - return ret; - } - - static constexpr const auto DICT36 = R"(0123456789abcdefghijklmnopqrstuvwxyz)"; - static constexpr const auto DICT90 = R"(0123456789abcdefghijklmnopqrstuvwxyz)" - R"(ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())" - R"(*+,-./:;<=>?@[\]_`~)"; - - static constexpr const auto kHost = Device(0); - - Model model_; - - static constexpr const int blank_idx_{0}; - int unknown_idx_{-1}; - - vector idx2char_; }; MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMOCR, CTCConvertor);