NeuralNet 1.0
Loading...
Searching...
No Matches
ModelCheckpoint.hpp
1#pragma once
2
3#include <memory>
4#include <typeinfo>
5
6#include "Callback.hpp"
7#include "Model.hpp"
8#include "utils/Functions.hpp"
9
10namespace NeuralNet {
11class ModelCheckpoint : public Callback {
12 public:
13 ModelCheckpoint(const std::string &folderPath, const bool saveBestOnly = true,
14 const int numEpochs = 1, const bool verbose = false) {
15 assert(folderExists(folderPath) && "Folder doesn't exist");
16 this->folderPath = folderPath;
17 this->saveBestOnly = saveBestOnly;
18 this->numEpochs = numEpochs;
19 this->verbose = verbose;
20 };
21
22 void onEpochBegin(Model &model) override {};
23
35 void onEpochEnd(Model &model) override {
36 logs = getLogs(model);
37 int epoch = std::get<int>(logs.at("EPOCH"));
38 // to get the Model's name
39 const std::type_info &modelInfo = typeid(model);
40 std::string filename = formatCheckpointFilepath(modelInfo.name());
41
42 if ((epoch % numEpochs) != 0) return;
43
44 if (saveBestOnly) {
45 double currLoss = std::get<double>(logs.at("LOSS"));
46 double currAccuracy = std::get<double>(logs.at("ACCURACY"));
47
48 if (currLoss > bestLoss && bestAccuracy > currAccuracy) return;
49
50 // Save best model for later saving
51 bestLoss = currLoss;
52 bestAccuracy = currAccuracy;
53 }
54
55 if (verbose) verboseOutput(filename);
56
57 model.to_file(filename);
58 };
59
60 void onTrainBegin(Model &model) override {};
61 void onTrainEnd(Model &model) override {};
62 void onBatchBegin(Model &model) override {};
63 void onBatchEnd(Model &model) override {};
64
65 ~ModelCheckpoint() override = default;
66
67 private:
68 std::string folderPath, filename;
69 bool saveBestOnly, verbose;
70 double bestLoss = std::numeric_limits<double>::max(), bestAccuracy = 0;
71 int numEpochs, bestEpoch;
72 std::unordered_map<std::string, Logs> logs;
73
74 void verboseOutput(const std::string filename) {
75 std::cout << "Saving checkpoint in file: " << filename << "\n";
76 }
77
78 std::string formatCheckpointFilepath(const std::string &modelName) {
79 std::string checkpointId =
80 saveBestOnly ? "best" : std::to_string(std::get<int>(logs["EPOCH"]));
81 std::string fileName = modelName + "-checkpoint-" + checkpointId + ".bin";
82 return constructFilePath(folderPath, fileName);
83 }
84};
85} // namespace NeuralNet
Definition Callback.hpp:19
Definition ModelCheckpoint.hpp:11
void onEpochEnd(Model &model) override
This method will be called at the end of each epoch.
Definition ModelCheckpoint.hpp:35
Definition Model.hpp:17