#include "thread_pool.hpp" #define DEBUG false #if DEBUG #include #endif ThreadPool::ThreadPool(size_t num_threads) { mutexes = new std::mutex[num_threads]; tasks = new std::queue>[num_threads]; cvs = new std::condition_variable[num_threads]; for (auto i = 0; i < num_threads; ++i) { threads.emplace_back([this, i] { #if DEBUG std::cerr << "Thread " << i << " started" << std::endl; #endif while (true) { std::function task; { std::unique_lock lock(mutexes[i]); #if DEBUG std::cerr << "Thread " << i << " waiting for task. Queue " << tasks[i].size() << std::endl; #endif cvs[i].wait(lock, [i,this] { return !tasks[i].empty() || stop; }); if (tasks[i].empty() && stop) { #if DEBUG std::cerr << "Thread " << i << " finished. Stop " << stop << ", queue size " << tasks[i].size() << std::endl; #endif return; } if (!tasks[i].empty()) { #if DEBUG std::cerr << "Thread " << i << " task acquired" << std::endl; #endif task = std::move(tasks[i].front()); tasks[i].pop(); } else { #if DEBUG std::cerr << "Thread " << i << " spontaneous wakeup. Continueing" << std::endl; #endif continue; } } #if DEBUG std::cerr << "Thread " << i << " executing task" << std::endl; #endif active_workers.fetch_add(1); task(); active_workers.fetch_sub(1); remaining_tasks.fetch_sub(1); #if DEBUG std::cerr << "Remaining tasks " << remaining_tasks.load() << std::endl; #endif } }); } } ThreadPool::~ThreadPool() { stop = true; for (auto i = 0; i < threads.size(); ++i) { cvs[i].notify_one(); } for (auto& thread : threads) { thread.join(); } } void ThreadPool::add(std::function task) { size_t idx = current.fetch_add(1) % threads.size(); { std::unique_lock lock(mutexes[idx]); tasks[idx].emplace(std::move(task)); } remaining_tasks.fetch_add(1); #if DEBUG std::cerr << "Adding task to thread " << current << ". Total remaining " << remaining_tasks.load() << std::endl; #endif cvs[idx].notify_one(); } uint32_t ThreadPool::size() { return remaining_tasks.load(); } bool ThreadPool::isWorking() { return active_workers.load() > 0; }