NeuralNet 1.0
Loading...
Searching...
No Matches
Callback.hpp
1#pragma once
2
3#include <functional>
4#include <memory>
5#include <stdexcept>
6#include <string>
7#include <unordered_map>
8#include <utility>
9#include <variant>
10#include <vector>
11
12#include "Model.hpp"
13#include "utils/Variants.hpp"
14
15namespace NeuralNet {
16
17class Model;
18
19class Callback {
20 public:
21 virtual void onTrainBegin(Model &model) = 0;
22 virtual void onTrainEnd(Model &model) = 0;
23 virtual void onEpochBegin(Model &model) = 0;
24 virtual void onEpochEnd(Model &model) = 0;
25 virtual void onBatchBegin(Model &model) = 0;
26 virtual void onBatchEnd(Model &model) = 0;
27
28 virtual ~Callback() = default;
29
42 template <typename T>
43 static void callMethod(std::shared_ptr<T> callback,
44 const std::string &methodName, Model &model) {
45 static const std::unordered_map<std::string,
46 std::function<void(T *, Model &)>>
47 methods = {{"onTrainBegin",
48 [](T *callback, Model &model) {
49 return callback->onTrainBegin(model);
50 }},
51 {"onTrainEnd",
52 [](T *callback, Model &model) {
53 return callback->onTrainEnd(model);
54 }},
55 {"onEpochBegin",
56 [](T *callback, Model &model) {
57 return callback->onEpochBegin(model);
58 }},
59 {"onEpochEnd",
60 [](T *callback, Model &model) {
61 return callback->onEpochEnd(model);
62 }},
63 {"onBatchBegin",
64 [](T *callback, Model &model) {
65 return callback->onBatchBegin(model);
66 }},
67 {"onBatchEnd", [](T *callback, Model &model) {
68 return callback->onBatchEnd(model);
69 }}};
70
71 auto it = methods.find(methodName);
72
73 if (it == methods.end()) return;
74
75 it->second(callback.get(), model);
76 }
77
78 protected:
79 static void checkMetric(const std::string &metric,
80 const std::vector<std::string> &metrics) {
81 if (std::find(metrics.begin(), metrics.end(), metric) == metrics.end())
82 throw std::invalid_argument("Metric not found");
83 };
84
85 static std::unordered_map<std::string, Logs> getLogs(Model &model) {
86 std::unordered_map<std::string, Logs> logs;
87
88 logs["EPOCH"] = model.cEpoch;
89 logs["ACCURACY"] = model.accuracy;
90 logs["LOSS"] = model.loss;
91 logs["TEST_ACCURACY"] = model.testAccuracy;
92 logs["TEST_LOSS"] = model.testLoss;
93
94 return logs;
95 };
96};
97} // namespace NeuralNet
Definition Callback.hpp:19
static void callMethod(std::shared_ptr< T > callback, const std::string &methodName, Model &model)
Calls the method of the callback with the given logs.
Definition Callback.hpp:43
Definition Model.hpp:17