NeuralNet 1.0
Loading...
Searching...
No Matches
TrainingData.hpp
1#pragma once
2
3#include <algorithm>
4#include <random>
5#include <utility> // For std::pair
6#include <vector>
7
8namespace NeuralNet {
9template <typename X, typename Y>
11 friend class Network;
12
13 public:
25 TrainingData(X xTrain, Y yTrain, X xTest = X(), Y yTest = Y())
26 : xTrain(xTrain), yTrain(yTrain), xTest(xTest), yTest(yTest) {
27 assert(xTrain.size() == yTrain.size());
28 if (!xTest.empty() && !yTest.empty()) assert(xTest.size() == xTest.size());
29 }
30
31 std::vector<std::pair<X, Y>> getMiniBatches() { return this->miniBatches; }
32
39 void batch(int batchSize, bool stratified = false, bool shuffle = false,
40 bool dropLast = false, bool verbose = false) {
41 batched = true;
42 if (!stratified)
43 return normalMiniBatch(batchSize, xTrain, yTrain, shuffle, dropLast,
44 verbose);
45 return stratifiedMiniBatch(batchSize, xTrain, yTrain, shuffle, dropLast,
46 verbose);
47 };
48
49 private:
50 X xTrain;
51 Y yTrain;
52 X xTest;
53 Y yTest;
54 std::vector<std::pair<X, Y>> miniBatches;
55 bool batched = false;
56
57 void printDropLast(const size_t size, const size_t requiredSize) {
58 std::cout << "Dropping last mini-batch (size : " << size << " < "
59 << requiredSize << ")" << std::endl;
60 };
61
62 void printNBatchesCreated(const size_t nBatches) {
63 std::cout << "Total mini-batches created : " << nBatches << std::endl;
64 }
65
66 template <typename Ti, typename Tl>
67 void shuffle(std::vector<Ti>& inputs, std::vector<Tl>& labels) {
68 // Create a vector of pairs (data, label)
69 std::vector<std::pair<Ti, Tl>> combined;
70 for (size_t i = 0; i < inputs.size(); ++i) {
71 combined.emplace_back(inputs[i], labels[i]);
72 }
73
74 // Create a random number generator
75 std::random_device rd;
76 std::mt19937 g(rd());
77
78 // Shuffle the combined vector
79 std::shuffle(combined.begin(), combined.end(), g);
80
81 // Separate the shuffled data and labels
82 for (size_t i = 0; i < combined.size(); ++i) {
83 inputs[i] = combined[i].first;
84 labels[i] = combined[i].second;
85 }
86 };
87
88 void normalMiniBatch(int batchSize, X& x, Y& y, bool shuffle = false,
89 bool dropLast = false, bool verbose = false) {
90 int nInputs = x.size();
91 assert(nInputs > 0 && batchSize < nInputs);
92 int nMiniBatches = (nInputs + batchSize - 1) / batchSize;
93
94 if (shuffle) {
95 this->shuffle(x, y);
96 }
97
98 miniBatches.reserve(nMiniBatches);
99
100 for (int i = 0; i < nMiniBatches; ++i) {
101 int startIdx = i * batchSize;
102 int endIdx = std::min((i + 1) * batchSize, static_cast<int>(nInputs));
103
104 X xMiniBatch(x.begin() + startIdx, x.begin() + endIdx);
105 Y yMiniBatch(y.begin() + startIdx, y.begin() + endIdx);
106
107 std::pair<X, Y> miniBatch = std::make_pair(xMiniBatch, yMiniBatch);
108
109 if (xMiniBatch.size() < static_cast<size_t>(batchSize) && dropLast) {
110 if (verbose)
111 printDropLast(xMiniBatch.size(), static_cast<size_t>(batchSize));
112 continue;
113 }
114
115 miniBatches.push_back(miniBatch);
116 }
117
118 // empty the respective containers
119 x.erase(x.begin(), x.end());
120 y.erase(y.begin(), y.end());
121
122 if (verbose) printNBatchesCreated(miniBatches.size());
123 };
124
125 template <typename Tx, typename Ty>
126 void stratifiedMiniBatch(int batchSize, std::vector<Tx>& x,
127 std::vector<Ty>& y, bool shuffle = false,
128 bool dropLast = false, bool verbose = false) {
129 // Group data by class
130 std::map<Ty, X> classDataMap;
131 for (size_t i = 0; i < x.size(); i++) {
132 classDataMap[y[i]].push_back(x[i]);
133 }
134
135 // Calculate number of samples per class in each batch
136 std::map<Ty, int> classBatchSize;
137 int totalSamples = x.size();
138 for (const auto& classPair : classDataMap) {
139 int classCount = classPair.second.size();
140 classBatchSize[classPair.first] = (classCount * batchSize) / totalSamples;
141
142 if (verbose)
143 std::cout << "Class count for (" << classPair.first
144 << ") = " << classCount
145 << " - classBatchSize = " << classBatchSize[classPair.first]
146 << std::endl;
147 }
148
149 // Create batches
150 bool moreData = true;
151 while (moreData) {
152 X batchData;
153 Y batchLabels;
154 moreData = false;
155
156 // Fill mini-batch with value of each class
157 for (auto& classPair : classDataMap) {
158 Ty classLabel = classPair.first;
159 X& classSamples = classPair.second;
160 int nSamplesToAdd = classBatchSize[classLabel];
161
162 for (int i = 0; i < nSamplesToAdd && !classSamples.empty(); i++) {
163 batchData.push_back(classSamples.back());
164 batchLabels.push_back(classLabel);
165 classSamples.pop_back();
166 }
167
168 if (!classSamples.empty()) {
169 moreData = true;
170 }
171 }
172
173 while (batchData.size() < static_cast<size_t>(batchSize) && moreData) {
174 bool added = false;
175 for (auto& classPair : classDataMap) {
176 if (!classPair.second.empty()) {
177 batchData.push_back(classPair.second.back());
178 batchLabels.push_back(classPair.first);
179 classPair.second.pop_back();
180 added = true;
181 if (batchData.size() == static_cast<size_t>(batchSize)) break;
182 }
183 }
184
185 if (!added) break;
186 }
187
188 if (batchData.size() < static_cast<size_t>(batchSize) && dropLast) {
189 // skip placing it in miniBatches
190 if (verbose)
191 printDropLast(batchData.size(), static_cast<size_t>(batchSize));
192 continue;
193 }
194
195 if (!batchData.empty()) {
196 // Shuffle batches if indicated
197 if (shuffle) this->shuffle(batchData, batchLabels);
198 miniBatches.emplace_back(batchData, batchLabels);
199 }
200 }
201
202 if (verbose) printNBatchesCreated(miniBatches.size());
203 }
204};
205} // namespace NeuralNet
Definition Network.hpp:29
Definition TrainingData.hpp:10
TrainingData(X xTrain, Y yTrain, X xTest=X(), Y yTest=Y())
Construct a new Training Data object. This object is used to store the inputs and labels data,...
Definition TrainingData.hpp:25
void batch(int batchSize, bool stratified=false, bool shuffle=false, bool dropLast=false, bool verbose=false)
This method will separate the inputs and labels data into batches of the specified size.
Definition TrainingData.hpp:39