2025-09-28 14:58:27 +02:00

269 lines
8.3 KiB
C++
Executable File

#include "sorter.hpp"
#include "single_task_handler.hpp"
#include "thread_pool.hpp"
#include <algorithm>
#include <iterator>
#include <cassert>
#include <iostream>
#include <bitset>
#include <climits>
#define DEBUG false
namespace ae {
sorter::sorter(uint32_t num = 1) {
sorter::num_threads = num;
if (num > 1) {
sorter::handler = (TaskHandler*) new ThreadPool(num);
} else {
sorter::handler = (TaskHandler*) new SingleTaskHandler();
}
}
void sorter::sort(container& data) {
// for (auto i = 1uz; i < data.placeholder_.size(); ++i) {
// std::ranges::copy(data.placeholder_[i], std::back_inserter(data.placeholder_[0]));
// data.placeholder_[i].clear();
// }
#if DEBUG
std::vector<container::element_type> copy;
for (auto element : data) {
copy.push_back(element);
}
auto begin = data.begin();
for (int i = 0; i < data.size(); i++) {
std::cerr << i << " before:" << begin[i] << std::endl;
}
std::sort(copy.begin(), copy.end());
#endif
sorter::msd_inplace_radix_sort(data.begin(), data.end(), 0, [&](auto begin, auto end) {sorter::robin_hood_sort(begin, end);});
while (sorter::handler->size() > 0 || sorter::handler->isWorking()) {};
#if DEBUG
std::cerr << "Final check if sorted" << std::endl;
for (int i = 0; i < copy.size(); i++) {
if (copy[i] != begin[i])
std::cerr << i << " " << "sorted: " << copy[i] << " actual:" << begin[i] << std::endl;
}
#endif
}
void sorter::msd_inplace_radix_sort_binary(
container::Iterator begin,
container::Iterator end,
size_t passes,
const std::function<void(container::Iterator begin, container::Iterator end)>& bucket_sort
) {
if (begin >= end) {
return;
}
if (sorter::RADIX_ITERATIONS == passes) {
switch (end - begin) {
case 1: return;
case 2:
if (begin[0] >= begin[1]) {
std::swap(begin[0], begin[1]);
}
return;
default:
bucket_sort(begin, end);
return;
}
if (end - begin > 1) {
bucket_sort(begin, end);
}
return;
}
auto lower = begin;
auto upper = end;
while (lower < upper) {
if (*lower & (1L << (sizeof(container::element_type) * CHAR_BIT - passes - 1))) {
// The <passes>-left bit is set, so move to the beginning of the end section and decrement the upper iterator
--upper;
std::swap(*upper, *lower);
} else {
++lower;
}
}
#if DEBUG
std::cerr << "pass: " << passes << " begin: " << &*begin << " end: " << &*end << " lower: " << &*lower << std::endl;
#endif
sorter::msd_inplace_radix_sort_binary(begin, lower, passes + 1, bucket_sort);
sorter::msd_inplace_radix_sort_binary(lower, end, passes + 1, bucket_sort);
}
void sorter::msd_inplace_radix_sort(
container::Iterator begin,
container::Iterator end,
size_t passes,
const std::function<void(container::Iterator begin, container::Iterator end)>& bucket_sort
) {
if (begin > end) {
return;
}
if ((end - begin) <= sorter::SMALL_SORT_THRESHHOLD) {
bucket_sort(begin, end);
return;
}
// We first determine the number of elements per bucket
// This is one pass additional pass over the elements and needs O(buckets) additional space, so in one configuration constant overhead
uint32_t bucket_sizes[sorter::RADIX_BUCKETS] = { 0 };
auto upper_bucket_mask = ((1L << sorter::RADIX_SIZE) - 1) << (sizeof(container::element_type) * CHAR_BIT - sorter::RADIX_SIZE * (1 + passes));
auto mask_bucket = [&](container::element_type* element){ return (*element & upper_bucket_mask) >> (sizeof(container::element_type) * CHAR_BIT - sorter::RADIX_SIZE) * (1 + passes); };
for (auto element = begin; element < end; ++element) {
auto bucket = mask_bucket(&*element);
bucket_sizes[bucket]++;
}
#if DEBUG
std::cerr << "Bucket sizes: ";
for (auto bucket : bucket_sizes) {
std::cerr << bucket << " ";
}
std::cerr << std::endl;
#endif
// We now point each bucket to its start location in the range
container::Iterator* buckets_end[sorter::RADIX_BUCKETS];
container::Iterator* buckets_start[sorter::RADIX_BUCKETS];
#if DEBUG
std::cerr << "Starting bucket" << std::endl;
#endif
auto count = 0;
for (int i = 0; i < sorter::RADIX_BUCKETS; ++i) {
buckets_end[i] = new container::Iterator(begin + count);
buckets_start[i] = new container::Iterator(begin + count);
#if DEBUG
std::cerr << "bucket " << i << " at " << count << std::endl;
#endif
count += bucket_sizes[i];
}
#if DEBUG
std::cerr << "finish" << std::endl;
#endif
// Loop over the elements and swap them into the correct buckets.
// This will look at each element exactly once.
auto element = begin;
while (element < end) {
uint32_t bucket = mask_bucket(&*element);
// Check if we are currently in the bounds of the corresponding bucket
if (element >= *buckets_start[bucket] && element < *buckets_end[bucket]) {
// The element is in the correct bucket, we skip to the end of the bucket
element = *buckets_end[bucket];
} else {
// The element is not in the correct bucket; swap
std::swap(*element, **buckets_end[bucket]);
(*buckets_end[bucket])++;
}
}
#if DEBUG
for (int i = 0; i < end - begin; i++) {
std::cerr << i << " reordered:" << begin[i] << std::endl;
}
std::cerr << "Finish reordering elements" << std::endl;
std::cerr << "Bucket elements at begin of bucket" << std::endl;
for (auto bucket : buckets_start) {
std::cerr << (**bucket) << " bucket " << mask_bucket(&**bucket) << std::endl;
}
std::cerr << std::endl;
#endif
for (auto i = 0; i < sorter::RADIX_BUCKETS - 1; ++i) {
assert(*buckets_end[i] == *buckets_start[i + 1]);
}
assert(*buckets_end[sorter::RADIX_BUCKETS - 1] == end);
#if DEBUG
std::cerr << "Ranges of buckets are correct" << std::endl;
#endif
// sort each bucket recursively
for (auto i = 0; i < sorter::RADIX_BUCKETS; i++) {
#if DEBUG
std::cerr << "Putting in task with depth " << passes << " of bucket " << i << std::endl;
#endif
auto start = buckets_start[i];
auto end = buckets_end[i];
sorter::handler->add([start, end, &bucket_sort, passes, this, i](){
#if DEBUG
std::cerr << "Starting task with depth " << passes << " of bucket " << i << std::endl;
#endif
sorter::msd_inplace_radix_sort(*start, *end, passes + 1, bucket_sort);
#if DEBUG
std::cerr << "Finishing task with depth " << passes << " of bucket " << i << std::endl;
#endif
});
}
}
void sorter::robin_hood_sort(container::Iterator begin, container::Iterator end) {
const auto size = (end - begin) + sorter::OVERHEAD_SIZE;
const auto mask = ((1L) << (sizeof(container::element_type) * CHAR_BIT - sorter::RADIX_ITERATIONS)) - 1;
std::vector<container::element_type> space(size, -1L);
for (auto element = begin; element < end; ++element) {
auto masked_element = (*element & mask);
auto index = ((masked_element) * (end - begin)) / mask;
if (space[index] == -1) {
space[index] = *element;
} else {
#if DEBUG
std::cerr << "Linear probing of " << *element << " at index " << index << ". Current element " << space[index] << std::endl;
#endif
auto i = index;
// linear probing
while (i < size - 1 && space[i] != -1) {++i;};
#if DEBUG
std::cerr << "Inserting " << *element << " at index " << i << " instead of " << index << std::endl;
#endif
space[i] = *element;
}
}
#if DEBUG
std::cerr << "Unsorted\n";
for (auto element : space) {
std::cerr << element << " ";
}
std::cerr << std::endl;
#endif
// One final pass to correct linear probing errors
for (auto i = 1; i < size; ++i) {
auto j = i;
while ((uint64_t) space[j-1] > space[j] && j > 0) {
std::swap((space[j]),space[j-1]);
j--;
}
}
#if DEBUG
std::cerr << "Original\n";
for (auto element = begin; element < end; ++element) {
std::cerr << *element << " ";
}
std::cerr << std::endl;
std::cerr << "Checking if sorted\n";
for (auto element : space) {
std::cerr << element << " ";
}
std::cerr << std::endl;
#endif
// copy data back into original range
auto i = 0;
for (auto element = begin; element < end; ++element) {
*element = space[i];
++i;
}
}
} // namespace ae