NeuralNet 1.0
Loading...
Searching...
No Matches
Functions.hpp
1#pragma once
2
3#include <Eigen/Dense>
4#include <cmath>
5#include <cstddef>
6#include <filesystem>
7#include <fstream>
8#include <iostream>
9#include <random>
10
11namespace fs = std::filesystem;
12
13namespace NeuralNet {
14
23inline double mtRand(double min, double max) {
24 assert(min < max);
25 std::random_device rseed;
26 std::mt19937_64 rng(rseed());
27 std::uniform_real_distribution<double> dist(min, max);
28
29 return dist(rng);
30};
31
42inline std::vector<double> randDVector(int size, double min = -10,
43 double max = 10) {
44 std::vector<double> v;
45
46 for (int i = 0; i < size; i++) {
47 v.push_back(mtRand(min, max));
48 }
49
50 return v;
51}
52
62inline bool fileExistsWithExtension(const std::string &filePath,
63 const std::string &extension) {
64 // Check if the file exist
65 if (fs::exists(filePath)) {
66 // Check if the file has specified extension
67 fs::path file(filePath);
68 return file.has_extension() && file.extension() == extension;
69 }
70
71 return false;
72}
73
83inline bool fileHasExtension(const std::string &filePath,
84 const std::string &extension) {
85 fs::path file(filePath);
86 return file.has_extension() && file.extension() == extension;
87}
88
96inline bool folderExists(const std::string &folderPath) {
97 return fs::exists(folderPath) && fs::is_directory(folderPath);
98}
99
108inline std::string constructFilePath(const std::string &folderPath,
109 const std::string &fileName) {
110 std::string filepath;
111
112// Running on windows
113#ifdef _WIN32
114 if (!folderPath.empty() && folderPath.back() != "\\")
115 filepath = folderPath + "\\";
116 else
117 filepath = folderPath;
118#else
119 // Not running on Windows
120 if (!folderPath.empty() && folderPath.back() != '/')
121 filepath = folderPath + "/";
122 else
123 filepath = folderPath;
124#endif
125
126 return filepath + fileName;
127};
128
129/* MATHEMATICAL FUNCTIONS */
130
138inline constexpr double sqr(const double x) { return x * x; };
139
140/* VECTOR OPERATIONS */
141
153template <typename T>
154inline void reserve2d(std::vector<std::vector<T>> &v, int rows, int cols) {
155 // reserve space for num rows
156 v.reserve(rows);
157
158 // reserve space for each row
159 for (int i = 0; i < rows; i++) {
160 v.push_back(std::vector<T>());
161 v[i].reserve(cols);
162 }
163};
164
173inline int findRowIndexOfMaxEl(const Eigen::MatrixXd &m) {
174 // Find the maximum value in the matrix
175 double maxVal = m.maxCoeff();
176
177 // Find the row index by iterating through rows
178 for (int i = 0; i < m.rows(); ++i) {
179 if ((m.row(i).array() == maxVal).any()) {
180 return i;
181 }
182 }
183
184 // Return -1 if not found (this can be handled based on your use case)
185 return -1;
186};
187
200template <typename T>
201inline std::vector<T> flatten2DVector(const std::vector<std::vector<T>> &input,
202 size_t rows, size_t cols) {
203 // Asserting that the inputs respect the declared size
204 assert(input.size() == rows);
205 for (const std::vector<T> &row : input) {
206 assert(row.size() == cols);
207 }
208
209 std::vector<T> result;
210 result.reserve(rows * cols);
211
212 // Flatten the 2D vector
213 for (const std::vector<T> &row : input) {
214 result.insert(result.end(), row.begin(), row.end());
215 }
216
217 return result;
218}
219
229template <typename T>
230inline int findIndexOf(const std::vector<T> &v, const T &el) {
231 auto it = std::find(v.begin(), v.end(), el);
232
233 if (it == v.end()) return -1;
234
235 return it - v.begin();
236}
237
238/* MATRIX OPERATIONS */
239inline Eigen::MatrixXd zeroMatrix(const std::tuple<int, int> size) {
240 return Eigen::MatrixXd::Zero(std::get<0>(size), std::get<1>(size));
241}
242
243inline Eigen::MatrixXd vectorToMatrixXd(std::vector<std::vector<double>> &v) {
244 if (v.empty() || v[0].empty()) return Eigen::MatrixXd(0, 0);
245
246 int rows = v.size();
247 int cols = v[0].size();
248
249 // Flatten the vector of vectors into a single vector
250 std::vector<double> flat;
251 flat.reserve(rows * cols);
252 for (const auto &row : v) {
253 flat.insert(flat.end(), row.begin(), row.end());
254 }
255
256 return Eigen::Map<
257 Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
258 flat.data(), rows, cols);
259};
260
271static void randomWeightInit(Eigen::MatrixXd *weightsMatrix, double min = -1.0,
272 double max = 1.0) {
273 for (int col = 0; col < weightsMatrix->cols(); col++) {
274 for (int row = 0; row < weightsMatrix->rows(); row++) {
275 weightsMatrix->operator()(row, col) = mtRand(min, max);
276 }
277 }
278
279 return;
280};
281
291static void randomDistMatrixInit(Eigen::MatrixXd *weightsMatrix, double mean,
292 double stddev) {
293 std::random_device rseed;
294 std::default_random_engine generator(rseed());
295 std::normal_distribution<double> distribution(mean, stddev);
296
297 for (int col = 0; col < weightsMatrix->cols(); col++) {
298 for (int row = 0; row < weightsMatrix->rows(); row++) {
299 weightsMatrix->operator()(row, col) = distribution(generator);
300 }
301 }
302
303 return;
304};
305
313static Eigen::MatrixXd hardmax(const Eigen::MatrixXd &mat) {
314 Eigen::MatrixXd hardmaxMatrix = Eigen::MatrixXd::Zero(mat.rows(), mat.cols());
315
316 for (int i = 0; i < mat.rows(); ++i) {
317 int maxIndex;
318 mat.row(i).maxCoeff(&maxIndex);
319
320 hardmaxMatrix(i, maxIndex) = 1;
321 }
322
323 return hardmaxMatrix;
324}
325
335static Eigen::MatrixXd trim(const Eigen::MatrixXd &logits,
336 double threshold = 0.01) {
337 return (logits.array() < threshold).select(0, logits);
338}
339
348static Eigen::MatrixXd thresh(const Eigen::MatrixXd &logits,
349 double threshold = 0.01) {
350 return (logits.array() < threshold).select(threshold, logits);
351}
352
353/* SIGNAL HANDLING */
354static void signalHandler(int signum) {
355 std::cout << "Interrupt signal (" << signum << ") received.\n";
356
357 // cleanup and close up stuff here
358 // terminate program
359 exit(signum);
360};
361
362/* STRING OPERATIONS */
363
364static std::string removeTrailingZeros(std::string str) {
365 str.erase(str.find_last_not_of('0') + 1, std::string::npos);
366
367 if (str.back() == '.') {
368 str.pop_back();
369 }
370
371 return str;
372}
373
374} // namespace NeuralNet