libgo的上下文切换

在 libgo 的上下文切换上,并没有自己去实现创建和维护栈空间、保存和切换 CPU 寄存器执行状态信息等的任务,而是直接使用了 Boost.Context。Boost.Context 作为众多协程底层支持库,性能方面一直在被优化。

Boost.Context所做的工作,就是在传统的线程环境中可以保存当前执行的抽象状态信息(栈空间、栈指针、CPU寄存器和状态寄存器、IP指令指针),然后暂停当前的执行状态,程序的执行流程跳转到其他位置继续执行,这个基础构建可以用于开辟用户态的线程,从而构建出更加高级的协程等操作接口。同时因为这个切换是在用户空间的,所以资源损耗很小,同时保存了栈空间和执行状态的所有信息,所以其中的函数可以*被嵌套使用。引用自https://blog.csdn.net/qq_35976351/article/details/107449235

文章目录

fcontext_t

libgo/context/fcontext.h

Boost.Context 的底层实现是通过 fcontext_t 结构体来保存协程状态,使用 make_fcontext 创建协程,使用 jump_fcontext 实现协程切换。在 libgo 协程中,直接引用了这两个接口函数。

extern "C"
{

typedef void* fcontext_t;
typedef void (FCONTEXT_CALL *fn_t)(intptr_t);

intptr_t jump_fcontext(fcontext_t * ofc, fcontext_t nfc,
        intptr_t vp, bool preserve_fpu = false);

fcontext_t make_fcontext(void* stack, std::size_t size, fn_t fn);

} // extern "C"

除此之外,还提供了一系列的栈函数。

namespace co {

struct StackTraits
{
    static stack_malloc_fn_t& MallocFunc();

    static stack_free_fn_t& FreeFunc();
    // 获取当前栈顶设置的保护页的页数
    static int & GetProtectStackPageSize();
	// 对保护页的内容做保护
    static bool ProtectStack(void* stack, std::size_t size, int pageSize);
   // 取消对保护页的内存保护,析构是才会调用
    static void UnprotectStack(void* stack, int pageSize);
};

} // namespace co

栈保护

libgo 对栈对保护,使用了 mprotect 系统调用实现。我们在给该协程创建了大小为 N 字节对栈空间时,会对栈顶的一部分的空间进行保护,因此,分配的协程栈的大小,应该要大于要保护的内存页数加一。

为什么提到保护栈,总是以页为单位呢?因为 mprotect 是按照页来进行设置的,因此,对没有对齐地址,应该首先对齐之后再去操作。

    bool StackTraits::ProtectStack(void* stack, std::size_t size, int pageSize)
    {
        if (!pageSize) return false;
          // 协程栈的大小,应该大于(保护内存页数+1)
        if ((int)size <= getpagesize() * (pageSize + 1))
            return false;
    // 使用 mprotect 保护的内存页应该是按页对其的
    // 栈从高地址向地地址生长,被保护的栈空间应该位于栈顶(低地址处)
    // protect_page_addr 是在当前协程栈内取最近的整数页边界的地址,如:0xf7234008 ---> 0xf7235000
        void *protect_page_addr = ((std::size_t)stack & 0xfff) ? (void*)(((std::size_t)stack & ~(std::size_t)0xfff) + 0x1000) : stack;
         // 使用 mprotect 系统调用实现栈保护,PROT_NONE 表明该内存空间不可访问
        if (-1 == mprotect(protect_page_addr, getpagesize() * pageSize, PROT_NONE)) {
            DebugPrint(dbg_task, "origin_addr:%p, align_addr:%p, page_size:%d, protect_page:%u, protect stack stack error: %s",
                    stack, protect_page_addr, getpagesize(), pageSize, strerror(errno));
            return false;
        } else {
            DebugPrint(dbg_task, "origin_addr:%p, align_addr:%p, page_size:%d, protect_page:%u, protect stack success.",
                    stack, protect_page_addr, pageSize, getpagesize());
            return true;
        }
    }

取消栈保护

取消栈保护只有在释放该协程空间的时候会调用。

    void StackTraits::UnprotectStack(void *stack, int pageSize)
    {
        if (!pageSize) return ;
	
        void *protect_page_addr = ((std::size_t)stack & 0xfff) ? (void*)(((std::size_t)stack & ~(std::size_t)0xfff) + 0x1000) : stack;
        // 允许该块内存可读可写
        if (-1 == mprotect(protect_page_addr, getpagesize() * pageSize, PROT_READ|PROT_WRITE)) {
            DebugPrint(dbg_task, "origin_addr:%p, align_addr:%p, page_size:%d, protect_page:%u, protect stack stack error: %s",
                    stack, protect_page_addr, getpagesize(), pageSize, strerror(errno));
        } else {
            DebugPrint(dbg_task, "origin_addr:%p, align_addr:%p, page_size:%d, protect_page:%u, protect stack success.",
                    stack, protect_page_addr, pageSize, getpagesize());
        }
    }

