-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdigit_detector.h
35 lines (27 loc) · 953 Bytes
/
digit_detector.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#ifndef SUDOKU__DIGIT_DETECTOR_H_
#define SUDOKU__DIGIT_DETECTOR_H_
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "opencv2/core.hpp"
#include "opencv2/ml.hpp"
namespace sudoku {
// Detects Sudoku digits
class DigitDetector {
public:
// Loads saved model. In any errors crashes binary with CHECK.
void Init(absl::string_view model_path);
// Detects image. Returns std::nullopt if image could not be recognized.
absl::optional<int32_t> Detect(const cv::Mat& image) const;
// Uses StatModel::train for training. May throw an exceptions from OpenCV.
// Otherwise, returns true.
bool Train(absl::string_view mnist_directory, absl::string_view model_path,
size_t synthetic_count);
template <typename T>
cv::Ptr<T> GetModelAs() const {
return model_.dynamicCast<T>();
}
private:
cv::Ptr<cv::ml::StatModel> model_;
};
} // namespace sudoku
#endif // SUDOKU__DIGIT_DETECTOR_H_