This is my study note of thread pools.

线程池核心就是线程池中的线程会持续查询任务队列是否有可用工作,如果有可用工作则将其取出并执行。

任务队列

任务队列需要实现的就是插入、删除任务等操作。

  • 判断队列是否为空 SafeQueue::empty()
    1
    2
    3
    4
    5
    bool empty() // 返回队列是否为空
    {
    std::unique_lock<std::mutex> lock(m_mutex); // 互斥信号变量加锁,防止m_queue被改变
    return m_queue.empty();
    }
  • 返回队列长度 SafeQueue::size()
    1
    2
    3
    4
    5
    int size()
    {
    std::unique_lock<std::mutex> lock(m_mutex); // 互斥信号变量加锁,防止m_queue被改变
    return m_queue.size();
    }
  • 插入任务 SafeQueue::enqueue(T &t)
    1
    2
    3
    4
    5
    6
    // 主要是用模板进行函数类型的替代
    void enqueue(T &t)
    {
    std::unique_lock<std::mutex> lock(m_mutex);
    m_queue.emplace(t);
    }
  • 取出任务 SafeQueue::dequeue(T &t)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    // t参数用于取出任务以后调用,false和true用于判断是否还具有任务需要执行
    bool dequeue(T &t)
    {
    std::unique_lock<std::mutex> lock(m_mutex); // 队列加锁
    if (m_queue.empty())
    return false;
    t = std::move(m_queue.front()); // 取出队首元素,返回队首元素值,并进行右值引用
    m_queue.pop(); // 弹出入队的第一个元素
    return true;
    }

线程池代码

使用内置class(ThreadWorker)执行真正的操作

  • 内置class构造函数
    1
    2
    3
    4
    5
    ThreadWorker(ThreadPool *pool, const int id) : m_pool(pool), m_id(id)
    {
    }
    // 主要是为了将线程池以及当前任务的id导入,也可以不用如此形式,直接赋值也可
    // (注意内置class无法调用非内置的成员变量)
  • 重载()操作
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    // 主要是将任务队列头部的任务取出并执行,若为空则等待任务
    void operator()()
    {
    // std::function类型对象实例可以包装下列这几种可调用实体:函数、函数指针、成员函数、静态函数、lamda表达式和函数对象
    std::function<void()> func; // 定义基础函数类func
    bool dequeued; // 是否正在取出队列中元素
    while (!m_pool->m_shutdown)
    // m_shutdown为true线程池关闭,flase运行
    {
    {
    // 为线程环境加锁,互访问工作线程的休眠和唤醒
    std::unique_lock<std::mutex> lock(m_pool->m_conditional_mutex);
    // 如果任务队列为空,阻塞当前线程
    if (m_pool->m_queue.empty())
    {
    m_pool->m_conditional_lock.wait(lock); // 等待条件变量通知,开启线程
    }
    // 取出任务队列中的元素
    dequeued = m_pool->m_queue.dequeue(func);
    }
    // 如果成功取出,执行工作函数
    if (dequeued)
    func();
    }
    }

线程池初始化 init()

1
2
3
4
5
6
7
8
9
// 主要是根据m_threads的数量创建对应的线程进行执行
void init()
{
for (int i = 0; i < m_threads.size(); ++i)
{
m_threads.at(i) = std::thread(ThreadWorker(this, i)); // 分配工作线程
}
}
// m_threads的数量由ThreadPool构造函数决定

线程池关闭 shutdown()

1
2
3
4
5
6
7
8
9
10
11
12
13
void shutdown()
{
m_shutdown = true; // 设置标志位
m_conditional_lock.notify_all(); // 通知,唤醒所有工作线程

for (int i = 0; i < m_threads.size(); ++i)
{
if (m_threads.at(i).joinable()) // 判断线程是否在等待
{
m_threads.at(i).join(); // 将线程加入到等待队列
}
}
}

线程池任务提交 submit

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
template <typename F, typename... Args>
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
{
// decltype用于识别func的类别,forward()用于识别是左值还是右值,最后用bind结合起来传递给func
std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...); // 连接函数和参数定义,特殊函数类型,避免左右值错误
auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
std::function<void()> warpper_func = [task_ptr]()
{
(*task_ptr)();
};
// 队列通用安全封包函数,并压入安全队列
m_queue.enqueue(warpper_func);
// 唤醒一个等待中的线程
m_conditional_lock.notify_one();
// 返回先前注册的任务指针
return task_ptr->get_future();
}

