NeuralNet 1.0
Loading...
Searching...
No Matches
CSVLogger.hpp
1#pragma once
2
3#include <fstream>
4#include <stdexcept>
5#include <string>
6#include <vector>
7
8#include "Callback.hpp"
9#include "utils/Functions.hpp" // fileExistsWithExtension
10
11namespace NeuralNet {
12class CSVLogger : public Callback {
13 public:
20 CSVLogger(const std::string &filepath, const std::string &separator = ",") {
21 assert(fileHasExtension(filepath, ".csv") &&
22 "filepath must have .csv extension");
23 this->filepath = filepath;
24 this->separator = separator;
25 };
26
27 void onEpochBegin(Model &model) override {};
28
37 void onEpochEnd(Model &model) override {
38 std::unordered_map<std::string, Logs> logs = getLogs(model);
39 std::vector<double> row;
40
41 row.reserve(logs.size());
42
43 std::transform(logs.begin(), logs.end(), std::back_inserter(row),
44 [](const auto &log) {
45 const auto &value = log.second;
46 if (std::holds_alternative<int>(value)) {
47 return static_cast<double>(std::get<int>(value));
48 }
49 return std::get<double>(value);
50 });
51
52 data.push_back(row);
53 };
54
62 void onTrainBegin(Model &model) override {
63 std::unordered_map<std::string, Logs> logs = getLogs(model);
64 // Initializing the headers with the logs keys
65 for (const auto &log : logs) {
66 const auto &value = log.second;
67 if (std::holds_alternative<int>(value) ||
68 std::holds_alternative<double>(value)) {
69 headers.push_back(log.first);
70 }
71 };
72 };
73
81 void onTrainEnd(Model &model) override {
82 std::ofstream file(filepath);
83
84 if (!file.is_open()) {
85 throw std::runtime_error("Couldn't open csv file");
86 }
87
88 file << formatRow(headers);
89
90 std::for_each(data.begin(), data.end(),
91 [&file, this](auto &row) { file << this->formatRow(row); });
92
93 file.close();
94 };
95
96 void onBatchBegin(Model &model) override {};
97 void onBatchEnd(Model &model) override {};
98
99 private:
100 std::string filepath;
101 std::string separator;
102 std::vector<std::string> headers;
103 std::vector<std::vector<double>> data;
104
113 template <typename T>
114 std::string formatRow(const std::vector<T> &v) {
115 std::string csvRow;
116
117 for (T el : v) {
118 csvRow += std::to_string(el) + separator;
119 }
120
121 // Remove last ","
122 csvRow.pop_back();
123
124 return csvRow + "\n";
125 };
126
127 std::string formatRow(const std::vector<std::string> &v) {
128 std::string csvRow;
129
130 for (const std::string &el : v) {
131 csvRow += el + separator;
132 }
133
134 // Remove last ","
135 csvRow.pop_back();
136
137 return csvRow + "\n";
138 };
139};
140} // namespace NeuralNet
Definition CSVLogger.hpp:12
void onEpochEnd(Model &model) override
This method will be called at the end of each epoch.
Definition CSVLogger.hpp:37
void onTrainEnd(Model &model) override
This method will be called at the end of the training.
Definition CSVLogger.hpp:81
CSVLogger(const std::string &filepath, const std::string &separator=",")
CSVLogger is a Callback that streams epoch results to a csv file.
Definition CSVLogger.hpp:20
void onTrainBegin(Model &model) override
This method will be called at the beginning of the training.
Definition CSVLogger.hpp:62
Definition Callback.hpp:19
Definition Model.hpp:17