mprotect 系统调用使用说明

#include <sys/mman.h>
int mprotect(void *addr, size_t len, int prot);
    addr:应该是按页对其的内存地址
    len:保护的内存页大小,因此保护的地址范围应该是[addr, addr+len-1]
    prot:保护类型
        PROT_NONE  The memory cannot be accessed at all.
        PROT_READ  The memory can be read.
        PROT_WRITE The memory can be modified.
        PROT_EXEC  The memory can be executed.

Context

Context 是 libgo 中封装的上下文对象,每个协程都会有一份独有的。

class Context
{
public:
    Context(fn_t fn, intptr_t vp, std::size_t stackSize)
        : fn_(fn), vp_(vp), stackSize_(stackSize)
    {
        stack_ = (char*)StackTraits::MallocFunc()(stackSize_);
        DebugPrint(dbg_task, "valloc stack. size=%u ptr=%p",
                stackSize_, stack_);

        ctx_ = make_fcontext(stack_ + stackSize_, stackSize_, fn_);

        int protectPage = StackTraits::GetProtectStackPageSize();
        if (protectPage && StackTraits::ProtectStack(stack_, stackSize_, protectPage))
            protectPage_ = protectPage;
    }
    ~Context()
    {
        if (stack_) {
            DebugPrint(dbg_task, "free stack. ptr=%p", stack_);
            if (protectPage_)
                StackTraits::UnprotectStack(stack_, protectPage_);
            StackTraits::FreeFunc()(stack_);
            stack_ = NULL;
        }
    }

    ALWAYS_INLINE void SwapIn()
    {
        jump_fcontext(&GetTlsContext(), ctx_, vp_);
    }

    ALWAYS_INLINE void SwapTo(Context & other)
    {
        jump_fcontext(&ctx_, other.ctx_, other.vp_);
    }

    ALWAYS_INLINE void SwapOut()
    {
        jump_fcontext(&ctx_, GetTlsContext(), 0);
    }

    fcontext_t& GetTlsContext()
    {
        static thread_local fcontext_t tls_context;
        return tls_context;
    }

private:
    fcontext_t ctx_;
    fn_t fn_;  // 协程运行函数
    intptr_t vp_; // 当前上下文属于的协程 Task 对象指针
    char* stack_ = nullptr;  // 栈空间
    uint32_t stackSize_ = 0;// 栈大小
    int protectPage_ = 0; // 保护页的数量
};

Hook

介绍hook之前,需要先了解两个系统函数:

#include <dlfcn.h>
void *dlopen(const char *filename, int flag);
void *dlsym(void *handle, const char *symbol);

