NeuralNet 1.0
Loading...
Searching...
No Matches
Model.hpp
1#pragma once
2
3#include <cereal/access.hpp>
4#include <cereal/archives/binary.hpp>
5#include <cereal/archives/portable_binary.hpp>
6#include <cereal/types/base_class.hpp>
7#include <cereal/types/common.hpp>
8#include <cereal/types/polymorphic.hpp>
9#include <csignal>
10#include <fstream>
11#include <string>
12#include <type_traits>
13
14#include "utils/Functions.hpp"
15
16namespace NeuralNet {
17class Model {
18 public:
26 template <typename T, typename = typename std::enable_if<
27 std::is_base_of<Model, T>::value>::type>
28 static void save_to_file(const std::string &filename, T model) {
29 // Serializing model to a binary file
30 std::ofstream file(filename, std::ios::binary);
31 cereal::BinaryOutputArchive archive(file);
32 archive(model);
33 };
34
47 template <typename T, typename = typename std::enable_if<
48 std::is_base_of<Model, T>::value>::type>
49 static void load_from_file(const std::string &filename, T &model) {
50 // Making sure the file exists and is binary
51 assert(fileExistsWithExtension(filename, ".bin") &&
52 "The file doesn't exists or is not binary '.bin'");
53
54 // Deserializing the model from the binary file
55 std::ifstream file(filename, std::ios::binary);
56 cereal::BinaryInputArchive archive(file);
57 archive(model);
58 };
59
60 virtual void to_file(const std::string &filename) = 0;
61 virtual void from_file(const std::string &filename) = 0;
62
63 // Declare at least one virtual function
64 virtual ~Model() = default;
65
66 private:
67 friend class cereal::access;
68
69 // Serialization function
70 template <class Archive>
71 void save(Archive &archive) const {
72 archive(loss, accuracy);
73 };
74
75 template <class Archive>
76 void load(Archive &archive) {
77 archive(loss, accuracy);
78 };
79
80 protected:
81 friend class Callback;
82 int cEpoch = 0; // Current epoch
83 double loss = 0, accuracy = 0, testLoss = 0, testAccuracy = 0;
84
85 void registerSignals() const {
86 // Registering signals
87 signal(SIGINT, signalHandler);
88 signal(SIGTERM, signalHandler);
89 }
90};
91} // namespace NeuralNet
92
93CEREAL_REGISTER_TYPE(NeuralNet::Model);
Definition Model.hpp:17
static void save_to_file(const std::string &filename, T model)
This method will save (by serializing) the model passed as argument to a .bin file.
Definition Model.hpp:28
static void load_from_file(const std::string &filename, T &model)
This static method loads a Model from a file and assigns it to the supposedly "empty" model passed as...
Definition Model.hpp:49