Skip to content

Commit

Permalink
重构mysql组件
Browse files Browse the repository at this point in the history
  • Loading branch information
xiyoo0812 committed Sep 18, 2023
1 parent a26ff11 commit 4851d20
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 71 deletions.
9 changes: 3 additions & 6 deletions core/luabus/src/lua_socket_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ int lua_socket_node::transfer_hash(lua_State* L, uint32_t session_id, uint16_t s
return 0;
}

int lua_socket_node::on_recv(slice* slice) {
void lua_socket_node::on_recv(slice* slice) {
if (m_type == eproto_type::proto_pb) {
return on_call_pb(slice);
}
Expand Down Expand Up @@ -268,7 +268,6 @@ int lua_socket_node::on_recv(slice* slice) {
}
break;
}
return header->len;
}

void lua_socket_node::on_forward_error(router_header* header, slice* slice) {
Expand All @@ -291,12 +290,11 @@ void lua_socket_node::on_transfer(transfer_header* header, slice* slice) {
m_lvm->object_call(this, "on_transfer", nullptr, std::tie(), header->len, session_id, service_id, target_id, slice);
}

int lua_socket_node::on_call_pb(slice* slice) {
void lua_socket_node::on_call_pb(slice* slice) {
socket_header* header = (socket_header*)slice->peek(sizeof(socket_header));
uint32_t session_id = header->session_id;
if (session_id > 0) session_id |= m_stoken;
m_lvm->object_call(this, "on_call_pb", nullptr, m_codec, std::tie(), header->len, session_id);
return header->len;
}

void lua_socket_node::on_call(router_header* header, slice* slice) {
Expand All @@ -305,7 +303,6 @@ void lua_socket_node::on_call(router_header* header, slice* slice) {
m_lvm->object_call(this, "on_call", nullptr, m_codec, std::tie(), header->len, session_id, flag);
}

int lua_socket_node::on_call_data(slice* slice) {
void lua_socket_node::on_call_data(slice* slice) {
m_lvm->object_call(this, "on_call_data", nullptr, m_codec, std::tie(), slice->size());
return m_codec->get_packet_len();
}
6 changes: 3 additions & 3 deletions core/luabus/src/lua_socket_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ class lua_socket_node
uint16_t m_sindex = 1;

private:
int on_recv(slice* slice);
int on_call_pb(slice* slice);
int on_call_data(slice* slice);
void on_recv(slice* slice);
void on_call_pb(slice* slice);
void on_call_data(slice* slice);
void on_call(router_header* header, slice* slice);
void on_transfer(transfer_header* header, slice* slice);
void on_forward_broadcast(router_header* header, size_t target_size);
Expand Down
2 changes: 1 addition & 1 deletion core/luabus/src/socket_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ void socket_mgr::set_connect_callback(uint32_t token, const std::function<void(b
}
}

void socket_mgr::set_package_callback(uint32_t token, const std::function<int(slice*)> cb) {
void socket_mgr::set_package_callback(uint32_t token, const std::function<void(slice*)> cb) {
auto node = get_object(token);
if (node) {
node->set_package_callback(cb);
Expand Down
4 changes: 2 additions & 2 deletions core/luabus/src/socket_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct socket_object
virtual void set_accept_callback(const std::function<void(int)> cb) { }
virtual void set_connect_callback(const std::function<void(bool, const char*)> cb) { }
virtual void set_error_callback(const std::function<void(const char*)> cb) { }
virtual void set_package_callback(const std::function<int(slice*)> cb) { }
virtual void set_package_callback(const std::function<void(slice*)> cb) { }
virtual bool is_same_kind(uint32_t kind) { return m_kind == kind; }

#ifdef _MSC_VER
Expand Down Expand Up @@ -103,7 +103,7 @@ class socket_mgr
void set_accept_callback(uint32_t token, const std::function<void(int)> cb);
void set_error_callback(uint32_t token, const std::function<void(const char*)> cb);
void set_connect_callback(uint32_t token, const std::function<void(bool, const char*)> cb);
void set_package_callback(uint32_t token, const std::function<int(slice*)> cb);
void set_package_callback(uint32_t token, const std::function<void(slice*)> cb);

bool watch_listen(socket_t fd, socket_object* object);
bool watch_accepted(socket_t fd, socket_object* object);
Expand Down
7 changes: 5 additions & 2 deletions core/luabus/src/socket_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void socket_stream::close() {
}
if (m_codec) {
m_codec->clean();
m_codec = nullptr;
}
shutdown(m_socket, SD_RECEIVE);
m_link_status = elink_status::link_colsing;
Expand Down Expand Up @@ -478,12 +479,14 @@ void socket_stream::dispatch_package() {
if (package_size == 0) break;
// 数据回调
slice->attach(data, package_size);
int read_size = m_package_cb(slice);
m_package_cb(slice);
if (!m_codec) break;
// 数据包解析失败
if (m_codec->failed()) {
if (!m_codec || m_codec->failed()) {
on_error(m_codec->err());
break;
}
size_t read_size = m_codec->get_packet_len();
// 数据包还没有收完整
if (read_size == 0) break;
// 接收缓冲读游标调整
Expand Down
4 changes: 2 additions & 2 deletions core/luabus/src/socket_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct socket_stream : public socket_object
void close() override;
void set_error_callback(const std::function<void(const char*)> cb) override { m_error_cb = cb; }
void set_connect_callback(const std::function<void(bool, const char*)> cb) override { m_connect_cb = cb; }
void set_package_callback(const std::function<int(slice*)> cb) override { m_package_cb = cb; }
void set_package_callback(const std::function<void(slice*)> cb) override { m_package_cb = cb; }
void set_timeout(int duration) override { m_timeout = duration; }
void set_nodelay(int flag) override { set_no_delay(m_socket, flag); }

Expand Down Expand Up @@ -70,5 +70,5 @@ struct socket_stream : public socket_object

std::function<void(const char*)> m_error_cb = nullptr;
std::function<void(bool, const char*)> m_connect_cb = nullptr;
std::function<int(slice*)> m_package_cb = nullptr;
std::function<void(slice*)> m_package_cb = nullptr;
};
6 changes: 3 additions & 3 deletions extend/lbson/src/bson.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,13 @@ namespace lbson {
size_t sz;
const char* dst = (const char*)slice->data(&sz);
for (l = 0; l < sz; ++l) {
if (l == sz - 1) {
throw invalid_argument("invalid bson block : cstring");
}
if (dst[l] == '\0') {
slice->erase(l + 1);
return dst;
}
if (l == sz - 1) {
throw invalid_argument("invalid bson block : cstring");
}
}
throw invalid_argument("invalid bson block : cstring");
return "";
Expand Down
70 changes: 37 additions & 33 deletions extend/lcodec/src/mysql.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ namespace lcodec {
const uint8_t COM_STMT_CLOSE = 0x19;

// constants
inline size_t CLIENT_FLAG = 260047; //0011 1111 0111 1100 1111
inline size_t MAX_PACKET_SIZE = 0xffffff;
inline size_t CLIENT_PLUGIN_AUTH = 1 << 19;
inline size_t CLIENT_DEPRECATE_EOF = 1 << 24;
inline uint32_t CLIENT_FLAG = 260047; //0011 1111 0111 1100 1111
inline uint32_t MAX_PACKET_SIZE = 1024*1024;
inline uint32_t CLIENT_PLUGIN_AUTH = 1 << 19;
inline uint32_t CLIENT_DEPRECATE_EOF = 1 << 24;

// field types
const uint16_t MYSQL_TYPE_TINY = 0x01;
Expand Down Expand Up @@ -62,7 +62,7 @@ namespace lcodec {
if (!m_slice) return 0;
uint32_t* packet_len = (uint32_t*)m_slice->peek(sizeof(uint32_t));
if (!packet_len) return 0;
m_packet_len = ((*packet_len) >> 8);
m_packet_len = ((*packet_len) & 0xffffff);
if (!m_slice->peek(m_packet_len)) return 0;
m_packet_len += sizeof(uint32_t);
if (m_packet_len > data_len) return 0;
Expand All @@ -75,6 +75,8 @@ namespace lcodec {
uint8_t cmd_id = (uint8_t)lua_tointeger(L, index++);
// session_id
size_t session_id = lua_tointeger(L, index++);
//4 byte header placeholder
m_buf->write<uint32_t>(0);
if (cmd_id != COM_CONNECT) {
return comand_encode(L, cmd_id, session_id, index, len);
}
Expand All @@ -85,7 +87,7 @@ namespace lcodec {
int top = lua_gettop(L);
if (sessions.empty()) throw invalid_argument("invalid mysql data");
uint32_t payload = *(uint32_t*)m_slice->read<uint32_t>();
uint32_t length = payload >> 8;
uint32_t length = (payload & 0xffffff);
if (length >= 0xffffff) throw invalid_argument("sharded packet not suppert!");
mysql_cmd cmd = sessions.front();
lua_pushinteger(L, cmd.session_id);
Expand All @@ -108,8 +110,7 @@ namespace lcodec {
uint8_t* comand_encode(lua_State* L, uint8_t cmd_id, size_t session_id, int index, size_t* len) {
m_buf->write<uint8_t>(cmd_id);
int top = lua_gettop(L);
int argnum = top - index;
if (argnum > 1) {
if (index <= top) {
if (lua_type(L, index) == LUA_TNUMBER) {
m_buf->write<uint32_t>(lua_tointeger(L, index++));
}
Expand All @@ -119,18 +120,20 @@ namespace lcodec {
m_buf->push_data(query, data_len);
}
}
if (argnum > 2) {
encode_stmt_args(L, index, argnum - 2);
if (index <= top) {
encode_stmt_args(L, index, top - index + 1);
}
// header
uint32_t size = ((m_buf->size() - 4) & 0xffffff);
m_buf->copy(0, (uint8_t*)&size, 4);
// cmd
if (cmd_id != COM_STMT_CLOSE) {
sessions.push_back(mysql_cmd{ cmd_id, session_id });
}
return m_buf->data(len);
}

uint8_t* auth_encode(lua_State* L, uint8_t cmd_id, size_t session_id, int index, size_t* len) {
//4 byte header placeholder
m_buf->write<uint32_t>(0);
//4 byte client_flag
m_buf->write<uint32_t>(CLIENT_FLAG);
//4 byte max_packet_size
Expand All @@ -142,14 +145,17 @@ namespace lcodec {
// username
uint8_t* user = (uint8_t*)lua_tolstring(L, index++, len);
m_buf->push_data(user, *len);
m_buf->push_data((uint8_t*)"\0", 1);
// auth_data
uint8_t* auth_data = (uint8_t*)lua_tolstring(L, index++, len);
m_buf->write<uint8_t>(*len);
m_buf->push_data(auth_data, *len);
//dbname
// dbname
const uint8_t* dbname = (const uint8_t*)lua_tolstring(L, index++, len);
m_buf->push_data(dbname, *len);
m_buf->push_data((uint8_t*)"\0", 1);
// header
uint32_t size = ((m_buf->size() - 4) << 8) | 0xffffff00;
uint32_t size = ((m_buf->size() - 4) & 0xffffff) | 0x01000000;
m_buf->copy(0, (uint8_t*)&size, 4);
// cmd
sessions.push_back(mysql_cmd{ cmd_id, session_id });
Expand Down Expand Up @@ -281,9 +287,8 @@ namespace lcodec {
lua_pushinteger(L, decode_length_encoded_number());
lua_setfield(L, -2, "warnings");
//info
size_t data_len;
const char* info = read_cstring(m_slice, data_len);
lua_pushlstring(L, info, data_len);
auto info = m_slice->eof();
lua_pushlstring(L, info.data(), info.size());
lua_setfield(L, -2, "info");
}

Expand All @@ -301,21 +306,20 @@ namespace lcodec {
lua_setfield(L, -2, "sql_state");
m_slice->erase(6);
//error_message
size_t data_len;
const char* error_message = read_cstring(m_slice, data_len);
auto error_message = m_slice->eof();
lua_pushlstring(L, error_message.data(), error_message.size());
lua_setfield(L, -2, "error_message");
}

bool eof_packet_decode(lua_State* L) {
//type
m_slice->read<uint8_t>();
if ((m_capability & CLIENT_DEPRECATE_EOF) == CLIENT_DEPRECATE_EOF) {
size_t data_len;
size_t affected_rows = decode_length_encoded_number();
size_t last_insert_id = decode_length_encoded_number();
uint16_t status_flags = *(uint16_t*)m_slice->read<uint16_t>();
uint16_t warnings = *(uint16_t*)m_slice->read<uint16_t>();
const char* info = read_cstring(m_slice, data_len);
auto info = m_slice->eof();
return ((status_flags & SERVER_MORE_RESULTS_EXISTS) == SERVER_MORE_RESULTS_EXISTS);
} else {
uint16_t warnings = *(uint16_t*)m_slice->read<uint16_t>();
Expand Down Expand Up @@ -351,6 +355,7 @@ namespace lcodec {
uint16_t capability_flag_1 = *(uint16_t*)m_slice->read<uint16_t>();
//1 byte character_set
uint8_t character_set = *(uint8_t*)m_slice->read<uint8_t>();
lua_pushinteger(L, character_set);
//2 byte status_flags
uint16_t status_flags = *(uint16_t*)m_slice->read<uint16_t>();
//2 byte capability_flags_2
Expand All @@ -360,34 +365,32 @@ namespace lcodec {
uint8_t auth_plugin_data_len = *(uint8_t*)m_slice->read<uint8_t>();
//10 byte reserved (all 0)
m_slice->erase(10);
uint8_t* scramble2 = nullptr;
//auth-plugin-data-part-2
if (auth_plugin_data_len > 0) {
scramble2 = m_slice->peek(auth_plugin_data_len - 8);
m_slice->erase(auth_plugin_data_len - 8);
}
char* scramble2 = nullptr;
auth_plugin_data_len = std::max(13, auth_plugin_data_len - 8);
scramble2 = (char*)m_slice->peek(auth_plugin_data_len);
m_slice->erase(auth_plugin_data_len);
lua_pushlstring(L, (char*)scramble1, 8);
lua_pushlstring(L, scramble2, 12);
//auth_plugin_name
const char* auth_plugin_name = nullptr;
if ((m_capability & CLIENT_PLUGIN_AUTH) == CLIENT_PLUGIN_AUTH) {
auth_plugin_name = read_cstring(m_slice, data_len);
lua_pushlstring(L, auth_plugin_name, data_len);
}
lua_pushinteger(L, character_set);
lua_pushlstring(L, (char*)scramble1, 8);
lua_pushlstring(L, (char*)scramble2, auth_plugin_data_len - 8);
lua_pushlstring(L, auth_plugin_name, 8);
}

const char* read_cstring(slice* slice, size_t& l) {
size_t sz;
const char* dst = (const char*)slice->data(&sz);
for (l = 0; l < sz; ++l) {
if (l == sz - 1) {
throw invalid_argument("invalid mysql block : cstring");
}
if (dst[l] == '\0') {
slice->erase(l + 1);
return dst;
}
if (l == sz - 1) {
throw invalid_argument("invalid mysql block : cstring");
}
}
throw invalid_argument("invalid mysql block : cstring");
return "";
Expand Down Expand Up @@ -489,6 +492,7 @@ namespace lcodec {
if (length > 0) {
char* data = (char*)m_slice->peek(length);
if (!data) throw invalid_argument("invalid length coded string:" + length);
m_slice->erase(length);
return string_view(data, length);
}
return "";
Expand Down
7 changes: 7 additions & 0 deletions extend/luakit/include/lua_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ namespace luakit {
return std::string_view((const char*)m_head, len);
}

std::string_view eof() {
uint8_t* head = m_head;
m_head = m_tail;
size_t len = (size_t)(m_tail - head);
return std::string_view((const char*)head, len);
}

int string(lua_State* L) {
size_t len = (size_t)(m_tail - m_head);
lua_pushlstring(L, (const char*)m_head, len);
Expand Down
Loading

0 comments on commit 4851d20

Please sign in to comment.