dlopen以指定模式打开指定的动态连接库文件,并返回一个句柄给调用进程,dlsym通过句柄和连接符名称获取函数名或者变量名。hook的意思就是说,在系统调用函数中,会先查看是否自己有先定义过,如果有则调用自己的定义的函数,否则调用系统函数。在libgo中,定义了:

   connect_f = (connect_t)dlsym(RTLD_NEXT, "connect");
    if (connect_f) {
        pipe_f = (pipe_t)dlsym(RTLD_NEXT, "pipe");
        socket_f = (socket_t)dlsym(RTLD_NEXT, "socket");
        socketpair_f = (socketpair_t)dlsym(RTLD_NEXT, "socketpair");
        connect_f = (connect_t)dlsym(RTLD_NEXT, "connect");
        read_f = (read_t)dlsym(RTLD_NEXT, "read");
        readv_f = (readv_t)dlsym(RTLD_NEXT, "readv");
        recv_f = (recv_t)dlsym(RTLD_NEXT, "recv");
        recvfrom_f = (recvfrom_t)dlsym(RTLD_NEXT, "recvfrom");
        recvmsg_f = (recvmsg_t)dlsym(RTLD_NEXT, "recvmsg");
        write_f = (write_t)dlsym(RTLD_NEXT, "write");
        writev_f = (writev_t)dlsym(RTLD_NEXT, "writev");
        send_f = (send_t)dlsym(RTLD_NEXT, "send");
        sendto_f = (sendto_t)dlsym(RTLD_NEXT, "sendto");
        sendmsg_f = (sendmsg_t)dlsym(RTLD_NEXT, "sendmsg");
        accept_f = (accept_t)dlsym(RTLD_NEXT, "accept");
        poll_f = (poll_t)dlsym(RTLD_NEXT, "poll");
        select_f = (select_t)dlsym(RTLD_NEXT, "select");
        sleep_f = (sleep_t)dlsym(RTLD_NEXT, "sleep");
        usleep_f = (usleep_t)dlsym(RTLD_NEXT, "usleep");
        nanosleep_f = (nanosleep_t)dlsym(RTLD_NEXT, "nanosleep");
        close_f = (close_t)dlsym(RTLD_NEXT, "close");
        fcntl_f = (fcntl_t)dlsym(RTLD_NEXT, "fcntl");
        ioctl_f = (ioctl_t)dlsym(RTLD_NEXT, "ioctl");
        getsockopt_f = (getsockopt_t)dlsym(RTLD_NEXT, "getsockopt");
        setsockopt_f = (setsockopt_t)dlsym(RTLD_NEXT, "setsockopt");
        dup_f = (dup_t)dlsym(RTLD_NEXT, "dup");
        dup2_f = (dup2_t)dlsym(RTLD_NEXT, "dup2");
        dup3_f = (dup3_t)dlsym(RTLD_NEXT, "dup3");
        fclose_f = (fclose_t)dlsym(RTLD_NEXT, "fclose");
#if defined(LIBGO_SYS_Linux)
        pipe2_f = (pipe2_t)dlsym(RTLD_NEXT, "pipe2");
        gethostbyname_r_f = (gethostbyname_r_t)dlsym(RTLD_NEXT, "gethostbyname_r");
        gethostbyname2_r_f = (gethostbyname2_r_t)dlsym(RTLD_NEXT, "gethostbyname2_r");
        gethostbyaddr_r_f = (gethostbyaddr_r_t)dlsym(RTLD_NEXT, "gethostbyaddr_r");
        epoll_wait_f = (epoll_wait_t)dlsym(RTLD_NEXT, "epoll_wait");

看下sample code:




#define _GNU_SOURCE
#include <dlfcn.h>
#include <stdio.h>

#include <fcntl.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/ioctl.h>
#include <arpa/inet.h>


#include <mysql.h>


typedef int(*connect_t)(int, const struct sockaddr *, socklen_t);
connect_t connect_f;

typedef ssize_t(*read_t)(int, void *, size_t);
read_t read_f;

typedef ssize_t(*recv_t)(int sockfd, void *buf, size_t len, int flags);
recv_t recv_f;

typedef ssize_t(*write_t)(int, const void *, size_t);
write_t write_f;

typedef ssize_t(*send_t)(int sockfd, const void *buf, size_t len, int flags);
send_t send_f;


int connect(int sockfd, const struct sockaddr *addr, socklen_t len) {
	printf("connect\n");
	return connect_f(sockfd, addr, len);

}

ssize_t read(int fd, void *buffer, size_t len) {
	printf("read\n");
	return read_f(fd, buffer, len);
}

ssize_t recv(int sockfd, void *buf, size_t len, int flags) {
	printf("recv\n");
	return recv_f(sockfd, buf, len, flags);
}

ssize_t write(int fd, const void *buf, size_t len) {
	printf("write\n");
	return write_f(fd, buf, len);
}


ssize_t send(int sockfd, const void *buf, size_t len, int flags) {
	printf("send\n");
	return send_f(sockfd, buf, len, flags);
}


static int init_hook() {

	connect_f = dlsym(RTLD_NEXT, "connect");
	
	read_f = dlsym(RTLD_NEXT, "read");
	recv_f = dlsym(RTLD_NEXT, "recv");
	
	write_f = dlsym(RTLD_NEXT, "write");
	send_f = dlsym(RTLD_NEXT, "send");
}



int main() {

	init_hook();

	
	MYSQL* m_mysql = mysql_init(NULL);
	if (!m_mysql) {
		printf("mysql_init failed\n");
		return 0;
	}

	if (!mysql_real_connect(m_mysql, 
				"192.168.232.132", "xxx", "xxx",
				"xxx", 3306,
				NULL, CLIENT_FOUND_ROWS)) {
		printf("mysql_real_connect failed: %s\n", mysql_error(m_mysql));
		return 0;
	} else{
		printf("mysql_real_connect success\n");
	}
	
}

调用的都是我们自己的定义的connect, read, recv…函数。

上一篇:第1年7月12日 iOS运行加载动态库


下一篇:论文笔记1---HVAC fault detection using a system identification approach