#ifndef __APP_WEBSOCKET_FRAME_H__ #define __APP_WEBSOCKET_FRAME_H__ #include "memory.hpp" class buffer; struct websocket_frame { websocket_frame(); ~websocket_frame(); static const unsigned int fix_min_len = 2; static const unsigned int fix_masking_len = 4; static const unsigned int fix_126_len = 2; static const unsigned int fix_127_len = 8; bool fin; bool rsv1; bool rsv2; bool rsv3; unsigned char opcode; bool mask; unsigned char payload_len ; unsigned short payload_126_len; unsigned long long payload_127_len; unsigned char masking_key[4]; shared_ptr<buffer> payload; unsigned int get_head_len() const ; unsigned int get_total_len() const ; unsigned int get_payload_len() const ; //解包 int unpakcage( const shared_ptr<buffer> & buf ); //打包 unsigned int package_size() ; shared_ptr<buffer> package() ; protected: void set_masking_key( unsigned char * data, int offset); void get_masking_key( unsigned char * data, int offset); }; #endif
#include "websocket_frame.hpp" #include "buffer_pool.hpp" #include <cstring> #include <glog/logging.h> #include <cstdlib> websocket_frame::websocket_frame():fin(false) ,rsv1(false),rsv2(false),rsv3(false),mask(false),opcode(0) ,payload_len(0),payload_126_len(0),payload_127_len(0) { int rand_var = rand(); masking_key[0] = (rand_var & 0xff000000)>>24; masking_key[1] = (rand_var & 0x00ff0000) >> 16; masking_key[2] = ( rand_var & 0x0000ff00) >> 8; masking_key[3] = rand_var & 0x000000ff; } websocket_frame::~websocket_frame(){ if( payload ){ recycle_buffer(payload); } } void websocket_frame::set_masking_key( unsigned char * data, int offset){ masking_key[0] = data[0+offset]; masking_key[1] = data[1+offset]; masking_key[2] = data[2+offset]; masking_key[3] = data[3+offset]; } void websocket_frame:: get_masking_key( unsigned char * data, int offset){ data[0+offset] = masking_key[0]; data[1+offset] = masking_key[1] ; data[2+offset] = masking_key[2]; data[3+offset] = masking_key[3] ; } unsigned int websocket_frame::get_head_len() const { return get_total_len() - get_payload_len(); } unsigned int websocket_frame::get_total_len() const { if( mask ){ if( payload_len < 126 ){ return payload_len + fix_min_len + fix_masking_len; }else if( payload_len == 126 ){ return payload_126_len + fix_min_len + fix_masking_len + fix_126_len; }else { return payload_127_len +fix_min_len + fix_masking_len + fix_127_len; } }else{ if( payload_len < 126 ){ return payload_len + fix_min_len; }else if( payload_len == 126 ){ return payload_126_len + fix_min_len + fix_126_len; }else { return payload_127_len + fix_min_len + fix_127_len; } } } unsigned int websocket_frame::get_payload_len() const { if( payload_len < 126 ){ return payload_len; }else if( payload_len == 126 ){ return payload_126_len; }else { return payload_127_len; } } int websocket_frame::unpakcage( const shared_ptr<buffer> & buf ){ if( buf->length() == 0 ){ return 2; } unsigned char * data = buf->data(); fin = data[0] & 0x80; rsv1 = data[0] & (0x40 ); rsv2 = data[0] & ( 0x20); rsv3 = data[0] & ( 0x10); opcode = (data[0] & 0x0f) ; mask = data[1] & 0x80; payload_len = data[1] & 0x7f; if( payload_len < 126 ){ if( buf->length() < get_total_len() ){ return get_total_len() - buf->length(); } payload = get_buffer( get_payload_len()); unsigned char * pdata = payload->data(); if( mask ){ set_masking_key(data,2); int i; for( i = 0 ; i < payload_len; i ++ ){ pdata[i] = data[6+i] ^ masking_key[i%4]; } }else{ memcpy(pdata, data+2, get_payload_len()); } payload->size(get_payload_len()); }else if( payload_len == 126 ){ if( buf->length() < 4 ){ return 4 - buf->length(); } payload_126_len = (data[2] << 8 )| data[3]; if( buf->length() < get_total_len() ){ return get_total_len() - buf->length(); } payload = get_buffer( get_payload_len()); unsigned char * pdata = payload->data(); if( mask ){ set_masking_key(data,4); int i; for( i = 0 ; i < payload_126_len; i ++ ){ pdata[i] = data[8+i] ^ masking_key[i%4]; } }else{ memcpy(pdata, data+4, get_payload_len()); } payload->size(get_payload_len()); }else{ // 127 LOG(FATAL)<<"un support big frame for websocket."; } return 0; } unsigned int websocket_frame::package_size() { if( payload->length() < 126 ){ payload_len = payload->length(); }else if( payload->length() < 0xffff ){ payload_len = 126; payload_126_len = payload->length(); }else{ payload_len = 127; payload_127_len = payload->length(); } return get_total_len(); } shared_ptr<buffer> websocket_frame::package() { shared_ptr<buffer> buf = get_buffer( package_size()); buf->size(package_size()); unsigned char * data = buf->data(); data[0] = 0; if( fin ){ data[0] |= 0x80; } if( rsv1 ){ data[0] |= 0x40; } if( rsv2 ){ data[0] |= 0x20; } if( rsv3 ) { data[0] |= 0x10; } data[0] |= (opcode & 0x0f); data[1] = 0; if( mask ){ data[1] |= 0x80; } data[1] |=( payload_len & 0x7f); unsigned char * pdata = payload->data(); if( payload_len < 126 ){ if( mask ){ //get_masking_key(data,2); data[2] = masking_key[0]; data[3] = masking_key[1]; data[4] = masking_key[2]; data[5] = masking_key[3]; int i = 0; for( i = 0 ; i < payload_len; i ++ ){ data[6+i] = pdata[i] ^ masking_key[i%4]; } }else{ memcpy(data+2, pdata,payload_len); } }else if( payload_len == 126 ){ data[2] =(unsigned char)( (payload_126_len & 0xff00) >> 8); data[3] =(unsigned char) (payload_126_len & 0x00ff); if( mask ){ data[4] = masking_key[0]; data[5] = masking_key[1]; data[6] = masking_key[2]; data[7] = masking_key[3]; int i = 0; for( i = 0 ; i < payload_126_len; i ++ ){ data[8+i] = pdata[i] ^ masking_key[i%4]; } }else{ memcpy(data+4, pdata,payload_126_len); } }else if( payload_len == 127 ){ LOG(FATAL)<<"no support 127 web socket frame."; } return buf; }
重点在unpakcage方法,该方法在解包成功时,返回0,失败时,返回还需要多少字节才可以解析一个完整的包。
另外暂时没有计算payload_len=127情况下,主要是大部分环境下不需要用到那么大。解包成功后,get_head_len、get_total_len、get_payload_len调用时,都可以返回正确的数据,如果unpackage返回大于0时,调用关于len的函数是属于非法的。。