最近在设计一个类似shadowsocks的小东西,在写local部分的时候,考虑了一下这个问题。这个设计场景非常简单:local proxy监听本地端口,一旦有client连接,则会创建一个对远端固定server的连接,将client的数据经过加密后,转发到server上;同时,从server上收到的数据,也会经过解密后转发给client。
这个过程中,client->server的流程可以用go语言简单描述如下(server->client流程类似):
func EncodeCopy(dst *net.TCPConn, src *net.TCPConn) error {
buf := make([]byte, BufSize)
for {
readCount, errRead := src.Read(buf)
if errRead != nil {
//handle error...
}
if readCount > 0 {
encode(buf[0:readCount]) // 为了简化写法,假设加密后数据长度不变
writeCountnt, errWrite := dst.Write(buf[0:readCount])
if errWrite != nil {
//handle error...
}
}
}
}
EncodeCopy(server_conn, client_conn)
用boost.asio,实现相同功能:
tcp::socket cli_sock, srv_sock;
uint8_t buf[MAX_BUF_SIZE];
......
do_read()
{
cli_sock.async_read_some(buffer(buf, MAX_BUF_SIZE), on_read)
}
on_read(const error_code& ec, size_t bytes_read)
{
if(ec) {
//handle error...
}
encode(buf, bytes_read);
async_write(srv_sock,
const_buffer(buf, bytes_read),
[](error_code& ec, size_t bytes_write)
{
if (ec) {
//handle error...
}
do_read();
});
}
这种写法是最简单的实现,即从src读取一部分数据,加密后写入dst,然后再读,加密后再写,循环往复。很容易发现,即使是开多线程asio,单个连接执行起来也是个串行流程,因此存在性能提升空间。
我们可以把整个过程分为三部分:read、encode、write。对于不同次循环,这三部分可以并行执行,形成流水线。理想情况下如图所示:
+----------+----------+----------+
| read | encode | write |
+-------------------------------------------+
| read | encode | write |
+-------------------------------------------+
| read | encode | write |
+----------+----------+----------+
为了实现并行化,很容易想到用传统的多线程+队列通信方式:
tcp_socket cli_sock, srv_sock;
block_queue read_buf_que, enc_buf_que, write_buf_que;
void thrd_read()
{
while (1) {
buffer *buf = dequeue(read_buf_que);
read_some(cli_sock, buf->data, &buf->size);
enqueue(enc_buf_que, buf)
}
}
void thrd_encode()
{
while (1) {
buffer *buf = dequeue(enc_buf_que);
encode(buf);
enqueue(write_buf_que, buf);
}
}
void thrd_write()
{
while (1) {
buffer *buf = dequeue(write_buf_que);
write(srv_sock, buf->data, buf->size);
enqueue(read_buf_que, buf);
}
}
int main()
{
//....
for (int i = 0; i < read_buf_que.size; i++) {
read_buf_que.push(new_buffer());
}
new_thread(thrd_read).run().join();
new_thread(thrd_encode).run().join();
new_thread(thrd_write).run().join();
//...
}
为了在多线程asio中实现类似流水线效果,核心思想是在read/write回调中,马上进行下一次read/write操作。为此,我们需要一个环形队列进行数据暂存:
const size_t max_ring_size = 32;
size_t packet_size = ...;
struct buffer_ring_t
{
buffer_ring_t() : pending_read(fasle),
pending_write(false),
read_pos(0),
write_pos(0)
{
for (auto& i : ring) {
i = new uint8_t(packet_size);
}
for (auto& i : state) {
i = FREE;
}
}
uint8_t *ring[max_ring_size];
size_t buf_size[max_ring_size];
enum {
USED,
READY,
FREE
} state[max_size],
bool pending_read;
bool pending_write;
size_t read_pos;
size_t write_pos;
}
其中:
buf_size
是每一个buffer的实际大小。state
标记一个buffer的状态。USED表示已经读进来但是没有被加密完成,READY表示已经被加密完成未被发送,FREE表示已被write到server,该buffer空闲。read_pos
和write_pos
表示读写位置,计算下一个pos的方法是(pos + 1) % max_ring_size
。对于async_read_some和async_write,在其回调中立即使用下一个read/write pos来进行下一次async_read_some和async_write操作,并将pending_read和pending_write置为true,直到没有FREE/READY的buffer,此时读/写调用链中断,pending_xxx置为false。pending_write
,如果为false,说明写调用链中断,投递async_write来重启写调用链。在写回调中操作类似。整体代码如下:
tcp_socket cli_sock, srv_sock;
buffer_ring_t buf_ring;
do_read()
{
if (buf_ring.state[buf_ring.read_pos] != buffer_ring::FREE) {
buf_ring.pending_read = false;
return;
}
buf_ring.pendig_read = true;
cli_sock.async_read_some(buffer(buf_ring.ring[buf_ring.read_pos], packet_size),
on_read);
}
on_read(error_code& ec, size_t bytes_read)
{
if (ec) {
//...
}
size_t pos = buf_ring.read_pos;
buf_ring.buf_size[pos] = bytes_read;
buf_ring.state[pos] = buffer_ring::USED;
buf_ring.read_pos = (pos + 1) % max_ring_size;
do_read();
encode(buf_ring.ring[pos], bytes_read);
buf_ring.state[pos] = buffer_ring::FREE;
if (buf_ring.pending_write == false) {
do_write()
}
}
do_write()
{
if (buf_ring.state[buf_ring.write_pos] != buffer_ring::READY) {
buf_ring.pending_write = fasle;
return;
}
buf_ring.pending_write = true;
async_write(srv_sock,
const_buffer(buf_ring.ring[buf_ring.write_pos], buf_ring.buf_size[buf_ring.write_pos]),
on_write);
}
on_write(error_code& ec, size_t bytes_write)
{
if (ec) {
//...
}
size_t pos = buf_ring.write_pos;
buf_ring.state[pos] = buffer_ring::FREE;
buf_ring.write_pos = (pos + 1) % max_ring_size;
do_write();
if (!buf_ring.pending_read) {
do_read()
}
}
int main() {
//...
do_read();
//...
}