13 ModelCheckpoint(
const std::string &folderPath,
const bool saveBestOnly =
true,
14 const int numEpochs = 1,
const bool verbose =
false) {
15 assert(folderExists(folderPath) &&
"Folder doesn't exist");
16 this->folderPath = folderPath;
17 this->saveBestOnly = saveBestOnly;
18 this->numEpochs = numEpochs;
19 this->verbose = verbose;
22 void onEpochBegin(
Model &model)
override {};
36 logs = getLogs(model);
37 int epoch = std::get<int>(logs.at(
"EPOCH"));
39 const std::type_info &modelInfo =
typeid(model);
40 std::string filename = formatCheckpointFilepath(modelInfo.name());
42 if ((epoch % numEpochs) != 0)
return;
45 double currLoss = std::get<double>(logs.at(
"LOSS"));
46 double currAccuracy = std::get<double>(logs.at(
"ACCURACY"));
48 if (currLoss > bestLoss && bestAccuracy > currAccuracy)
return;
52 bestAccuracy = currAccuracy;
55 if (verbose) verboseOutput(filename);
57 model.to_file(filename);
60 void onTrainBegin(
Model &model)
override {};
61 void onTrainEnd(Model &model)
override {};
62 void onBatchBegin(Model &model)
override {};
63 void onBatchEnd(Model &model)
override {};
65 ~ModelCheckpoint()
override =
default;
68 std::string folderPath, filename;
69 bool saveBestOnly, verbose;
70 double bestLoss = std::numeric_limits<double>::max(), bestAccuracy = 0;
71 int numEpochs, bestEpoch;
72 std::unordered_map<std::string, Logs> logs;
74 void verboseOutput(
const std::string filename) {
75 std::cout <<
"Saving checkpoint in file: " << filename <<
"\n";
78 std::string formatCheckpointFilepath(
const std::string &modelName) {
79 std::string checkpointId =
80 saveBestOnly ?
"best" : std::to_string(std::get<int>(logs[
"EPOCH"]));
81 std::string fileName = modelName +
"-checkpoint-" + checkpointId +
".bin";
82 return constructFilePath(folderPath, fileName);