diff --git a/core/luabus/src/lua_socket_node.cpp b/core/luabus/src/lua_socket_node.cpp index 907905b3..b6ee069c 100644 --- a/core/luabus/src/lua_socket_node.cpp +++ b/core/luabus/src/lua_socket_node.cpp @@ -213,7 +213,7 @@ int lua_socket_node::transfer_hash(lua_State* L, uint32_t session_id, uint16_t s return 0; } -int lua_socket_node::on_recv(slice* slice) { +void lua_socket_node::on_recv(slice* slice) { if (m_type == eproto_type::proto_pb) { return on_call_pb(slice); } @@ -268,7 +268,6 @@ int lua_socket_node::on_recv(slice* slice) { } break; } - return header->len; } void lua_socket_node::on_forward_error(router_header* header, slice* slice) { @@ -291,12 +290,11 @@ void lua_socket_node::on_transfer(transfer_header* header, slice* slice) { m_lvm->object_call(this, "on_transfer", nullptr, std::tie(), header->len, session_id, service_id, target_id, slice); } -int lua_socket_node::on_call_pb(slice* slice) { +void lua_socket_node::on_call_pb(slice* slice) { socket_header* header = (socket_header*)slice->peek(sizeof(socket_header)); uint32_t session_id = header->session_id; if (session_id > 0) session_id |= m_stoken; m_lvm->object_call(this, "on_call_pb", nullptr, m_codec, std::tie(), header->len, session_id); - return header->len; } void lua_socket_node::on_call(router_header* header, slice* slice) { @@ -305,7 +303,6 @@ void lua_socket_node::on_call(router_header* header, slice* slice) { m_lvm->object_call(this, "on_call", nullptr, m_codec, std::tie(), header->len, session_id, flag); } -int lua_socket_node::on_call_data(slice* slice) { +void lua_socket_node::on_call_data(slice* slice) { m_lvm->object_call(this, "on_call_data", nullptr, m_codec, std::tie(), slice->size()); - return m_codec->get_packet_len(); } diff --git a/core/luabus/src/lua_socket_node.h b/core/luabus/src/lua_socket_node.h index d959c070..d717ca17 100644 --- a/core/luabus/src/lua_socket_node.h +++ b/core/luabus/src/lua_socket_node.h @@ -73,9 +73,9 @@ class lua_socket_node uint16_t m_sindex = 1; private: - int on_recv(slice* slice); - int on_call_pb(slice* slice); - int on_call_data(slice* slice); + void on_recv(slice* slice); + void on_call_pb(slice* slice); + void on_call_data(slice* slice); void on_call(router_header* header, slice* slice); void on_transfer(transfer_header* header, slice* slice); void on_forward_broadcast(router_header* header, size_t target_size); diff --git a/core/luabus/src/socket_mgr.cpp b/core/luabus/src/socket_mgr.cpp index bbca18b7..b00f92e7 100644 --- a/core/luabus/src/socket_mgr.cpp +++ b/core/luabus/src/socket_mgr.cpp @@ -324,7 +324,7 @@ void socket_mgr::set_connect_callback(uint32_t token, const std::function cb) { +void socket_mgr::set_package_callback(uint32_t token, const std::function cb) { auto node = get_object(token); if (node) { node->set_package_callback(cb); diff --git a/core/luabus/src/socket_mgr.h b/core/luabus/src/socket_mgr.h index 3b905a24..4d909316 100644 --- a/core/luabus/src/socket_mgr.h +++ b/core/luabus/src/socket_mgr.h @@ -53,7 +53,7 @@ struct socket_object virtual void set_accept_callback(const std::function cb) { } virtual void set_connect_callback(const std::function cb) { } virtual void set_error_callback(const std::function cb) { } - virtual void set_package_callback(const std::function cb) { } + virtual void set_package_callback(const std::function cb) { } virtual bool is_same_kind(uint32_t kind) { return m_kind == kind; } #ifdef _MSC_VER @@ -103,7 +103,7 @@ class socket_mgr void set_accept_callback(uint32_t token, const std::function cb); void set_error_callback(uint32_t token, const std::function cb); void set_connect_callback(uint32_t token, const std::function cb); - void set_package_callback(uint32_t token, const std::function cb); + void set_package_callback(uint32_t token, const std::function cb); bool watch_listen(socket_t fd, socket_object* object); bool watch_accepted(socket_t fd, socket_object* object); diff --git a/core/luabus/src/socket_stream.cpp b/core/luabus/src/socket_stream.cpp index 1171d18e..56bc86a4 100644 --- a/core/luabus/src/socket_stream.cpp +++ b/core/luabus/src/socket_stream.cpp @@ -64,6 +64,7 @@ void socket_stream::close() { } if (m_codec) { m_codec->clean(); + m_codec = nullptr; } shutdown(m_socket, SD_RECEIVE); m_link_status = elink_status::link_colsing; @@ -478,12 +479,14 @@ void socket_stream::dispatch_package() { if (package_size == 0) break; // 数据回调 slice->attach(data, package_size); - int read_size = m_package_cb(slice); + m_package_cb(slice); + if (!m_codec) break; // 数据包解析失败 - if (m_codec->failed()) { + if (!m_codec || m_codec->failed()) { on_error(m_codec->err()); break; } + size_t read_size = m_codec->get_packet_len(); // 数据包还没有收完整 if (read_size == 0) break; // 接收缓冲读游标调整 diff --git a/core/luabus/src/socket_stream.h b/core/luabus/src/socket_stream.h index b79d7847..bcec044f 100644 --- a/core/luabus/src/socket_stream.h +++ b/core/luabus/src/socket_stream.h @@ -19,7 +19,7 @@ struct socket_stream : public socket_object void close() override; void set_error_callback(const std::function cb) override { m_error_cb = cb; } void set_connect_callback(const std::function cb) override { m_connect_cb = cb; } - void set_package_callback(const std::function cb) override { m_package_cb = cb; } + void set_package_callback(const std::function cb) override { m_package_cb = cb; } void set_timeout(int duration) override { m_timeout = duration; } void set_nodelay(int flag) override { set_no_delay(m_socket, flag); } @@ -70,5 +70,5 @@ struct socket_stream : public socket_object std::function m_error_cb = nullptr; std::function m_connect_cb = nullptr; - std::function m_package_cb = nullptr; + std::function m_package_cb = nullptr; }; diff --git a/extend/lbson/src/bson.h b/extend/lbson/src/bson.h index f2077687..4249159a 100644 --- a/extend/lbson/src/bson.h +++ b/extend/lbson/src/bson.h @@ -348,13 +348,13 @@ namespace lbson { size_t sz; const char* dst = (const char*)slice->data(&sz); for (l = 0; l < sz; ++l) { - if (l == sz - 1) { - throw invalid_argument("invalid bson block : cstring"); - } if (dst[l] == '\0') { slice->erase(l + 1); return dst; } + if (l == sz - 1) { + throw invalid_argument("invalid bson block : cstring"); + } } throw invalid_argument("invalid bson block : cstring"); return ""; diff --git a/extend/lcodec/src/mysql.h b/extend/lcodec/src/mysql.h index 0d1bbd5e..ca9467bb 100644 --- a/extend/lcodec/src/mysql.h +++ b/extend/lcodec/src/mysql.h @@ -15,10 +15,10 @@ namespace lcodec { const uint8_t COM_STMT_CLOSE = 0x19; // constants - inline size_t CLIENT_FLAG = 260047; //0011 1111 0111 1100 1111 - inline size_t MAX_PACKET_SIZE = 0xffffff; - inline size_t CLIENT_PLUGIN_AUTH = 1 << 19; - inline size_t CLIENT_DEPRECATE_EOF = 1 << 24; + inline uint32_t CLIENT_FLAG = 260047; //0011 1111 0111 1100 1111 + inline uint32_t MAX_PACKET_SIZE = 1024*1024; + inline uint32_t CLIENT_PLUGIN_AUTH = 1 << 19; + inline uint32_t CLIENT_DEPRECATE_EOF = 1 << 24; // field types const uint16_t MYSQL_TYPE_TINY = 0x01; @@ -62,7 +62,7 @@ namespace lcodec { if (!m_slice) return 0; uint32_t* packet_len = (uint32_t*)m_slice->peek(sizeof(uint32_t)); if (!packet_len) return 0; - m_packet_len = ((*packet_len) >> 8); + m_packet_len = ((*packet_len) & 0xffffff); if (!m_slice->peek(m_packet_len)) return 0; m_packet_len += sizeof(uint32_t); if (m_packet_len > data_len) return 0; @@ -75,6 +75,8 @@ namespace lcodec { uint8_t cmd_id = (uint8_t)lua_tointeger(L, index++); // session_id size_t session_id = lua_tointeger(L, index++); + //4 byte header placeholder + m_buf->write(0); if (cmd_id != COM_CONNECT) { return comand_encode(L, cmd_id, session_id, index, len); } @@ -85,7 +87,7 @@ namespace lcodec { int top = lua_gettop(L); if (sessions.empty()) throw invalid_argument("invalid mysql data"); uint32_t payload = *(uint32_t*)m_slice->read(); - uint32_t length = payload >> 8; + uint32_t length = (payload & 0xffffff); if (length >= 0xffffff) throw invalid_argument("sharded packet not suppert!"); mysql_cmd cmd = sessions.front(); lua_pushinteger(L, cmd.session_id); @@ -108,8 +110,7 @@ namespace lcodec { uint8_t* comand_encode(lua_State* L, uint8_t cmd_id, size_t session_id, int index, size_t* len) { m_buf->write(cmd_id); int top = lua_gettop(L); - int argnum = top - index; - if (argnum > 1) { + if (index <= top) { if (lua_type(L, index) == LUA_TNUMBER) { m_buf->write(lua_tointeger(L, index++)); } @@ -119,9 +120,13 @@ namespace lcodec { m_buf->push_data(query, data_len); } } - if (argnum > 2) { - encode_stmt_args(L, index, argnum - 2); + if (index <= top) { + encode_stmt_args(L, index, top - index + 1); } + // header + uint32_t size = ((m_buf->size() - 4) & 0xffffff); + m_buf->copy(0, (uint8_t*)&size, 4); + // cmd if (cmd_id != COM_STMT_CLOSE) { sessions.push_back(mysql_cmd{ cmd_id, session_id }); } @@ -129,8 +134,6 @@ namespace lcodec { } uint8_t* auth_encode(lua_State* L, uint8_t cmd_id, size_t session_id, int index, size_t* len) { - //4 byte header placeholder - m_buf->write(0); //4 byte client_flag m_buf->write(CLIENT_FLAG); //4 byte max_packet_size @@ -142,14 +145,17 @@ namespace lcodec { // username uint8_t* user = (uint8_t*)lua_tolstring(L, index++, len); m_buf->push_data(user, *len); + m_buf->push_data((uint8_t*)"\0", 1); + // auth_data uint8_t* auth_data = (uint8_t*)lua_tolstring(L, index++, len); m_buf->write(*len); m_buf->push_data(auth_data, *len); - //dbname + // dbname const uint8_t* dbname = (const uint8_t*)lua_tolstring(L, index++, len); m_buf->push_data(dbname, *len); + m_buf->push_data((uint8_t*)"\0", 1); // header - uint32_t size = ((m_buf->size() - 4) << 8) | 0xffffff00; + uint32_t size = ((m_buf->size() - 4) & 0xffffff) | 0x01000000; m_buf->copy(0, (uint8_t*)&size, 4); // cmd sessions.push_back(mysql_cmd{ cmd_id, session_id }); @@ -281,9 +287,8 @@ namespace lcodec { lua_pushinteger(L, decode_length_encoded_number()); lua_setfield(L, -2, "warnings"); //info - size_t data_len; - const char* info = read_cstring(m_slice, data_len); - lua_pushlstring(L, info, data_len); + auto info = m_slice->eof(); + lua_pushlstring(L, info.data(), info.size()); lua_setfield(L, -2, "info"); } @@ -301,8 +306,8 @@ namespace lcodec { lua_setfield(L, -2, "sql_state"); m_slice->erase(6); //error_message - size_t data_len; - const char* error_message = read_cstring(m_slice, data_len); + auto error_message = m_slice->eof(); + lua_pushlstring(L, error_message.data(), error_message.size()); lua_setfield(L, -2, "error_message"); } @@ -310,12 +315,11 @@ namespace lcodec { //type m_slice->read(); if ((m_capability & CLIENT_DEPRECATE_EOF) == CLIENT_DEPRECATE_EOF) { - size_t data_len; size_t affected_rows = decode_length_encoded_number(); size_t last_insert_id = decode_length_encoded_number(); uint16_t status_flags = *(uint16_t*)m_slice->read(); uint16_t warnings = *(uint16_t*)m_slice->read(); - const char* info = read_cstring(m_slice, data_len); + auto info = m_slice->eof(); return ((status_flags & SERVER_MORE_RESULTS_EXISTS) == SERVER_MORE_RESULTS_EXISTS); } else { uint16_t warnings = *(uint16_t*)m_slice->read(); @@ -351,6 +355,7 @@ namespace lcodec { uint16_t capability_flag_1 = *(uint16_t*)m_slice->read(); //1 byte character_set uint8_t character_set = *(uint8_t*)m_slice->read(); + lua_pushinteger(L, character_set); //2 byte status_flags uint16_t status_flags = *(uint16_t*)m_slice->read(); //2 byte capability_flags_2 @@ -360,34 +365,32 @@ namespace lcodec { uint8_t auth_plugin_data_len = *(uint8_t*)m_slice->read(); //10 byte reserved (all 0) m_slice->erase(10); - uint8_t* scramble2 = nullptr; //auth-plugin-data-part-2 - if (auth_plugin_data_len > 0) { - scramble2 = m_slice->peek(auth_plugin_data_len - 8); - m_slice->erase(auth_plugin_data_len - 8); - } + char* scramble2 = nullptr; + auth_plugin_data_len = std::max(13, auth_plugin_data_len - 8); + scramble2 = (char*)m_slice->peek(auth_plugin_data_len); + m_slice->erase(auth_plugin_data_len); + lua_pushlstring(L, (char*)scramble1, 8); + lua_pushlstring(L, scramble2, 12); //auth_plugin_name const char* auth_plugin_name = nullptr; if ((m_capability & CLIENT_PLUGIN_AUTH) == CLIENT_PLUGIN_AUTH) { auth_plugin_name = read_cstring(m_slice, data_len); + lua_pushlstring(L, auth_plugin_name, data_len); } - lua_pushinteger(L, character_set); - lua_pushlstring(L, (char*)scramble1, 8); - lua_pushlstring(L, (char*)scramble2, auth_plugin_data_len - 8); - lua_pushlstring(L, auth_plugin_name, 8); } const char* read_cstring(slice* slice, size_t& l) { size_t sz; const char* dst = (const char*)slice->data(&sz); for (l = 0; l < sz; ++l) { - if (l == sz - 1) { - throw invalid_argument("invalid mysql block : cstring"); - } if (dst[l] == '\0') { slice->erase(l + 1); return dst; } + if (l == sz - 1) { + throw invalid_argument("invalid mysql block : cstring"); + } } throw invalid_argument("invalid mysql block : cstring"); return ""; @@ -489,6 +492,7 @@ namespace lcodec { if (length > 0) { char* data = (char*)m_slice->peek(length); if (!data) throw invalid_argument("invalid length coded string:" + length); + m_slice->erase(length); return string_view(data, length); } return ""; diff --git a/extend/luakit/include/lua_slice.h b/extend/luakit/include/lua_slice.h index f38e99ac..641a1efd 100644 --- a/extend/luakit/include/lua_slice.h +++ b/extend/luakit/include/lua_slice.h @@ -98,6 +98,13 @@ namespace luakit { return std::string_view((const char*)m_head, len); } + std::string_view eof() { + uint8_t* head = m_head; + m_head = m_tail; + size_t len = (size_t)(m_tail - head); + return std::string_view((const char*)head, len); + } + int string(lua_State* L) { size_t len = (size_t)(m_tail - m_head); lua_pushlstring(L, (const char*)m_head, len); diff --git a/script/driver/mysql.lua b/script/driver/mysql.lua index 5f6cfd8c..efd3329f 100644 --- a/script/driver/mysql.lua +++ b/script/driver/mysql.lua @@ -52,7 +52,7 @@ function MysqlDB:__init(conf, id) self.passwd = conf.passwd --setup self:set_options(conf.opts) - self:setup_pool(conf) + self:setup_pool(conf.hosts) --update update_mgr:attach_hour(self) end @@ -94,7 +94,7 @@ end function MysqlDB:setup_pool(hosts) if not next(hosts) then - log_err("[MysqlDB][setup_pool] mongo config err: hosts is empty") + log_err("[MysqlDB][setup_pool] mysql config err: hosts is empty") return end local count = 1 @@ -116,6 +116,7 @@ function MysqlDB:check_alive() thread_mgr:entry(self:address(), function() local channel = makechan("check mysql") for _, sock in pairs(self.connections) do + sock:close() channel:push(function() return self:login(sock) end) @@ -137,7 +138,6 @@ function MysqlDB:login(socket) local ok, res = self:auth(socket) if not ok then self:delive(socket) - socket:close() log_err("[MysqlDB][login] auth db(%s:%s:%s) auth failed! because: %s", ip, port, id, res) return false end @@ -150,22 +150,15 @@ end function MysqlDB:auth(socket) local session_id = thread_mgr:build_session_id() socket:set_codec(mysqlcodec(session_id)) - local ok, charset, scramble1, scramble2 = thread_mgr:yield(session_id, "mysql server auth", DB_TIMEOUT) - if not ok then - return false, charset - end + local charset, scramble1, scramble2, atype = thread_mgr:yield(session_id, "mysql server auth", DB_TIMEOUT) local scramble = scramble1 .. scramble2 - local stage1 = lsha1(lsha1(self.passwd)) + local stage1 = lsha1(self.passwd) local stage2 = lsha1(scramble .. lsha1(stage1)) local auth_passwd = lxor_byte(stage1, stage2) - if not socket:send_data(COM_CONNECT, session_id, charset, auth_passwd, self.name) then + if not socket:send_data(COM_CONNECT, session_id, charset, self.user, auth_passwd, self.name) then return false, "send failed" end - local sok, res = thread_mgr:yield(session_id, "mysql client auth", DB_TIMEOUT) - if not sok then - return false, res - end - return true + return thread_mgr:yield(session_id, "mysql client auth", DB_TIMEOUT) end function MysqlDB:delive(sock) @@ -188,7 +181,7 @@ end function MysqlDB:on_socket_recv(socket, session_id, ...) if session_id > 0 then - thread_mgr:response(session_id, true, ...) + thread_mgr:response(session_id, ...) end end @@ -203,7 +196,7 @@ function MysqlDB:request(cmd, quote, ...) end function MysqlDB:query(query) - return self:request(COM_QUERY, "mysql prepare", query) + return self:request(COM_QUERY, "mysql query", query) end -- 注册预处理语句 diff --git a/script/driver/socket.lua b/script/driver/socket.lua index 88d78024..21519d4c 100644 --- a/script/driver/socket.lua +++ b/script/driver/socket.lua @@ -39,6 +39,7 @@ function Socket:close() self.session.close() self.alive = false self.session = nil + self.codec = nil self.token = nil end end @@ -80,7 +81,7 @@ function Socket:connect(ip, port, ptype) end local session, cerr = socket_mgr.connect(ip, port, CONNECT_TIMEOUT, proto_text) if not session then - log_err("[Socket][connect] failed to connect: %s:%d err=%s", ip, port, cerr) + log_err("[Socket][connect] failed to connect: %s:%s err=%s", ip, port, cerr) return false, cerr end --设置阻塞id diff --git a/script/store/mysql_mgr.lua b/script/store/mysql_mgr.lua index 54c798cf..c05100ad 100644 --- a/script/store/mysql_mgr.lua +++ b/script/store/mysql_mgr.lua @@ -26,7 +26,7 @@ function MysqlMgr:setup() local drivers = environ.driver("QUANTA_MYSQL_URLS") for i, conf in ipairs(drivers) do local mysql_db = MysqlDB(conf, i) - self.mysql_dbs[conf.id] = mysql_db + self.mysql_dbs[i] = mysql_db end end @@ -49,7 +49,7 @@ end function MysqlMgr:execute(db_id, primary_id, stmt, ...) local mysqldb = self:get_db(db_id) - if mysqldb and mysqldb:set_executer() then + if mysqldb and mysqldb:set_executer(primary_id) then local ok, res_oe = mysqldb:execute(stmt, ...) if not ok then log_err("[MysqlMgr][execute] execute %s failed, because: %s", stmt, res_oe)