diff options
Diffstat (limited to 'tesseract/unittest/lstm_test.h')
-rw-r--r-- | tesseract/unittest/lstm_test.h | 189 |
1 files changed, 189 insertions, 0 deletions
diff --git a/tesseract/unittest/lstm_test.h b/tesseract/unittest/lstm_test.h new file mode 100644 index 00000000..4f3d9572 --- /dev/null +++ b/tesseract/unittest/lstm_test.h @@ -0,0 +1,189 @@ +// (C) Copyright 2017, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TESSERACT_UNITTEST_LSTM_TEST_H_ +#define TESSERACT_UNITTEST_LSTM_TEST_H_ + +#include <memory> +#include <string> +#include <utility> + +#include "include_gunit.h" + +#include "absl/strings/str_cat.h" +#include "tprintf.h" +#include "helpers.h" + +#include "functions.h" +#include "lang_model_helpers.h" +#include "log.h" // for LOG +#include "lstmtrainer.h" +#include "unicharset.h" + +namespace tesseract { + +#if DEBUG_DETAIL == 0 +// Number of iterations to run all the trainers. +const int kTrainerIterations = 600; +// Number of iterations between accuracy checks. +const int kBatchIterations = 100; +#else +// Number of iterations to run all the trainers. +const int kTrainerIterations = 2; +// Number of iterations between accuracy checks. +const int kBatchIterations = 1; +#endif + +// The fixture for testing LSTMTrainer. +class LSTMTrainerTest : public testing::Test { + protected: + void SetUp() { + std::locale::global(std::locale("")); + file::MakeTmpdir(); + } + + LSTMTrainerTest() {} + std::string TestDataNameToPath(const std::string& name) { + return file::JoinPath(TESTDATA_DIR, + "" + name); + } + std::string TessDataNameToPath(const std::string& name) { + return file::JoinPath(TESSDATA_DIR, + "" + name); + } + std::string TestingNameToPath(const std::string& name) { + return file::JoinPath(TESTING_DIR, + "" + name); + } + + void SetupTrainerEng(const std::string& network_spec, const std::string& model_name, + bool recode, bool adam) { + SetupTrainer(network_spec, model_name, "eng/eng.unicharset", + "eng.Arial.exp0.lstmf", recode, adam, 5e-4, false, "eng"); + } + void SetupTrainer(const std::string& network_spec, const std::string& model_name, + const std::string& unicharset_file, const std::string& lstmf_file, + bool recode, bool adam, double learning_rate, + bool layer_specific, const std::string& kLang) { +// constexpr char kLang[] = "eng"; // Exact value doesn't matter. + std::string unicharset_name = TestDataNameToPath(unicharset_file); + UNICHARSET unicharset; + ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false)); + std::string script_dir = file::JoinPath( + LANGDATA_DIR, ""); + std::vector<STRING> words; + EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, "", FLAGS_test_tmpdir, + kLang, !recode, words, words, words, false, + nullptr, nullptr)); + std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name); + std::string checkpoint_path = model_path + "_checkpoint"; + trainer_.reset(new LSTMTrainer(model_path.c_str(), checkpoint_path.c_str(), + 0, 0)); + trainer_->InitCharSet(file::JoinPath(FLAGS_test_tmpdir, kLang, + absl::StrCat(kLang, ".traineddata"))); + int net_mode = adam ? NF_ADAM : 0; + // Adam needs a higher learning rate, due to not multiplying the effective + // rate by 1/(1-momentum). + if (adam) learning_rate *= 20.0; + if (layer_specific) net_mode |= NF_LAYER_SPECIFIC_LR; + EXPECT_TRUE(trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1, + learning_rate, 0.9, 0.999)); + std::vector<STRING> filenames; + filenames.push_back(STRING(TestDataNameToPath(lstmf_file).c_str())); + EXPECT_TRUE(trainer_->LoadAllTrainingData(filenames, CS_SEQUENTIAL, false)); + LOG(INFO) << "Setup network:" << model_name << "\n" ; + } + // Trains for a given number of iterations and returns the char error rate. + double TrainIterations(int max_iterations) { + int iteration = trainer_->training_iteration(); + int iteration_limit = iteration + max_iterations; + double best_error = 100.0; + do { + STRING log_str; + int target_iteration = iteration + kBatchIterations; + // Train a few. + double mean_error = 0.0; + while (iteration < target_iteration && iteration < iteration_limit) { + trainer_->TrainOnLine(trainer_.get(), false); + iteration = trainer_->training_iteration(); + mean_error += trainer_->LastSingleError(ET_CHAR_ERROR); + } + trainer_->MaintainCheckpoints(nullptr, &log_str); + iteration = trainer_->training_iteration(); + mean_error *= 100.0 / kBatchIterations; + if (mean_error < best_error) best_error = mean_error; + } while (iteration < iteration_limit); + LOG(INFO) << "Trainer error rate = " << best_error << "\n"; + return best_error; + } + // Tests for a given number of iterations and returns the char error rate. + double TestIterations(int max_iterations) { + CHECK_GT(max_iterations, 0); + int iteration = trainer_->sample_iteration(); + double mean_error = 0.0; + int error_count = 0; + while (error_count < max_iterations) { + const ImageData& trainingdata = + *trainer_->mutable_training_data()->GetPageBySerial(iteration); + NetworkIO fwd_outputs, targets; + if (trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) != + UNENCODABLE) { + mean_error += trainer_->NewSingleError(ET_CHAR_ERROR); + ++error_count; + } + trainer_->SetIteration(++iteration); + } + mean_error *= 100.0 / max_iterations; + LOG(INFO) << "Tester error rate = " << mean_error << "\n" ; + return mean_error; + } + // Tests that the current trainer_ can be converted to int mode and still gets + // within 1% of the error rate. Returns the increase in error from float to + // int. + double TestIntMode(int test_iterations) { + std::vector<char> trainer_data; + EXPECT_TRUE(trainer_->SaveTrainingDump(NO_BEST_TRAINER, trainer_.get(), + &trainer_data)); + // Get the error on the next few iterations in float mode. + double float_err = TestIterations(test_iterations); + // Restore the dump, convert to int and test error on that. + EXPECT_TRUE(trainer_->ReadTrainingDump(trainer_data, trainer_.get())); + trainer_->ConvertToInt(); + double int_err = TestIterations(test_iterations); + EXPECT_LT(int_err, float_err + 1.0); + return int_err - float_err; + } + // Sets up a trainer with the given language and given recode+ctc condition. + // It then verifies that the given str encodes and decodes back to the same + // string. + void TestEncodeDecode(const std::string& lang, const std::string& str, bool recode) { + std::string unicharset_name = lang + "/" + lang + ".unicharset"; + std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf"; + SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name, + lstmf_name, recode, true, 5e-4, true, lang); + std::vector<int> labels; + EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels)); + STRING decoded = trainer_->DecodeLabels(labels); + std::string decoded_str(&decoded[0], decoded.length()); + EXPECT_EQ(str, decoded_str); + } + // Calls TestEncodeDeode with both recode on and off. + void TestEncodeDecodeBoth(const std::string& lang, const std::string& str) { + TestEncodeDecode(lang, str, false); + TestEncodeDecode(lang, str, true); + } + + std::unique_ptr<LSTMTrainer> trainer_; +}; + +} // namespace tesseract. + +#endif // THIRD_PARTY_TESSERACT_UNITTEST_LSTM_TEST_H_ |