26 : xTrain(xTrain), yTrain(yTrain), xTest(xTest), yTest(yTest) {
27 assert(xTrain.size() == yTrain.size());
28 if (!xTest.empty() && !yTest.empty()) assert(xTest.size() == xTest.size());
31 std::vector<std::pair<X, Y>> getMiniBatches() {
return this->miniBatches; }
39 void batch(
int batchSize,
bool stratified =
false,
bool shuffle =
false,
40 bool dropLast =
false,
bool verbose =
false) {
43 return normalMiniBatch(batchSize, xTrain, yTrain, shuffle, dropLast,
45 return stratifiedMiniBatch(batchSize, xTrain, yTrain, shuffle, dropLast,
54 std::vector<std::pair<X, Y>> miniBatches;
57 void printDropLast(
const size_t size,
const size_t requiredSize) {
58 std::cout <<
"Dropping last mini-batch (size : " << size <<
" < "
59 << requiredSize <<
")" << std::endl;
62 void printNBatchesCreated(
const size_t nBatches) {
63 std::cout <<
"Total mini-batches created : " << nBatches << std::endl;
66 template <
typename Ti,
typename Tl>
67 void shuffle(std::vector<Ti>& inputs, std::vector<Tl>& labels) {
69 std::vector<std::pair<Ti, Tl>> combined;
70 for (
size_t i = 0; i < inputs.size(); ++i) {
71 combined.emplace_back(inputs[i], labels[i]);
75 std::random_device rd;
79 std::shuffle(combined.begin(), combined.end(), g);
82 for (
size_t i = 0; i < combined.size(); ++i) {
83 inputs[i] = combined[i].first;
84 labels[i] = combined[i].second;
88 void normalMiniBatch(
int batchSize, X& x, Y& y,
bool shuffle =
false,
89 bool dropLast =
false,
bool verbose =
false) {
90 int nInputs = x.size();
91 assert(nInputs > 0 && batchSize < nInputs);
92 int nMiniBatches = (nInputs + batchSize - 1) / batchSize;
98 miniBatches.reserve(nMiniBatches);
100 for (
int i = 0; i < nMiniBatches; ++i) {
101 int startIdx = i * batchSize;
102 int endIdx = std::min((i + 1) * batchSize,
static_cast<int>(nInputs));
104 X xMiniBatch(x.begin() + startIdx, x.begin() + endIdx);
105 Y yMiniBatch(y.begin() + startIdx, y.begin() + endIdx);
107 std::pair<X, Y> miniBatch = std::make_pair(xMiniBatch, yMiniBatch);
109 if (xMiniBatch.size() <
static_cast<size_t>(batchSize) && dropLast) {
111 printDropLast(xMiniBatch.size(),
static_cast<size_t>(batchSize));
115 miniBatches.push_back(miniBatch);
119 x.erase(x.begin(), x.end());
120 y.erase(y.begin(), y.end());
122 if (verbose) printNBatchesCreated(miniBatches.size());
125 template <
typename Tx,
typename Ty>
126 void stratifiedMiniBatch(
int batchSize, std::vector<Tx>& x,
127 std::vector<Ty>& y,
bool shuffle =
false,
128 bool dropLast =
false,
bool verbose =
false) {
130 std::map<Ty, X> classDataMap;
131 for (
size_t i = 0; i < x.size(); i++) {
132 classDataMap[y[i]].push_back(x[i]);
136 std::map<Ty, int> classBatchSize;
137 int totalSamples = x.size();
138 for (
const auto& classPair : classDataMap) {
139 int classCount = classPair.second.size();
140 classBatchSize[classPair.first] = (classCount * batchSize) / totalSamples;
143 std::cout <<
"Class count for (" << classPair.first
144 <<
") = " << classCount
145 <<
" - classBatchSize = " << classBatchSize[classPair.first]
150 bool moreData =
true;
157 for (
auto& classPair : classDataMap) {
158 Ty classLabel = classPair.first;
159 X& classSamples = classPair.second;
160 int nSamplesToAdd = classBatchSize[classLabel];
162 for (
int i = 0; i < nSamplesToAdd && !classSamples.empty(); i++) {
163 batchData.push_back(classSamples.back());
164 batchLabels.push_back(classLabel);
165 classSamples.pop_back();
168 if (!classSamples.empty()) {
173 while (batchData.size() <
static_cast<size_t>(batchSize) && moreData) {
175 for (
auto& classPair : classDataMap) {
176 if (!classPair.second.empty()) {
177 batchData.push_back(classPair.second.back());
178 batchLabels.push_back(classPair.first);
179 classPair.second.pop_back();
181 if (batchData.size() ==
static_cast<size_t>(batchSize))
break;
188 if (batchData.size() <
static_cast<size_t>(batchSize) && dropLast) {
191 printDropLast(batchData.size(),
static_cast<size_t>(batchSize));
195 if (!batchData.empty()) {
197 if (shuffle) this->shuffle(batchData, batchLabels);
198 miniBatches.emplace_back(batchData, batchLabels);
202 if (verbose) printNBatchesCreated(miniBatches.size());