总体代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class SafeQueue
{
private:
std::queue<T> m_queue; //利用模板函数构造队列
std::mutex m_mutex; // 访问互斥信号量

public:
SafeQueue() {}
SafeQueue(SafeQueue &&other) {}
~SafeQueue() {}

bool empty() // 返回队列是否为空
{
std::unique_lock<std::mutex> lock(m_mutex); // 互斥信号变量加锁,防止m_queue被改变

return m_queue.empty();
}

int size()
{
std::unique_lock<std::mutex> lock(m_mutex); // 互斥信号变量加锁,防止m_queue被改变

return m_queue.size();
}

// 队列添加元素
void enqueue(T &t)
{
std::unique_lock<std::mutex> lock(m_mutex);
m_queue.emplace(t);
}

// 队列取出元素
bool dequeue(T &t)
{
std::unique_lock<std::mutex> lock(m_mutex); // 队列加锁

if (m_queue.empty())
return false;
t = std::move(m_queue.front()); // 取出队首元素,返回队首元素值,并进行右值引用

m_queue.pop(); // 弹出入队的第一个元素

return true;
}
};

class ThreadPool
{
private:
class ThreadWorker // 内置线程工作类
{
private:
int m_id; // 工作id

ThreadPool *m_pool; // 所属线程池
public:
// 构造函数
ThreadWorker(ThreadPool *pool, const int id) : m_pool(pool), m_id(id)
{
}

// 重载()操作
void operator()()
{
std::function<void()> func; // 定义基础函数类func

bool dequeued; // 是否正在取出队列中元素

while (!m_pool->m_shutdown)
{
{
// 为线程环境加锁,互访问工作线程的休眠和唤醒
std::unique_lock<std::mutex> lock(m_pool->m_conditional_mutex);

// 如果任务队列为空,阻塞当前线程
if (m_pool->m_queue.empty())
{
m_pool->m_conditional_lock.wait(lock); // 等待条件变量通知,开启线程
}
// 取出任务队列中的元素
dequeued = m_pool->m_queue.dequeue(func);
}

// 如果成功取出,执行工作函数
if (dequeued)
func();
}
}
};

bool m_shutdown; // 线程池是否关闭

SafeQueue<std::function<void()>> m_queue; // 执行函数安全队列,即任务队列

std::vector<std::thread> m_threads; // 工作线程队列

std::mutex m_conditional_mutex; // 线程休眠锁互斥变量

std::condition_variable m_conditional_lock; // 线程环境锁,可以让线程处于休眠或者唤醒状态

public:
// 线程池构造函数
ThreadPool(const int n_threads = 4)
: m_threads(std::vector<std::thread>(n_threads)), m_shutdown(false)
{
}

ThreadPool(const ThreadPool &) = delete;

ThreadPool(ThreadPool &&) = delete;

ThreadPool &operator=(const ThreadPool &) = delete;

ThreadPool &operator=(ThreadPool &&) = delete;

// Inits thread pool
void init()
{
for (int i = 0; i < m_threads.size(); ++i)
{
m_threads.at(i) = std::thread(ThreadWorker(this, i)); // 分配工作线程
}
}

// Waits until threads finish their current task and shutdowns the pool
void shutdown()
{
m_shutdown = true;
m_conditional_lock.notify_all(); // 通知,唤醒所有工作线程

for (int i = 0; i < m_threads.size(); ++i)
{
if (m_threads.at(i).joinable()) // 判断线程是否在等待
{
m_threads.at(i).join(); // 将线程加入到等待队列
}
}
}

// Submit a function to be executed asynchronously by the pool
template <typename F, typename... Args>
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))>
{
// Create a function with bounded parameter ready to execute
std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...); // 连接函数和参数定义,特殊函数类型,避免左右值错误

// Encapsulate it into a shared pointer in order to be able to copy construct
auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);

// Warp packaged task into void function
std::function<void()> warpper_func = [task_ptr]()
{
(*task_ptr)();
};

// 队列通用安全封包函数,并压入安全队列
m_queue.enqueue(warpper_func);

// 唤醒一个等待中的线程
m_conditional_lock.notify_one();

// 返回先前注册的任务指针
return task_ptr->get_future();
}
};