在 libgo 的上下文切换上,并没有自己去实现创建和维护栈空间、保存和切换 CPU 寄存器执行状态信息等的任务,而是直接使用了 Boost.Context。Boost.Context 作为众多协程底层支持库,性能方面一直在被优化。
Boost.Context所做的工作,就是在传统的线程环境中可以保存当前执行的抽象状态信息(栈空间、栈指针、CPU寄存器和状态寄存器、IP指令指针),然后暂停当前的执行状态,程序的执行流程跳转到其他位置继续执行,这个基础构建可以用于开辟用户态的线程,从而构建出更加高级的协程等操作接口。同时因为这个切换是在用户空间的,所以资源损耗很小,同时保存了栈空间和执行状态的所有信息,所以其中的函数可以*被嵌套使用。引用自https://blog.csdn.net/qq_35976351/article/details/107449235
文章目录
fcontext_t
libgo/context/fcontext.h
Boost.Context 的底层实现是通过 fcontext_t 结构体来保存协程状态,使用 make_fcontext 创建协程,使用 jump_fcontext 实现协程切换。在 libgo 协程中,直接引用了这两个接口函数。
extern "C"
{
typedef void* fcontext_t;
typedef void (FCONTEXT_CALL *fn_t)(intptr_t);
intptr_t jump_fcontext(fcontext_t * ofc, fcontext_t nfc,
intptr_t vp, bool preserve_fpu = false);
fcontext_t make_fcontext(void* stack, std::size_t size, fn_t fn);
} // extern "C"
除此之外,还提供了一系列的栈函数。
namespace co {
struct StackTraits
{
static stack_malloc_fn_t& MallocFunc();
static stack_free_fn_t& FreeFunc();
// 获取当前栈顶设置的保护页的页数
static int & GetProtectStackPageSize();
// 对保护页的内容做保护
static bool ProtectStack(void* stack, std::size_t size, int pageSize);
// 取消对保护页的内存保护,析构是才会调用
static void UnprotectStack(void* stack, int pageSize);
};
} // namespace co
栈保护
libgo 对栈对保护,使用了 mprotect 系统调用实现。我们在给该协程创建了大小为 N 字节对栈空间时,会对栈顶的一部分的空间进行保护,因此,分配的协程栈的大小,应该要大于要保护的内存页数加一。
为什么提到保护栈,总是以页为单位呢?因为 mprotect 是按照页来进行设置的,因此,对没有对齐地址,应该首先对齐之后再去操作。
bool StackTraits::ProtectStack(void* stack, std::size_t size, int pageSize)
{
if (!pageSize) return false;
// 协程栈的大小,应该大于(保护内存页数+1)
if ((int)size <= getpagesize() * (pageSize + 1))
return false;
// 使用 mprotect 保护的内存页应该是按页对其的
// 栈从高地址向地地址生长,被保护的栈空间应该位于栈顶(低地址处)
// protect_page_addr 是在当前协程栈内取最近的整数页边界的地址,如:0xf7234008 ---> 0xf7235000
void *protect_page_addr = ((std::size_t)stack & 0xfff) ? (void*)(((std::size_t)stack & ~(std::size_t)0xfff) + 0x1000) : stack;
// 使用 mprotect 系统调用实现栈保护,PROT_NONE 表明该内存空间不可访问
if (-1 == mprotect(protect_page_addr, getpagesize() * pageSize, PROT_NONE)) {
DebugPrint(dbg_task, "origin_addr:%p, align_addr:%p, page_size:%d, protect_page:%u, protect stack stack error: %s",
stack, protect_page_addr, getpagesize(), pageSize, strerror(errno));
return false;
} else {
DebugPrint(dbg_task, "origin_addr:%p, align_addr:%p, page_size:%d, protect_page:%u, protect stack success.",
stack, protect_page_addr, pageSize, getpagesize());
return true;
}
}
取消栈保护
取消栈保护只有在释放该协程空间的时候会调用。
void StackTraits::UnprotectStack(void *stack, int pageSize)
{
if (!pageSize) return ;
void *protect_page_addr = ((std::size_t)stack & 0xfff) ? (void*)(((std::size_t)stack & ~(std::size_t)0xfff) + 0x1000) : stack;
// 允许该块内存可读可写
if (-1 == mprotect(protect_page_addr, getpagesize() * pageSize, PROT_READ|PROT_WRITE)) {
DebugPrint(dbg_task, "origin_addr:%p, align_addr:%p, page_size:%d, protect_page:%u, protect stack stack error: %s",
stack, protect_page_addr, getpagesize(), pageSize, strerror(errno));
} else {
DebugPrint(dbg_task, "origin_addr:%p, align_addr:%p, page_size:%d, protect_page:%u, protect stack success.",
stack, protect_page_addr, pageSize, getpagesize());
}
}
mprotect 系统调用使用说明
#include <sys/mman.h>
int mprotect(void *addr, size_t len, int prot);
addr:应该是按页对其的内存地址
len:保护的内存页大小,因此保护的地址范围应该是[addr, addr+len-1]
prot:保护类型
PROT_NONE The memory cannot be accessed at all.
PROT_READ The memory can be read.
PROT_WRITE The memory can be modified.
PROT_EXEC The memory can be executed.
Context
Context 是 libgo 中封装的上下文对象,每个协程都会有一份独有的。
class Context
{
public:
Context(fn_t fn, intptr_t vp, std::size_t stackSize)
: fn_(fn), vp_(vp), stackSize_(stackSize)
{
stack_ = (char*)StackTraits::MallocFunc()(stackSize_);
DebugPrint(dbg_task, "valloc stack. size=%u ptr=%p",
stackSize_, stack_);
ctx_ = make_fcontext(stack_ + stackSize_, stackSize_, fn_);
int protectPage = StackTraits::GetProtectStackPageSize();
if (protectPage && StackTraits::ProtectStack(stack_, stackSize_, protectPage))
protectPage_ = protectPage;
}
~Context()
{
if (stack_) {
DebugPrint(dbg_task, "free stack. ptr=%p", stack_);
if (protectPage_)
StackTraits::UnprotectStack(stack_, protectPage_);
StackTraits::FreeFunc()(stack_);
stack_ = NULL;
}
}
ALWAYS_INLINE void SwapIn()
{
jump_fcontext(&GetTlsContext(), ctx_, vp_);
}
ALWAYS_INLINE void SwapTo(Context & other)
{
jump_fcontext(&ctx_, other.ctx_, other.vp_);
}
ALWAYS_INLINE void SwapOut()
{
jump_fcontext(&ctx_, GetTlsContext(), 0);
}
fcontext_t& GetTlsContext()
{
static thread_local fcontext_t tls_context;
return tls_context;
}
private:
fcontext_t ctx_;
fn_t fn_; // 协程运行函数
intptr_t vp_; // 当前上下文属于的协程 Task 对象指针
char* stack_ = nullptr; // 栈空间
uint32_t stackSize_ = 0;// 栈大小
int protectPage_ = 0; // 保护页的数量
};
Hook
介绍hook之前,需要先了解两个系统函数:
#include <dlfcn.h>
void *dlopen(const char *filename, int flag);
void *dlsym(void *handle, const char *symbol);
dlopen以指定模式打开指定的动态连接库文件,并返回一个句柄给调用进程,dlsym通过句柄和连接符名称获取函数名或者变量名。hook的意思就是说,在系统调用函数中,会先查看是否自己有先定义过,如果有则调用自己的定义的函数,否则调用系统函数。在libgo中,定义了:
connect_f = (connect_t)dlsym(RTLD_NEXT, "connect");
if (connect_f) {
pipe_f = (pipe_t)dlsym(RTLD_NEXT, "pipe");
socket_f = (socket_t)dlsym(RTLD_NEXT, "socket");
socketpair_f = (socketpair_t)dlsym(RTLD_NEXT, "socketpair");
connect_f = (connect_t)dlsym(RTLD_NEXT, "connect");
read_f = (read_t)dlsym(RTLD_NEXT, "read");
readv_f = (readv_t)dlsym(RTLD_NEXT, "readv");
recv_f = (recv_t)dlsym(RTLD_NEXT, "recv");
recvfrom_f = (recvfrom_t)dlsym(RTLD_NEXT, "recvfrom");
recvmsg_f = (recvmsg_t)dlsym(RTLD_NEXT, "recvmsg");
write_f = (write_t)dlsym(RTLD_NEXT, "write");
writev_f = (writev_t)dlsym(RTLD_NEXT, "writev");
send_f = (send_t)dlsym(RTLD_NEXT, "send");
sendto_f = (sendto_t)dlsym(RTLD_NEXT, "sendto");
sendmsg_f = (sendmsg_t)dlsym(RTLD_NEXT, "sendmsg");
accept_f = (accept_t)dlsym(RTLD_NEXT, "accept");
poll_f = (poll_t)dlsym(RTLD_NEXT, "poll");
select_f = (select_t)dlsym(RTLD_NEXT, "select");
sleep_f = (sleep_t)dlsym(RTLD_NEXT, "sleep");
usleep_f = (usleep_t)dlsym(RTLD_NEXT, "usleep");
nanosleep_f = (nanosleep_t)dlsym(RTLD_NEXT, "nanosleep");
close_f = (close_t)dlsym(RTLD_NEXT, "close");
fcntl_f = (fcntl_t)dlsym(RTLD_NEXT, "fcntl");
ioctl_f = (ioctl_t)dlsym(RTLD_NEXT, "ioctl");
getsockopt_f = (getsockopt_t)dlsym(RTLD_NEXT, "getsockopt");
setsockopt_f = (setsockopt_t)dlsym(RTLD_NEXT, "setsockopt");
dup_f = (dup_t)dlsym(RTLD_NEXT, "dup");
dup2_f = (dup2_t)dlsym(RTLD_NEXT, "dup2");
dup3_f = (dup3_t)dlsym(RTLD_NEXT, "dup3");
fclose_f = (fclose_t)dlsym(RTLD_NEXT, "fclose");
#if defined(LIBGO_SYS_Linux)
pipe2_f = (pipe2_t)dlsym(RTLD_NEXT, "pipe2");
gethostbyname_r_f = (gethostbyname_r_t)dlsym(RTLD_NEXT, "gethostbyname_r");
gethostbyname2_r_f = (gethostbyname2_r_t)dlsym(RTLD_NEXT, "gethostbyname2_r");
gethostbyaddr_r_f = (gethostbyaddr_r_t)dlsym(RTLD_NEXT, "gethostbyaddr_r");
epoll_wait_f = (epoll_wait_t)dlsym(RTLD_NEXT, "epoll_wait");
看下sample code:
#define _GNU_SOURCE
#include <dlfcn.h>
#include <stdio.h>
#include <fcntl.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/ioctl.h>
#include <arpa/inet.h>
#include <mysql.h>
typedef int(*connect_t)(int, const struct sockaddr *, socklen_t);
connect_t connect_f;
typedef ssize_t(*read_t)(int, void *, size_t);
read_t read_f;
typedef ssize_t(*recv_t)(int sockfd, void *buf, size_t len, int flags);
recv_t recv_f;
typedef ssize_t(*write_t)(int, const void *, size_t);
write_t write_f;
typedef ssize_t(*send_t)(int sockfd, const void *buf, size_t len, int flags);
send_t send_f;
int connect(int sockfd, const struct sockaddr *addr, socklen_t len) {
printf("connect\n");
return connect_f(sockfd, addr, len);
}
ssize_t read(int fd, void *buffer, size_t len) {
printf("read\n");
return read_f(fd, buffer, len);
}
ssize_t recv(int sockfd, void *buf, size_t len, int flags) {
printf("recv\n");
return recv_f(sockfd, buf, len, flags);
}
ssize_t write(int fd, const void *buf, size_t len) {
printf("write\n");
return write_f(fd, buf, len);
}
ssize_t send(int sockfd, const void *buf, size_t len, int flags) {
printf("send\n");
return send_f(sockfd, buf, len, flags);
}
static int init_hook() {
connect_f = dlsym(RTLD_NEXT, "connect");
read_f = dlsym(RTLD_NEXT, "read");
recv_f = dlsym(RTLD_NEXT, "recv");
write_f = dlsym(RTLD_NEXT, "write");
send_f = dlsym(RTLD_NEXT, "send");
}
int main() {
init_hook();
MYSQL* m_mysql = mysql_init(NULL);
if (!m_mysql) {
printf("mysql_init failed\n");
return 0;
}
if (!mysql_real_connect(m_mysql,
"192.168.232.132", "xxx", "xxx",
"xxx", 3306,
NULL, CLIENT_FOUND_ROWS)) {
printf("mysql_real_connect failed: %s\n", mysql_error(m_mysql));
return 0;
} else{
printf("mysql_real_connect success\n");
}
}
调用的都是我们自己的定义的connect, read, recv…函数。