9#include "utils/Functions.hpp"
27 checkMetric(metric, metrics);
28 this->metric = metric;
29 this->minDelta = minDelta;
30 this->patience = patience;
33 void onEpochBegin(
Model& model)
override {};
47 std::unordered_map<std::string, Logs> logs = getLogs(model);
48 auto it = logs.find(metric);
50 if (it == logs.end())
throw std::invalid_argument(
"Metric not found");
52 double currentMetric = std::get<double>(it->second);
54 if (previousMetric == 0) {
55 previousMetric = currentMetric;
59 double absCurrentDelta = std::abs(currentMetric - previousMetric);
61 patience = absCurrentDelta <= minDelta ? patience - 1 : patience;
62 previousMetric = currentMetric;
64 if (patience < 0)
throw std::runtime_error(
"Early stopping");
67 void onTrainBegin(
Model& model)
override {};
68 void onTrainEnd(Model& model)
override {};
69 void onBatchBegin(Model& model)
override {};
70 void onBatchEnd(Model& model)
override {};
72 ~EarlyStopping()
override =
default;
76 double minDelta, previousMetric = 0;
78 std::vector<std::string> metrics = {
Definition Callback.hpp:19
Definition EarlyStopping.hpp:12
void onEpochEnd(Model &model) override
This method will be called at the end of each epoch.
Definition EarlyStopping.hpp:46
EarlyStopping(const std::string &metric="LOSS", double minDelta=0, int patience=0)
EarlyStopping is a Callback that stops training when a monitored metric has stopped improving.
Definition EarlyStopping.hpp:25