6#include "Functions.hpp"
21static Eigen::MatrixXd formatLabels(std::vector<T> labels,
22 std::tuple<int, int> shape) {
23 int rows = std::get<0>(shape);
24 int cols = std::get<1>(shape);
27 assert(labels.size() == rows &&
28 "The number of labels don't match the number of inputs");
30 Eigen::MatrixXd mLabels(rows, cols);
32 if constexpr (std::is_same<T, std::vector<double>>::value) {
33 std::vector<double> flattenedVector = flatten2DVector(labels, rows, cols);
34 mLabels = Eigen::Map<Eigen::MatrixXd>(flattenedVector.data(), rows, cols);
35 }
else if constexpr (std::is_same<T, double>::value) {
36 mLabels = Eigen::MatrixXd::Zero(rows, cols);
42 for (
int i = 0; i < rows; i++) {
43 int colIndex = labels[i];
44 assert(colIndex < cols);
45 mLabels(i, colIndex) = 1;