9#include "utils/Functions.hpp"
20 CSVLogger(
const std::string &filepath,
const std::string &separator =
",") {
21 assert(fileHasExtension(filepath,
".csv") &&
22 "filepath must have .csv extension");
23 this->filepath = filepath;
24 this->separator = separator;
27 void onEpochBegin(
Model &model)
override {};
38 std::unordered_map<std::string, Logs> logs = getLogs(model);
39 std::vector<double> row;
41 row.reserve(logs.size());
43 std::transform(logs.begin(), logs.end(), std::back_inserter(row),
45 const auto &value = log.second;
46 if (std::holds_alternative<int>(value)) {
47 return static_cast<double>(std::get<int>(value));
49 return std::get<double>(value);
63 std::unordered_map<std::string, Logs> logs = getLogs(model);
65 for (
const auto &log : logs) {
66 const auto &value = log.second;
67 if (std::holds_alternative<int>(value) ||
68 std::holds_alternative<double>(value)) {
69 headers.push_back(log.first);
82 std::ofstream file(filepath);
84 if (!file.is_open()) {
85 throw std::runtime_error(
"Couldn't open csv file");
88 file << formatRow(headers);
90 std::for_each(data.begin(), data.end(),
91 [&file,
this](
auto &row) { file << this->formatRow(row); });
96 void onBatchBegin(
Model &model)
override {};
97 void onBatchEnd(Model &model)
override {};
100 std::string filepath;
101 std::string separator;
102 std::vector<std::string> headers;
103 std::vector<std::vector<double>> data;
113 template <
typename T>
114 std::string formatRow(
const std::vector<T> &v) {
118 csvRow += std::to_string(el) + separator;
124 return csvRow +
"\n";
127 std::string formatRow(
const std::vector<std::string> &v) {
130 for (
const std::string &el : v) {
131 csvRow += el + separator;
137 return csvRow +
"\n";
Definition CSVLogger.hpp:12
void onEpochEnd(Model &model) override
This method will be called at the end of each epoch.
Definition CSVLogger.hpp:37
void onTrainEnd(Model &model) override
This method will be called at the end of the training.
Definition CSVLogger.hpp:81
CSVLogger(const std::string &filepath, const std::string &separator=",")
CSVLogger is a Callback that streams epoch results to a csv file.
Definition CSVLogger.hpp:20
void onTrainBegin(Model &model) override
This method will be called at the beginning of the training.
Definition CSVLogger.hpp:62
Definition Callback.hpp:19