NeuralNet 1.0
Loading...
Searching...
No Matches
EarlyStopping.hpp
1#pragma once
2
3#include <cmath>
4#include <stdexcept>
5#include <string>
6#include <vector>
7
8#include "Callback.hpp"
9#include "utils/Functions.hpp"
10
11namespace NeuralNet {
12class EarlyStopping : public Callback {
13 public:
25 EarlyStopping(const std::string& metric = "LOSS", double minDelta = 0,
26 int patience = 0) {
27 checkMetric(metric, metrics);
28 this->metric = metric;
29 this->minDelta = minDelta;
30 this->patience = patience;
31 };
32
33 void onEpochBegin(Model& model) override {};
34
46 void onEpochEnd(Model& model) override {
47 std::unordered_map<std::string, Logs> logs = getLogs(model);
48 auto it = logs.find(metric);
49
50 if (it == logs.end()) throw std::invalid_argument("Metric not found");
51
52 double currentMetric = std::get<double>(it->second);
53
54 if (previousMetric == 0) {
55 previousMetric = currentMetric;
56 return;
57 }
58
59 double absCurrentDelta = std::abs(currentMetric - previousMetric);
60
61 patience = absCurrentDelta <= minDelta ? patience - 1 : patience;
62 previousMetric = currentMetric;
63
64 if (patience < 0) throw std::runtime_error("Early stopping");
65 };
66
67 void onTrainBegin(Model& model) override {};
68 void onTrainEnd(Model& model) override {};
69 void onBatchBegin(Model& model) override {};
70 void onBatchEnd(Model& model) override {};
71
72 ~EarlyStopping() override = default;
73
74 private:
75 std::string metric;
76 double minDelta, previousMetric = 0;
77 int patience;
78 std::vector<std::string> metrics = {
79 "LOSS", "ACCURACY"}; // Available metrics for this Callback
80};
81
82} // namespace NeuralNet
Definition Callback.hpp:19
Definition EarlyStopping.hpp:12
void onEpochEnd(Model &model) override
This method will be called at the end of each epoch.
Definition EarlyStopping.hpp:46
EarlyStopping(const std::string &metric="LOSS", double minDelta=0, int patience=0)
EarlyStopping is a Callback that stops training when a monitored metric has stopped improving.
Definition EarlyStopping.hpp:25
Definition Model.hpp:17