11namespace fs = std::filesystem;
23inline double mtRand(
double min,
double max) {
25 std::random_device rseed;
26 std::mt19937_64 rng(rseed());
27 std::uniform_real_distribution<double> dist(min, max);
42inline std::vector<double> randDVector(
int size,
double min = -10,
44 std::vector<double> v;
46 for (
int i = 0; i < size; i++) {
47 v.push_back(mtRand(min, max));
62inline bool fileExistsWithExtension(
const std::string &filePath,
63 const std::string &extension) {
65 if (fs::exists(filePath)) {
67 fs::path file(filePath);
68 return file.has_extension() && file.extension() == extension;
83inline bool fileHasExtension(
const std::string &filePath,
84 const std::string &extension) {
85 fs::path file(filePath);
86 return file.has_extension() && file.extension() == extension;
96inline bool folderExists(
const std::string &folderPath) {
97 return fs::exists(folderPath) && fs::is_directory(folderPath);
108inline std::string constructFilePath(
const std::string &folderPath,
109 const std::string &fileName) {
110 std::string filepath;
114 if (!folderPath.empty() && folderPath.back() !=
"\\")
115 filepath = folderPath +
"\\";
117 filepath = folderPath;
120 if (!folderPath.empty() && folderPath.back() !=
'/')
121 filepath = folderPath +
"/";
123 filepath = folderPath;
126 return filepath + fileName;
138inline constexpr double sqr(
const double x) {
return x * x; };
154inline void reserve2d(std::vector<std::vector<T>> &v,
int rows,
int cols) {
159 for (
int i = 0; i < rows; i++) {
160 v.push_back(std::vector<T>());
173inline int findRowIndexOfMaxEl(
const Eigen::MatrixXd &m) {
175 double maxVal = m.maxCoeff();
178 for (
int i = 0; i < m.rows(); ++i) {
179 if ((m.row(i).array() == maxVal).any()) {
201inline std::vector<T> flatten2DVector(
const std::vector<std::vector<T>> &input,
202 size_t rows,
size_t cols) {
204 assert(input.size() == rows);
205 for (
const std::vector<T> &row : input) {
206 assert(row.size() == cols);
209 std::vector<T> result;
210 result.reserve(rows * cols);
213 for (
const std::vector<T> &row : input) {
214 result.insert(result.end(), row.begin(), row.end());
230inline int findIndexOf(
const std::vector<T> &v,
const T &el) {
231 auto it = std::find(v.begin(), v.end(), el);
233 if (it == v.end())
return -1;
235 return it - v.begin();
239inline Eigen::MatrixXd zeroMatrix(
const std::tuple<int, int> size) {
240 return Eigen::MatrixXd::Zero(std::get<0>(size), std::get<1>(size));
243inline Eigen::MatrixXd vectorToMatrixXd(std::vector<std::vector<double>> &v) {
244 if (v.empty() || v[0].empty())
return Eigen::MatrixXd(0, 0);
247 int cols = v[0].size();
250 std::vector<double> flat;
251 flat.reserve(rows * cols);
252 for (
const auto &row : v) {
253 flat.insert(flat.end(), row.begin(), row.end());
257 Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
258 flat.data(), rows, cols);
271static void randomWeightInit(Eigen::MatrixXd *weightsMatrix,
double min = -1.0,
273 for (
int col = 0; col < weightsMatrix->cols(); col++) {
274 for (
int row = 0; row < weightsMatrix->rows(); row++) {
275 weightsMatrix->operator()(row, col) = mtRand(min, max);
291static void randomDistMatrixInit(Eigen::MatrixXd *weightsMatrix,
double mean,
293 std::random_device rseed;
294 std::default_random_engine generator(rseed());
295 std::normal_distribution<double> distribution(mean, stddev);
297 for (
int col = 0; col < weightsMatrix->cols(); col++) {
298 for (
int row = 0; row < weightsMatrix->rows(); row++) {
299 weightsMatrix->operator()(row, col) = distribution(generator);
313static Eigen::MatrixXd hardmax(
const Eigen::MatrixXd &mat) {
314 Eigen::MatrixXd hardmaxMatrix = Eigen::MatrixXd::Zero(mat.rows(), mat.cols());
316 for (
int i = 0; i < mat.rows(); ++i) {
318 mat.row(i).maxCoeff(&maxIndex);
320 hardmaxMatrix(i, maxIndex) = 1;
323 return hardmaxMatrix;
335static Eigen::MatrixXd trim(
const Eigen::MatrixXd &logits,
336 double threshold = 0.01) {
337 return (logits.array() < threshold).select(0, logits);
348static Eigen::MatrixXd thresh(
const Eigen::MatrixXd &logits,
349 double threshold = 0.01) {
350 return (logits.array() < threshold).select(threshold, logits);
354static void signalHandler(
int signum) {
355 std::cout <<
"Interrupt signal (" << signum <<
") received.\n";
364static std::string removeTrailingZeros(std::string str) {
365 str.erase(str.find_last_not_of(
'0') + 1, std::string::npos);
367 if (str.back() ==
'.') {