From 567c7e9a3cc31058f876a941d36d9a3f0c79fb32 Mon Sep 17 00:00:00 2001 From: xiyoo0812 Date: Tue, 19 Sep 2023 09:32:51 +0000 Subject: [PATCH] =?UTF-8?q?!222=20=E7=BD=91=E7=BB=9C=E9=A9=B1=E5=8A=A8?= =?UTF-8?q?=E5=B1=82=E4=BC=98=E5=8C=96=201=E3=80=81mysql=E9=A9=B1=E5=8A=A8?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E4=BC=98=E5=8C=96=202=E3=80=81pb=E5=8D=8F?= =?UTF-8?q?=E8=AE=AE=E8=B0=83=E7=94=A8=E4=BC=98=E5=8C=96=203=E3=80=81redis?= =?UTF-8?q?=E3=80=81wss=E5=8D=8F=E8=AE=AE=E4=BC=98=E5=8C=96=204=E3=80=81?= =?UTF-8?q?=E7=BD=91=E7=BB=9C=E5=B1=82=E8=A7=A3=E6=9E=90=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E7=BB=9F=E4=B8=80=205=E3=80=81lua=E4=BB=A3=E7=A0=81=E6=9B=B4?= =?UTF-8?q?=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .luacheckrc | 2 +- Makefile | 2 +- README.md | 8 +- bin/database.conf | 6 +- core/luabus/src/lua_socket_mgr.cpp | 60 +- core/luabus/src/lua_socket_mgr.h | 15 +- core/luabus/src/lua_socket_node.cpp | 290 +++--- core/luabus/src/lua_socket_node.h | 48 +- core/luabus/src/luabus.cpp | 14 +- core/luabus/src/socket_helper.h | 3 + core/luabus/src/socket_listener.cpp | 8 +- core/luabus/src/socket_listener.h | 4 +- core/luabus/src/socket_mgr.cpp | 34 +- core/luabus/src/socket_mgr.h | 42 +- core/luabus/src/socket_stream.cpp | 94 +- core/luabus/src/socket_stream.h | 12 +- core/quanta/quanta.mak | 1 + extend/laoi/laoi.mak | 1 + extend/lbson/lbson.mak | 1 + extend/lbson/src/bson.h | 29 +- extend/lbson/src/lbson.cpp | 8 +- extend/lcodec/lcodec.mak | 1 + extend/lcodec/src/http.h | 31 +- extend/lcodec/src/lcodec.cpp | 55 +- extend/lcodec/src/lcodec.h | 1 + extend/lcodec/src/mysql.h | 508 +++++++++- extend/lcodec/src/redis.h | 24 +- extend/lcodec/src/websocket.h | 26 +- extend/lcrypt/lcrypt.mak | 1 + extend/lcurl/lcurl.mak | 1 + extend/ldetour/ldetour.mak | 1 + extend/ljson/ljson.mak | 1 + extend/ljson/src/ljson.cpp | 8 +- extend/ljson/src/ljson.h | 6 + extend/lmake/share.lua | 1 + extend/lstdfs/lstdfs.mak | 1 + extend/ltimer/ltimer.mak | 1 + extend/lua/lua.mak | 1 + extend/lua/lua/lapi.c | 4 +- extend/lua/lua/ldebug.c | 28 +- extend/lua/lua/ldebug.h | 1 + extend/lua/lua/lgc.c | 16 +- extend/lua/lua/lmathlib.c | 31 +- extend/lua/lua/lobject.c | 2 +- extend/lua/lua/lobject.h | 18 +- extend/lua/lua/lparser.c | 12 +- extend/lua/lua/lstate.c | 2 +- extend/lua/lua/lstate.h | 2 +- extend/lua/lua/lstring.c | 11 +- extend/lua/lua/lundump.c | 4 +- extend/lua/lua/lundump.h | 3 +- extend/lua/lua/lvm.c | 30 +- extend/lua/luac.mak | 1 + extend/lua/lualib.mak | 1 + extend/luakit/include/lua_codec.h | 47 +- extend/luakit/include/lua_function.h | 2 +- extend/luakit/include/lua_kit.h | 20 +- extend/luakit/include/lua_slice.h | 66 +- extend/lualog/lualog.mak | 1 + .../lua-protobuf.lmak => luapb/luapb.lmak} | 15 +- .../lua-protobuf.mak => luapb/luapb.mak} | 17 +- .../luapb.vcxproj} | 17 +- .../luapb.vcxproj.filters} | 7 +- extend/luapb/src/luapb.cpp | 144 +++ .../{protobuf/lua-protobuf => luapb/src}/pb.c | 0 .../{protobuf/lua-protobuf => luapb/src}/pb.h | 0 extend/luaxlsx/luaxlsx.mak | 1 + extend/lworker/lworker.mak | 1 + extend/lworker/src/scheduler.h | 9 +- extend/lworker/src/worker.h | 12 +- extend/mimalloc/mimalloc.mak | 1 + quanta.sln | 12 +- script/basic/library.lua | 2 +- script/basic/logger.lua | 2 +- script/driver/influx.lua | 2 +- script/driver/mongo.lua | 7 +- script/driver/mysql.lua | 908 +++--------------- script/driver/redis.lua | 45 +- script/driver/redisps.lua | 2 - script/driver/socket.lua | 25 +- script/driver/websocket.lua | 28 +- script/network/http_server.lua | 13 +- script/network/net_client.lua | 71 +- script/network/net_server.lua | 113 +-- script/network/rpc_server.lua | 15 +- script/store/clickhouse_mgr.lua | 40 +- script/store/mysql_mgr.lua | 34 +- server/business/component/attr_component.lua | 2 +- server/cache/cache_gm.lua | 2 +- server/gateway/gateway.lua | 8 +- server/gateway/group_mgr.lua | 12 +- server/gateway/player.lua | 4 + server/router/transfer_mgr.lua | 6 +- server/test/clickhouse_test.lua | 10 +- server/test/codec_test.lua | 10 +- server/test/json_test.lua | 4 +- server/test/mysql_test.lua | 31 +- tools/accord/accord.lua | 2 +- 98 files changed, 1703 insertions(+), 1583 deletions(-) rename extend/{protobuf/lua-protobuf.lmak => luapb/luapb.lmak} (69%) rename extend/{protobuf/lua-protobuf.mak => luapb/luapb.mak} (82%) rename extend/{protobuf/lua-protobuf.vcxproj => luapb/luapb.vcxproj} (85%) rename extend/{protobuf/lua-protobuf.vcxproj.filters => luapb/luapb.vcxproj.filters} (78%) create mode 100644 extend/luapb/src/luapb.cpp rename extend/{protobuf/lua-protobuf => luapb/src}/pb.c (100%) rename extend/{protobuf/lua-protobuf => luapb/src}/pb.h (100%) diff --git a/.luacheckrc b/.luacheckrc index 0b17254d..6ffdf425 100644 --- a/.luacheckrc +++ b/.luacheckrc @@ -6,7 +6,7 @@ stds.quanta = { "quanta", "environ", "signal", "luabt", "service", "logger", "import", "class", "enum", "mixin", "property", "singleton", "super", "implemented", "logfeature", "db_property", "classof", "is_class", "is_subclass", "instanceof", "conv_class", - "codec", "crypt", "stdfs", "luabus", "json", "protobuf", "curl", "timer", "aoi", "log", "worker", "http", "bson", "detour" + "codec", "crypt", "stdfs", "luabus", "luakit", "json", "protobuf", "curl", "timer", "aoi", "log", "worker", "http", "bson", "detour" } } std = "max+quanta" diff --git a/Makefile b/Makefile index 763b9abc..2b31a0ca 100644 --- a/Makefile +++ b/Makefile @@ -33,9 +33,9 @@ luaext: cd extend/lstdfs; make SOLUTION_DIR=$(CUR_DIR) -f lstdfs.mak; cd extend/ltimer; make SOLUTION_DIR=$(CUR_DIR) -f ltimer.mak; cd extend/lualog; make SOLUTION_DIR=$(CUR_DIR) -f lualog.mak; + cd extend/luapb; make SOLUTION_DIR=$(CUR_DIR) -f luapb.mak; cd extend/luaxlsx; make SOLUTION_DIR=$(CUR_DIR) -f luaxlsx.mak; cd extend/lworker; make SOLUTION_DIR=$(CUR_DIR) -f lworker.mak; - cd extend/protobuf; make SOLUTION_DIR=$(CUR_DIR) -f lua-protobuf.mak; share: cd extend/mimalloc; make SOLUTION_DIR=$(CUR_DIR) -f mimalloc.mak; diff --git a/README.md b/README.md index 6dfc6759..309fd1df 100644 --- a/README.md +++ b/README.md @@ -53,18 +53,16 @@ cd bin # 依赖 - lua -- bson -- mongo +- lbson - luabt -- lhttp - lcurl +- ljson - luabus - lcrypt -- lcjson - lstdfs - luakit - lualog -- lbuffer +- lcodec - luaxlsx - lua-protobuf diff --git a/bin/database.conf b/bin/database.conf index f99c499e..2c7dd418 100644 --- a/bin/database.conf +++ b/bin/database.conf @@ -22,6 +22,6 @@ set_env("QUANTA_REDIS_URLS", [[ redis://root:123456@127.0.0.1:6379; ]]) --mysql ---set_env("QUANTA_MYSQL_URLS", [[ --- mysql://root:123456@127.0.0.1:6379/quanta; ---]]) +set_env("QUANTA_MYSQL_URLS", [[ + mysql://root:123456@127.0.0.1:3306/quanta; +]]) diff --git a/core/luabus/src/lua_socket_mgr.cpp b/core/luabus/src/lua_socket_mgr.cpp index d23891c0..82de8c2e 100644 --- a/core/luabus/src/lua_socket_mgr.cpp +++ b/core/luabus/src/lua_socket_mgr.cpp @@ -5,7 +5,6 @@ bool lua_socket_mgr::setup(lua_State* L, int max_fd) { m_lvm = L; m_mgr = std::make_shared(); - m_codec = std::make_shared(); m_router = std::make_shared(m_mgr); return m_mgr->setup(max_fd); } @@ -14,15 +13,17 @@ int lua_socket_mgr::listen(lua_State* L, const char* ip, int port) { if (ip == nullptr || port <= 0) { return luakit::variadic_return(L, nullptr, "invalid param"); } - std::string err; int token = m_mgr->listen(err, ip, port); if (token == 0) { return luakit::variadic_return(L, nullptr, err); } - auto listener = new lua_socket_node(token, L, m_mgr, m_router, true); - listener->set_codec(m_codec.get()); + eproto_type proto_type = (eproto_type)luaL_optinteger(L, 3, (int)eproto_type::proto_rpc); + auto listener = new lua_socket_node(token, L, m_mgr, m_router, proto_type); + if (proto_type == eproto_type::proto_rpc) { + listener->create_codec(); + } return luakit::variadic_return(L, listener, "ok"); } @@ -30,15 +31,17 @@ int lua_socket_mgr::connect(lua_State* L, const char* ip, const char* port, int if (ip == nullptr || port == nullptr) { return luakit::variadic_return(L, nullptr, "invalid param"); } - std::string err; int token = m_mgr->connect(err, ip, port, timeout); if (token == 0) { return luakit::variadic_return(L, nullptr, err); } - auto socket_node = new lua_socket_node(token, L, m_mgr, m_router, false); - socket_node->set_codec(m_codec.get()); + eproto_type proto_type = (eproto_type)luaL_optinteger(L, 4, (int)eproto_type::proto_rpc); + auto socket_node = new lua_socket_node(token, L, m_mgr, m_router, proto_type); + if (proto_type == eproto_type::proto_rpc) { + socket_node->create_codec(); + } return luakit::variadic_return(L, socket_node, "ok"); } @@ -50,10 +53,49 @@ int lua_socket_mgr::get_recvbuf_size(uint32_t token) { return m_mgr->get_recvbuf_size(token); } -void lua_socket_mgr::set_proto_type(uint32_t token, eproto_type type) { - return m_mgr->set_proto_type(token, type); +void lua_socket_mgr::set_codec(uint32_t token, codec_base* codec) { + return m_mgr->set_codec(token, codec); } int lua_socket_mgr::map_token(uint32_t node_id, uint32_t token) { return m_router->map_token(node_id, token); } + +int lua_socket_mgr::broadcast(lua_State* L, codec_base* codec, uint32_t kind) { + size_t data_len = 0; + char* data = (char*)codec->encode(L, 3, &data_len); + socket_header* header = (socket_header*)data; + if (data_len <= USHRT_MAX) { + //组装数据 + header->len = data_len; + header->session_id = 0; + //发送数据 + m_mgr->broadcast(kind, data, data_len); + lua_pushboolean(L, true); + return 1; + } + lua_pushboolean(L, false); + return 1; +} + +int lua_socket_mgr::broadgroup(lua_State* L, codec_base* codec) { + size_t data_len = 0; + std::vector groups; + if (!lua_to_native(L, 2, groups)) { + lua_pushboolean(L, false); + return 1; + } + char* data = (char*)codec->encode(L, 3, &data_len); + socket_header* header = (socket_header*)data; + if (data_len <= USHRT_MAX) { + //组装数据 + header->len = data_len; + header->session_id = 0; + //发送数据 + m_mgr->broadgroup(groups, data, data_len); + lua_pushboolean(L, true); + return 1; + } + lua_pushboolean(L, false); + return 1; +} diff --git a/core/luabus/src/lua_socket_mgr.h b/core/luabus/src/lua_socket_mgr.h index 72d78f16..3ae9308a 100644 --- a/core/luabus/src/lua_socket_mgr.h +++ b/core/luabus/src/lua_socket_mgr.h @@ -10,20 +10,19 @@ struct lua_socket_mgr final public: ~lua_socket_mgr(){} bool setup(lua_State* L, int max_fd); - int wait(int64_t now, int timeout) { return m_mgr->wait(now, timeout); } int get_sendbuf_size(uint32_t token); int get_recvbuf_size(uint32_t token); int map_token(uint32_t node_id, uint32_t token); int listen(lua_State* L, const char* ip, int port); int connect(lua_State* L, const char* ip, const char* port, int timeout); - void set_proto_type(uint32_t token, eproto_type type); - - std::shared_ptr get_router() { return m_router; } + int wait(int64_t now, int timeout) { return m_mgr->wait(now, timeout); } + int broadcast(lua_State* L, codec_base* codec, uint32_t kind); + int broadgroup(lua_State* L, codec_base* codec); + void set_codec(uint32_t token, codec_base* codec); private: - lua_State* m_lvm = nullptr; - std::shared_ptr m_codec; - std::shared_ptr m_mgr; - std::shared_ptr m_router; + lua_State* m_lvm; + stdsptr m_mgr; + stdsptr m_router; }; diff --git a/core/luabus/src/lua_socket_node.cpp b/core/luabus/src/lua_socket_node.cpp index 9564150f..d18f2cc6 100644 --- a/core/luabus/src/lua_socket_node.cpp +++ b/core/luabus/src/lua_socket_node.cpp @@ -1,30 +1,26 @@ #include "stdafx.h" #include "lua_socket_node.h" -lua_socket_node::lua_socket_node(uint32_t token, lua_State* L, std::shared_ptr& mgr, std::shared_ptr& router, bool blisten) - : m_token(token), m_mgr(mgr), m_router(router) { +lua_socket_node::lua_socket_node(uint32_t token, lua_State* L, stdsptr mgr, stdsptr router, eproto_type type) + : m_token(token), m_type(type), m_mgr(mgr), m_router(router) { m_stoken = (m_token & 0xffff) << 16; - m_luakit = std::make_shared(L); + m_lvm = std::make_shared(L); m_mgr->get_remote_ip(m_token, m_ip); - if (blisten) { - m_mgr->set_accept_callback(token, [=](uint32_t steam_token) { - auto node = new lua_socket_node(steam_token, m_luakit->L(), m_mgr, m_router, false); - node->set_codec(m_codec); - m_luakit->object_call(this, "on_accept", nullptr, std::tie(), node); - }); - } m_mgr->set_connect_callback(token, [=](bool ok, const char* reason) { - m_luakit->object_call(this, "on_connect", nullptr, std::tie(), ok ? "ok" : reason); + m_lvm->object_call(this, "on_connect", nullptr, std::tie(), ok ? "ok" : reason); }); - m_mgr->set_error_callback(token, [=](const char* err) { auto token = m_token; m_token = 0; - m_luakit->object_call(this, "on_error", nullptr, std::tie(), token, err); + m_lvm->object_call(this, "on_error", nullptr, std::tie(), token, err); }); - - m_mgr->set_package_callback(token, [=](slice* data, eproto_type proto_type){ - return on_recv(data, proto_type); + m_mgr->set_package_callback(token, [=](slice* slice){ + return on_recv(slice); + }); + m_mgr->set_accept_callback(token, [=](uint32_t steam_token) { + auto node = new lua_socket_node(steam_token, L, m_mgr, m_router, m_type); + node->set_codec(m_codec); + m_lvm->object_call(this, "on_accept", nullptr, std::tie(), node); }); } @@ -37,116 +33,132 @@ void lua_socket_node::close() { m_mgr->close(m_token); m_token = 0; } + m_router = nullptr; + m_codec = nullptr; + m_mgr = nullptr; } int lua_socket_node::call_data(lua_State* L) { - size_t data_len = 0; - char* data = (char*)m_codec->encode(L, 1, &data_len); - if (data_len > SOCKET_PACKET_MAX) return 0; - m_mgr->send(m_token, data, data_len); - lua_pushinteger(L, data_len); + if (m_codec) { + size_t data_len = 0; + char* data = (char*)m_codec->encode(L, 1, &data_len); + if (data_len <= SOCKET_PACKET_MAX){ + m_mgr->send(m_token, data, data_len); + lua_pushinteger(L, data_len); + return 1; + } + } + lua_pushinteger(L, 0); return 1; } -int lua_socket_node::call_head(uint16_t cmd_id, uint8_t flag, uint8_t type, uint8_t crc8, uint32_t session_id, const char* data, uint32_t data_len) { - size_t length = data_len + sizeof(socket_header); - if (length <= USHRT_MAX) { - //组装数据 - socket_header header; - header.flag = flag; - header.type = type; - header.len = length; - header.crc8 = crc8; - header.cmd_id = cmd_id; - header.session_id = (session_id & 0xffff); - //发送数据 - sendv_item items[] = { { &header, sizeof(socket_header) }, {data, data_len} }; - m_mgr->sendv(m_token, items, _countof(items)); - return length; +int lua_socket_node::call_pb(lua_State* L, uint32_t session_id) { + if (m_codec) { + size_t data_len = 0; + char* data = (char*)m_codec->encode(L, 2, &data_len); + socket_header* header = (socket_header*)data; + if (data_len <= USHRT_MAX) { + //组装数据 + header->len = data_len; + header->session_id = (session_id & 0xffff); + //发送数据 + m_mgr->send(m_token, data, data_len); + lua_pushinteger(L, data_len); + return 1; + } } - return 0; + lua_pushinteger(L, 0); + return 1; } int lua_socket_node::call(lua_State* L, uint32_t session_id, uint8_t flag) { - size_t data_len = 0; - char* data = (char*)m_codec->encode(L, 3, &data_len); - size_t length = data_len + sizeof(router_header); - if (length <= SOCKET_PACKET_MAX) { - //组装数据 - router_header header; - header.len = length; - header.target_id = 0; - header.session_id = session_id; - header.context = (uint8_t)rpc_type::remote_call << 4 | flag; - //发送数据 - sendv_item items[] = { { &header, sizeof(router_header)}, {data, data_len}}; - m_mgr->sendv(m_token, items, _countof(items)); - lua_pushinteger(L, length); - return 1; + if (m_codec) { + size_t data_len = 0; + char* data = (char*)m_codec->encode(L, 3, &data_len); + size_t length = data_len + sizeof(router_header); + if (length <= SOCKET_PACKET_MAX) { + //组装数据 + router_header header; + header.len = length; + header.target_id = 0; + header.session_id = session_id; + header.context = (uint8_t)rpc_type::remote_call << 4 | flag; + //发送数据 + sendv_item items[] = { { &header, sizeof(router_header)}, {data, data_len}}; + m_mgr->sendv(m_token, items, _countof(items)); + lua_pushinteger(L, length); + return 1; + } } lua_pushinteger(L, 0); return 1; } int lua_socket_node::forward_transfer(lua_State* L, uint32_t session_id, uint32_t target_id, uint8_t service_id) { - size_t data_len = 0; - char* data = (char*)m_codec->encode(L, 4, &data_len); - size_t length = data_len + sizeof(transfer_header); - if (length <= SOCKET_PACKET_MAX) { - //组装数据 - transfer_header header; - header.len = length; - header.target_id = target_id; - header.service_id = service_id; - header.session_id = session_id; - header.context = (uint8_t)rpc_type::transfer_call << 4; - //发送数据 - sendv_item items[] = { { &header, sizeof(transfer_header)}, {data, data_len}}; - m_mgr->sendv(m_token, items, _countof(items)); - lua_pushinteger(L, length); - return 1; - } + if (m_codec) { + size_t data_len = 0; + char* data = (char*)m_codec->encode(L, 4, &data_len); + size_t length = data_len + sizeof(transfer_header); + if (length <= SOCKET_PACKET_MAX) { + //组装数据 + transfer_header header; + header.len = length; + header.target_id = target_id; + header.service_id = service_id; + header.session_id = session_id; + header.context = (uint8_t)rpc_type::transfer_call << 4; + //发送数据 + sendv_item items[] = { { &header, sizeof(transfer_header)}, {data, data_len}}; + m_mgr->sendv(m_token, items, _countof(items)); + lua_pushinteger(L, length); + return 1; + } + } lua_pushinteger(L, 0); return 1; } int lua_socket_node::forward_target(lua_State* L, uint32_t session_id, uint8_t flag, uint32_t target_id) { - size_t data_len = 0; - char* data = (char*)m_codec->encode(L, 4, &data_len); - size_t length = data_len + sizeof(router_header); - if (length <= SOCKET_PACKET_MAX) { - //组装数据 - router_header header; - header.len = length; - header.target_id = target_id; - header.session_id = session_id; - header.context = (uint8_t)rpc_type::forward_target << 4 | flag; - //发送数据 - sendv_item items[] = { { &header, sizeof(router_header)}, {data, data_len} }; - m_mgr->sendv(m_token, items, _countof(items)); - lua_pushinteger(L, length); - return 1; + if (m_codec) { + size_t data_len = 0; + char* data = (char*)m_codec->encode(L, 4, &data_len); + size_t length = data_len + sizeof(router_header); + if (length <= SOCKET_PACKET_MAX) { + //组装数据 + router_header header; + header.len = length; + header.target_id = target_id; + header.session_id = session_id; + header.context = (uint8_t)rpc_type::forward_target << 4 | flag; + //发送数据 + sendv_item items[] = { { &header, sizeof(router_header)}, {data, data_len} }; + m_mgr->sendv(m_token, items, _countof(items)); + lua_pushinteger(L, length); + return 1; + } } lua_pushinteger(L, 0); return 1; } int lua_socket_node::forward_hash(lua_State* L, uint32_t session_id, uint8_t flag, uint16_t service_id, uint16_t hash) { - size_t data_len = 0; - char* data = (char*)m_codec->encode(L, 5, &data_len); - size_t length = data_len + sizeof(router_header); - if (length <= SOCKET_PACKET_MAX) { - //组装数据 - router_header header; - header.len = length; - header.session_id = session_id; - header.target_id = service_id << 16 | hash; - header.context = (uint8_t)rpc_type::forward_hash << 4 | flag; - //发送数据 - sendv_item items[] = { { &header, sizeof(router_header)}, {data, data_len} }; - m_mgr->sendv(m_token, items, _countof(items)); - lua_pushinteger(L, length); - return 1; + if (m_codec) { + size_t data_len = 0; + char* data = (char*)m_codec->encode(L, 5, &data_len); + size_t length = data_len + sizeof(router_header); + if (length <= SOCKET_PACKET_MAX) { + //组装数据 + router_header header; + header.len = length; + header.session_id = session_id; + header.target_id = service_id << 16 | hash; + header.context = (uint8_t)rpc_type::forward_hash << 4 | flag; + //发送数据 + sendv_item items[] = { { &header, sizeof(router_header)}, {data, data_len} }; + m_mgr->sendv(m_token, items, _countof(items)); + lua_pushinteger(L, length); + return 1; + } } lua_pushinteger(L, 0); return 1; @@ -180,36 +192,34 @@ int lua_socket_node::transfer_call(lua_State* L, uint32_t session_id, uint32_t t } int lua_socket_node::transfer_hash(lua_State* L, uint32_t session_id, uint16_t service_id, uint16_t hash) { - size_t data_len = 0; - char* data = (char*)m_codec->encode(L, 4, &data_len); - size_t length = data_len + sizeof(router_header); - if (length <= SOCKET_PACKET_MAX) { - //组装数据 - router_header header; - header.len = length; - header.session_id = session_id; - header.context = (uint8_t)rpc_type::remote_call << 4 | 0x01; - header.target_id = service_id << 16 | hash; - if (m_router->do_forward_hash(&header, data, data_len)) { - lua_pushinteger(L, length); - return 1; + if (m_codec) { + size_t data_len = 0; + char* data = (char*)m_codec->encode(L, 4, &data_len); + size_t length = data_len + sizeof(router_header); + if (length <= SOCKET_PACKET_MAX) { + //组装数据 + router_header header; + header.len = length; + header.session_id = session_id; + header.context = (uint8_t)rpc_type::remote_call << 4 | 0x01; + header.target_id = service_id << 16 | hash; + if (m_router->do_forward_hash(&header, data, data_len)) { + lua_pushinteger(L, length); + return 1; + } } } lua_pushinteger(L, 0); return 0; } -int lua_socket_node::on_recv(slice* slice, eproto_type proto_type) { - if (eproto_type::proto_head == proto_type) { - return on_call_head(slice); - } - if (eproto_type::proto_text == proto_type) { - return on_call_text(slice); +void lua_socket_node::on_recv(slice* slice) { + if (m_type == eproto_type::proto_pb) { + return on_call_pb(slice); } - if (eproto_type::proto_rpc != proto_type) { + if (m_type == eproto_type::proto_text) { return on_call_data(slice); } - size_t data_len; size_t header_len = sizeof(router_header); auto hdata = slice->peek(header_len); @@ -258,19 +268,18 @@ int lua_socket_node::on_recv(slice* slice, eproto_type proto_type) { } break; } - return header->len; } void lua_socket_node::on_forward_error(router_header* header, slice* slice) { if (header->session_id > 0) { m_codec->set_slice(slice); - m_luakit->object_call(this, "on_forward_error", nullptr, m_codec, std::tie(), header->session_id, header->target_id); + m_lvm->object_call(this, "on_forward_error", nullptr, m_codec, std::tie(), header->session_id, header->target_id); } } void lua_socket_node::on_forward_broadcast(router_header* header, size_t broadcast_num) { if (header->session_id > 0) { - m_luakit->object_call(this, "on_forward_broadcast", nullptr, std::tie(), header->session_id, broadcast_num); + m_lvm->object_call(this, "on_forward_broadcast", nullptr, std::tie(), header->session_id, broadcast_num); } } @@ -278,43 +287,22 @@ void lua_socket_node::on_transfer(transfer_header* header, slice* slice) { uint8_t service_id = header->service_id; uint32_t target_id = header->target_id; uint32_t session_id = header->session_id; - m_luakit->object_call(this, "on_transfer", nullptr, std::tie(), header->len, session_id, service_id, target_id, slice); + m_lvm->object_call(this, "on_transfer", nullptr, std::tie(), header->len, session_id, service_id, target_id, slice); } -int lua_socket_node::on_call_head(slice* slice) { - size_t header_len = sizeof(socket_header); - auto data = slice->peek(header_len); - socket_header* header = (socket_header*)data; - uint8_t crc8 = header->crc8; - uint8_t flag = header->flag; - uint8_t type = header->type; - uint16_t cmd_id = header->cmd_id; +void lua_socket_node::on_call_pb(slice* slice) { + socket_header* header = (socket_header*)slice->peek(sizeof(socket_header)); uint32_t session_id = header->session_id; if (session_id > 0) session_id |= m_stoken; - slice->erase(header_len); - std::string body((char*)slice->head(), slice->size()); - m_luakit->object_call(this, "on_call_head", nullptr, std::tie(), header->len, cmd_id, flag, type, crc8, session_id, body); - return header->len; + m_lvm->object_call(this, "on_call_pb", nullptr, m_codec, std::tie(), header->len, session_id); } void lua_socket_node::on_call(router_header* header, slice* slice) { - m_codec->set_slice(slice); uint8_t flag = header->context & 0xff; uint32_t session_id = header->session_id; - m_luakit->object_call(this, "on_call", nullptr, m_codec, std::tie(), header->len, session_id, flag); -} - -int lua_socket_node::on_call_data(slice* slice) { - m_codec->set_slice(slice); - size_t buf_size = slice->size(); - m_luakit->object_call(this, "on_call_data", nullptr, m_codec, std::tie(), buf_size); - return buf_size; + m_lvm->object_call(this, "on_call", nullptr, m_codec, std::tie(), header->len, session_id, flag); } -int lua_socket_node::on_call_text(slice* slice) { - bool success = true; - m_codec->set_slice(slice); - size_t buf_size = slice->size(); - m_luakit->object_call(this, "on_call_data", [&](std::string_view) { success = false; }, m_codec, std::tie(), buf_size); - return success ? (buf_size - slice->size()) : -1; +void lua_socket_node::on_call_data(slice* slice) { + m_lvm->object_call(this, "on_call_data", nullptr, m_codec, std::tie(), slice->size()); } diff --git a/core/luabus/src/lua_socket_node.h b/core/luabus/src/lua_socket_node.h index 22baf93f..d717ca17 100644 --- a/core/luabus/src/lua_socket_node.h +++ b/core/luabus/src/lua_socket_node.h @@ -8,23 +8,37 @@ class lua_socket_node { public: - lua_socket_node(uint32_t token, lua_State* L, std::shared_ptr& mgr, std::shared_ptr& router, bool blisten = false); + lua_socket_node(uint32_t token, lua_State* L, stdsptr mgr, stdsptr router, eproto_type type); ~lua_socket_node(); void close(); - uint32_t build_session_id() { return m_stoken | m_sindex++; } - uint32_t get_route_count() { return m_router->get_route_count(); } - void set_codec(codec_base* codec) { m_codec = codec; } - void set_timeout(int ms) { m_mgr->set_timeout(m_token, ms); } - void set_nodelay(bool flag) { m_mgr->set_nodelay(m_token, flag); } - void set_proto_type(eproto_type proto_type) { m_mgr->set_proto_type(m_token, proto_type); } + uint32_t build_session_id() { + return m_stoken | m_sindex++; + } + uint32_t get_route_count() { + return m_router->get_route_count(); + } + void set_timeout(int ms) { + m_mgr->set_timeout(m_token, ms); + } + void set_nodelay(bool flag) { + m_mgr->set_nodelay(m_token, flag); + } + void set_codec(codec_base* codec) { + m_codec = codec; + m_mgr->set_codec(m_token, codec); + } + void create_codec() { + m_codec = m_lvm->create_codec(); + m_mgr->set_codec(m_token, m_codec); + } int call_data(lua_State* L); + int call_pb(lua_State* L, uint32_t session_id); int call(lua_State* L, uint32_t session_id, uint8_t flag); - int call_head(uint16_t cmd_id, uint8_t flag, uint8_t type, uint8_t crc8, uint32_t session_id, const char* data, uint32_t data_len); - int forward_target(lua_State* L, uint32_t session_id, uint8_t flag, uint32_t target_id); + int forward_target(lua_State* L, uint32_t session_id, uint8_t flag, uint32_t target_id); int forward_hash(lua_State* L, uint32_t session_id, uint8_t flag, uint16_t service_id, uint16_t hash); int forward_transfer(lua_State* L, uint32_t session_id, uint32_t target_id, uint8_t service_id); @@ -59,17 +73,17 @@ class lua_socket_node uint16_t m_sindex = 1; private: - int on_call_head(slice* slice); - int on_call_text(slice* slice); - int on_call_data(slice* slice); - int on_recv(slice* slice, eproto_type proto_type); + void on_recv(slice* slice); + void on_call_pb(slice* slice); + void on_call_data(slice* slice); void on_call(router_header* header, slice* slice); void on_transfer(transfer_header* header, slice* slice); void on_forward_broadcast(router_header* header, size_t target_size); void on_forward_error(router_header* header, slice* slice); - codec_base* m_codec = nullptr; - std::shared_ptr m_mgr; - std::shared_ptr m_luakit; - std::shared_ptr m_router; + eproto_type m_type; + codec_base* m_codec; + stdsptr m_lvm; + stdsptr m_mgr; + stdsptr m_router; }; diff --git a/core/luabus/src/luabus.cpp b/core/luabus/src/luabus.cpp index b00a70b6..7e1df266 100644 --- a/core/luabus/src/luabus.cpp +++ b/core/luabus/src/luabus.cpp @@ -42,12 +42,9 @@ namespace luabus { lluabus.set_function("dns", gethostbydomain); lluabus.set_function("create_socket_mgr", create_socket_mgr); lluabus.new_enum("eproto_type", + "pb", eproto_type::proto_pb, "rpc", eproto_type::proto_rpc, - "wss", eproto_type::proto_wss, - "head", eproto_type::proto_head, - "text", eproto_type::proto_text, - "mongo", eproto_type::proto_mongo, - "mysql", eproto_type::proto_mysql + "text", eproto_type::proto_text ); kit_state.new_class( "send", &socket_udp::send, @@ -69,6 +66,8 @@ namespace luabus { "listen", &lua_socket_mgr::listen, "connect", &lua_socket_mgr::connect, "map_token", &lua_socket_mgr::map_token, + "broadcast", &lua_socket_mgr::broadcast, + "broadgroup", &lua_socket_mgr::broadgroup, "get_sendbuf_size", &lua_socket_mgr::get_sendbuf_size, "get_recvbuf_size", &lua_socket_mgr::get_recvbuf_size ); @@ -77,16 +76,15 @@ namespace luabus { "token", &lua_socket_node::m_token, "call", &lua_socket_node::call, "close", &lua_socket_node::close, - "set_codec", &lua_socket_node::set_codec, - "call_head", &lua_socket_node::call_head, + "call_pb", &lua_socket_node::call_pb, "call_data", &lua_socket_node::call_data, + "set_codec", &lua_socket_node::set_codec, "set_nodelay", &lua_socket_node::set_nodelay, "set_timeout", &lua_socket_node::set_timeout, "forward_hash", &lua_socket_node::forward_hash, "transfer_call", &lua_socket_node::transfer_call, "transfer_hash", &lua_socket_node::transfer_hash, "forward_target", &lua_socket_node::forward_target, - "set_proto_type", &lua_socket_node::set_proto_type, "get_route_count", &lua_socket_node::get_route_count, "build_session_id", &lua_socket_node::build_session_id, "forward_transfer", &lua_socket_node::forward_transfer, diff --git a/core/luabus/src/socket_helper.h b/core/luabus/src/socket_helper.h index 4bcf42a4..e6a64fab 100644 --- a/core/luabus/src/socket_helper.h +++ b/core/luabus/src/socket_helper.h @@ -44,6 +44,9 @@ bool wsa_send_empty(socket_t fd, WSAOVERLAPPED& ovl); bool wsa_recv_empty(socket_t fd, WSAOVERLAPPED& ovl); #endif +template +using stdsptr = std::shared_ptr; + bool make_ip_addr(sockaddr_storage* addr, size_t* len, const char ip[], int port); // ip字符串建议大小: char ip[INET6_ADDRSTRLEN]; bool get_ip_string(char ip[], size_t ip_size, const void* addr, size_t addr_len); diff --git a/core/luabus/src/socket_listener.cpp b/core/luabus/src/socket_listener.cpp index d126957f..041ba928 100644 --- a/core/luabus/src/socket_listener.cpp +++ b/core/luabus/src/socket_listener.cpp @@ -102,7 +102,7 @@ void socket_listener::on_complete(WSAOVERLAPPED* ovl) { set_no_block(node->fd); - auto token = m_mgr->accept_stream(node->fd, ip, m_proto_type); + auto token = m_mgr->accept_stream(m_token, node->fd, ip); if (token == 0) { closesocket(node->fd); } @@ -163,7 +163,7 @@ void socket_listener::queue_accept(WSAOVERLAPPED* ovl) { (*m_addrs_func)(node->buffer, 0, sizeof(node->buffer[0]), sizeof(node->buffer[2]), &local_addr, &local_addr_len, &remote_addr, &remote_addr_len); get_ip_string(ip, sizeof(ip), remote_addr, (size_t)remote_addr_len); - auto token = m_mgr->accept_stream(node->fd, ip, m_proto_type); + auto token = m_mgr->accept_stream(m_token, node->fd, ip); if (token == 0) { closesocket(node->fd); node->fd = INVALID_SOCKET; @@ -201,9 +201,9 @@ void socket_listener::on_can_recv(size_t max_len, bool is_eof) { set_no_delay(fd, 1); set_close_on_exec(fd); - auto token = m_mgr->accept_stream(fd, ip); + auto token = m_mgr->accept_stream(m_token, fd, ip); if (token != 0) { - m_accept_cb(token, m_proto_type); + m_accept_cb(token); } else { closesocket(fd); diff --git a/core/luabus/src/socket_listener.h b/core/luabus/src/socket_listener.h index 9c16b353..57a6ab83 100644 --- a/core/luabus/src/socket_listener.h +++ b/core/luabus/src/socket_listener.h @@ -24,8 +24,8 @@ struct socket_listener : public socket_object bool setup(socket_t fd); bool get_remote_ip(std::string& ip) override { return false; } bool update(int64_t now) override; - void set_accept_callback(const std::function& cb) override { m_accept_cb = cb; } - void set_error_callback(const std::function& cb) override { m_error_cb = cb; } + void set_accept_callback(const std::function cb) override { m_accept_cb = cb; } + void set_error_callback(const std::function cb) override { m_error_cb = cb; } #ifdef _MSC_VER void on_complete(WSAOVERLAPPED* ovl); diff --git a/core/luabus/src/socket_mgr.cpp b/core/luabus/src/socket_mgr.cpp index f60a9697..b00f92e7 100644 --- a/core/luabus/src/socket_mgr.cpp +++ b/core/luabus/src/socket_mgr.cpp @@ -192,6 +192,8 @@ int socket_mgr::listen(std::string& err, const char ip[], int port) { if (watch_listen(fd, listener) && listener->setup(fd)) { int token = new_token(); + listener->set_kind(token); + listener->set_token(token); m_objects[token] = listener; return token; } @@ -223,6 +225,7 @@ int socket_mgr::connect(std::string& err, const char node_name[], const char ser stm->connect(node_name, service_name, timeout); int token = new_token(); + stm->set_token(token); m_objects[token] = stm; return token; } @@ -241,10 +244,10 @@ void socket_mgr::set_nodelay(uint32_t token, int flag) { } } -void socket_mgr::set_proto_type(uint32_t token, eproto_type type) { +void socket_mgr::set_codec(uint32_t token, codec_base* codec) { auto node = get_object(token); if (node) { - node->set_proto_type(type); + node->set_codec(codec); } } @@ -262,6 +265,20 @@ void socket_mgr::sendv(uint32_t token, const sendv_item items[], int count) { } } +void socket_mgr::broadcast(size_t kind, const void* data, size_t data_len) { + for(auto node : m_objects) { + if (node.second->is_same_kind(kind)) { + node.second->send(data, data_len); + } + } +} + +void socket_mgr::broadgroup(std::vector& groups, const void* data, size_t data_len) { + for(auto token : groups) { + send(token, data, data_len); + } +} + void socket_mgr::close(uint32_t token) { auto node = get_object(token); if (node) { @@ -293,28 +310,28 @@ int socket_mgr::get_recvbuf_size(uint32_t token){ return 0; } -void socket_mgr::set_accept_callback(uint32_t token, const std::function& cb) { +void socket_mgr::set_accept_callback(uint32_t token, const std::function cb) { auto node = get_object(token); if (node) { node->set_accept_callback(cb); } } -void socket_mgr::set_connect_callback(uint32_t token, const std::function& cb) { +void socket_mgr::set_connect_callback(uint32_t token, const std::function cb) { auto node = get_object(token); if (node) { node->set_connect_callback(cb); } } -void socket_mgr::set_package_callback(uint32_t token, const std::function& cb) { +void socket_mgr::set_package_callback(uint32_t token, const std::function cb) { auto node = get_object(token); if (node) { node->set_package_callback(cb); } } -void socket_mgr::set_error_callback(uint32_t token, const std::function& cb) { +void socket_mgr::set_error_callback(uint32_t token, const std::function cb) { auto node = get_object(token); if (node) { node->set_error_callback(cb); @@ -418,11 +435,12 @@ bool socket_mgr::watch_send(socket_t fd, socket_object* object, bool enable) { #endif } -int socket_mgr::accept_stream(socket_t fd, const char ip[], eproto_type proto_type) { +int socket_mgr::accept_stream(uint32_t ltoken, socket_t fd, const char ip[]) { auto* stm = new socket_stream(this); if (watch_accepted(fd, stm) && stm->accept_socket(fd, ip)) { auto token = new_token(); - stm->set_proto_type(proto_type); + stm->set_kind(ltoken); + stm->set_token(token); m_objects[token] = stm; return token; } diff --git a/core/luabus/src/socket_mgr.h b/core/luabus/src/socket_mgr.h index 8e48edf9..4d909316 100644 --- a/core/luabus/src/socket_mgr.h +++ b/core/luabus/src/socket_mgr.h @@ -22,13 +22,10 @@ enum class elink_status : int // 协议类型 enum class eproto_type : int { - proto_rpc = 0, // rpc协议,根据协议头解析 - proto_wss = 1, // wss协议,协议前n个字节带长度 - proto_head = 2, // head协议,根据协议头解析 - proto_text = 3, // text协议,文本协议 - proto_mongo = 4, // mongo协议,协议前4个字节为长度 - proto_mysql = 5, // mysql协议,协议前3个字节为长度 - proto_max = 6, // max + proto_pb = 0, // pb协议,pb + proto_rpc = 1, // rpc协议,rpc + proto_text = 2, // text协议,mysql/mongo/http/wss/redis + proto_max = 3, // max }; struct sendv_item @@ -50,11 +47,14 @@ struct socket_object virtual void set_nodelay(int flag) { } virtual void send(const void* data, size_t data_len) { } virtual void sendv(const sendv_item items[], int count) { }; - virtual void set_proto_type(eproto_type type) { m_proto_type = type; } - virtual void set_accept_callback(const std::function& cb) { } - virtual void set_connect_callback(const std::function& cb) { } - virtual void set_error_callback(const std::function& cb) { } - virtual void set_package_callback(const std::function& cb) { } + virtual void set_kind(uint32_t kind) { m_kind = kind; } + virtual void set_token(uint32_t token) { m_token = token; } + virtual void set_codec(codec_base* codec) { m_codec = codec; } + virtual void set_accept_callback(const std::function cb) { } + virtual void set_connect_callback(const std::function cb) { } + virtual void set_error_callback(const std::function cb) { } + virtual void set_package_callback(const std::function cb) { } + virtual bool is_same_kind(uint32_t kind) { return m_kind == kind; } #ifdef _MSC_VER virtual void on_complete(WSAOVERLAPPED* ovl) = 0; @@ -66,7 +66,9 @@ struct socket_object #endif protected: - eproto_type m_proto_type = eproto_type::proto_rpc; + uint32_t m_kind = 0; + uint32_t m_token = 0; + codec_base* m_codec = nullptr; elink_status m_link_status = elink_status::link_init; }; @@ -93,20 +95,22 @@ class socket_mgr void set_nodelay(uint32_t token, int flag); void send(uint32_t token, const void* data, size_t data_len); void sendv(uint32_t token, const sendv_item items[], int count); + void broadcast(size_t kind, const void* data, size_t data_len); + void broadgroup(std::vector& groups, const void* data, size_t data_len); void close(uint32_t token); + void set_codec(uint32_t token, codec_base* codec); bool get_remote_ip(uint32_t token, std::string& ip); - void set_proto_type(uint32_t token, eproto_type type); - void set_accept_callback(uint32_t token, const std::function& cb); - void set_connect_callback(uint32_t token, const std::function& cb); - void set_package_callback(uint32_t token, const std::function& cb); - void set_error_callback(uint32_t token, const std::function& cb); + void set_accept_callback(uint32_t token, const std::function cb); + void set_error_callback(uint32_t token, const std::function cb); + void set_connect_callback(uint32_t token, const std::function cb); + void set_package_callback(uint32_t token, const std::function cb); bool watch_listen(socket_t fd, socket_object* object); bool watch_accepted(socket_t fd, socket_object* object); bool watch_connecting(socket_t fd, socket_object* object); bool watch_connected(socket_t fd, socket_object* object); bool watch_send(socket_t fd, socket_object* object, bool enable); - int accept_stream(socket_t fd, const char ip[], eproto_type proto_type); + int accept_stream(uint32_t ltoken, socket_t fd, const char ip[]); void increase_count() { m_count++; } void decrease_count() { m_count--; } diff --git a/core/luabus/src/socket_stream.cpp b/core/luabus/src/socket_stream.cpp index ffefbe3e..4dcbef46 100644 --- a/core/luabus/src/socket_stream.cpp +++ b/core/luabus/src/socket_stream.cpp @@ -62,6 +62,9 @@ void socket_stream::close() { m_link_status = elink_status::link_closed; return; } + if (m_codec) { + m_codec = nullptr; + } shutdown(m_socket, SD_RECEIVE); m_link_status = elink_status::link_colsing; } @@ -459,85 +462,36 @@ void socket_stream::do_recv(size_t max_len, bool is_eof) void socket_stream::dispatch_package() { int64_t now = ltimer::steady_ms(); while (m_link_status == elink_status::link_connected) { - size_t data_len = 0, package_size = 0; - auto* data = m_recv_buffer->data(&data_len); - if (data_len == 0) break; - switch (m_proto_type) { - case eproto_type::proto_rpc: { - size_t header_len = sizeof(router_header); - if (!m_recv_buffer->peek_data(header_len)) return; - // 当前包长小于headlen, 关闭连接 - router_header* header = (router_header*)data; - if (header->len < header_len) { - on_error("package-length-err"); - break; - } - package_size = header->len; - } - break; - case eproto_type::proto_head: { - size_t header_len = sizeof(socket_header); - if (!m_recv_buffer->peek_data(header_len)) return; - // 当前包长小于headlen, 关闭连接 - socket_header* header = (socket_header*)data; - if (header->len < header_len) { - on_error("package-length-err"); - return; - } - package_size = header->len; - } - break; - case eproto_type::proto_mongo: { - uint32_t* length = (uint32_t*)m_recv_buffer->peek_data(sizeof(uint32_t)); - if (!length) return; - //package_size = length + contents - package_size = *length; - } - break; - case eproto_type::proto_mysql: { - uint32_t* length = (uint32_t*)m_recv_buffer->peek_data(sizeof(uint32_t)); - if (!length) return; - //package_size = length + serialize_id + contents - package_size = ((*length) >> 8) + sizeof(uint32_t); - } - break; - case eproto_type::proto_wss: { - uint16_t* length = (uint16_t*)m_recv_buffer->peek_data(sizeof(uint16_t)); - if (!length) return; - uint16_t payload = (*length) & 0x7f; - if (payload < 0x7e) { - package_size = payload + sizeof(uint16_t); - } else { - size_t* length = (size_t*)m_recv_buffer->peek_data((payload == 0x7f) ? 8 : 2, sizeof(uint16_t)); - if (!length) return; - package_size = (*length) + sizeof(uint16_t); - } - } - break; - case eproto_type::proto_text: - package_size = data_len; + if (!m_codec){ + on_error("codec-is-bnull"); break; - default: - on_error("proto-type-not-suppert!"); - return; } - //当前包头标识的数据超过最大长度, 关闭连接 - if (package_size > SOCKET_PACKET_MAX) { + size_t data_len; + auto* data = m_recv_buffer->data(&data_len); + if (data_len == 0) break; + slice* slice = m_recv_buffer->get_slice(); + m_codec->set_slice(slice); + //解析数据包头长度 + int32_t package_size = m_codec->load_packet(data_len); + //当前包头长度解析失败, 关闭连接 + if (package_size < 0){ on_error("package-length-err"); break; } // 数据包还没有收完整 - if (data_len < package_size) break; - int read_size = m_package_cb(m_recv_buffer->get_slice(package_size), m_proto_type); - // 数据包还没有收完整 - if (read_size == 0) { - break; - } + if (package_size == 0) break; + // 数据回调 + slice->attach(data, package_size); + m_package_cb(slice); + if (!m_codec) break; // 数据包解析失败 - if (read_size < 0) { - on_error("package-read-err"); + if (m_codec->failed()) { + on_error(m_codec->err()); break; } + size_t read_size = m_codec->get_packet_len(); + // 数据包还没有收完整 + if (read_size == 0) break; // 接收缓冲读游标调整 m_recv_buffer->pop_size(read_size); m_last_recv_time = ltimer::steady_ms(); diff --git a/core/luabus/src/socket_stream.h b/core/luabus/src/socket_stream.h index 791adc44..bcec044f 100644 --- a/core/luabus/src/socket_stream.h +++ b/core/luabus/src/socket_stream.h @@ -1,6 +1,5 @@ #pragma once -#include "socket_helper.h" #include "socket_mgr.h" struct socket_stream : public socket_object @@ -18,10 +17,9 @@ struct socket_stream : public socket_object bool do_connect(); void try_connect(); void close() override; - void set_accept_callback(const std::function& cb) override { m_accept_cb = cb; } - void set_error_callback(const std::function& cb) override { m_error_cb = cb; } - void set_connect_callback(const std::function& cb) override { m_connect_cb = cb; } - void set_package_callback(const std::function& cb) override { m_package_cb = cb; } + void set_error_callback(const std::function cb) override { m_error_cb = cb; } + void set_connect_callback(const std::function cb) override { m_connect_cb = cb; } + void set_package_callback(const std::function cb) override { m_package_cb = cb; } void set_timeout(int duration) override { m_timeout = duration; } void set_nodelay(int flag) override { set_no_delay(m_socket, flag); } @@ -48,7 +46,6 @@ struct socket_stream : public socket_object void on_error(const char err[]); void on_connect(bool ok, const char reason[]); - int token = 0; socket_mgr* m_mgr = nullptr; socket_t m_socket = INVALID_SOCKET; std::shared_ptr m_recv_buffer = std::make_shared(); @@ -71,8 +68,7 @@ struct socket_stream : public socket_object int m_ovl_ref = 0; #endif - std::function m_accept_cb = nullptr; std::function m_error_cb = nullptr; std::function m_connect_cb = nullptr; - std::function m_package_cb = nullptr; + std::function m_package_cb = nullptr; }; diff --git a/core/quanta/quanta.mak b/core/quanta/quanta.mak index d93bfd59..8e9708d1 100644 --- a/core/quanta/quanta.mak +++ b/core/quanta/quanta.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/laoi/laoi.mak b/extend/laoi/laoi.mak index eb9e5db8..6344d32b 100644 --- a/extend/laoi/laoi.mak +++ b/extend/laoi/laoi.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/lbson/lbson.mak b/extend/lbson/lbson.mak index 44aec61e..e4958c4a 100644 --- a/extend/lbson/lbson.mak +++ b/extend/lbson/lbson.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/lbson/src/bson.h b/extend/lbson/src/bson.h index b325af44..4249159a 100644 --- a/extend/lbson/src/bson.h +++ b/extend/lbson/src/bson.h @@ -324,7 +324,7 @@ namespace lbson { const char* read_bytes(lua_State* L, slice* slice, size_t sz) { const char* dst = (const char*)slice->peek(sz); if (!dst) { - throw length_error("invalid bson string , length = " + sz); + throw invalid_argument("invalid bson string , length = " + sz); } slice->erase(sz); return dst; @@ -333,7 +333,7 @@ namespace lbson { const char* read_string(lua_State* L, slice* slice, size_t& sz) { sz = (size_t)read_val(L, slice); if (sz <= 0) { - throw length_error("invalid bson string , length = " + sz); + throw invalid_argument("invalid bson string , length = " + sz); } sz = sz - 1; const char* dst = ""; @@ -344,17 +344,17 @@ namespace lbson { return dst; } - const char* read_cstring(lua_State * L, slice* slice, size_t& l) { + const char* read_cstring(slice* slice, size_t& l) { size_t sz; const char* dst = (const char*)slice->data(&sz); for (l = 0; l < sz; ++l) { - if (l == sz - 1) { - throw invalid_argument("invalid bson block : cstring"); - } if (dst[l] == '\0') { slice->erase(l + 1); return dst; } + if (l == sz - 1) { + throw invalid_argument("invalid bson block : cstring"); + } } throw invalid_argument("invalid bson block : cstring"); return ""; @@ -362,7 +362,7 @@ namespace lbson { void unpack_key(lua_State* L, slice* slice, bool isarray) { size_t klen = 0; - const char* key = read_cstring(L, slice, klen); + const char* key = read_cstring(slice, klen); if (isarray) { lua_pushinteger(L, std::stoll(key, nullptr, 10) + 1); return; @@ -375,7 +375,7 @@ namespace lbson { void unpack_dict(lua_State* L, slice* slice, bool isarray) { uint32_t sz = read_val(L, slice); if (slice->size() < sz - 4) { - throw length_error("decode can't unpack one value"); + throw invalid_argument("decode can't unpack one value"); } lua_createtable(L, 0, 8); while (!slice->empty()) { @@ -417,7 +417,7 @@ namespace lbson { } break; case bson_type::BSON_REGEX: - lua_push_object(L, new bson_value(bt, read_cstring(L, slice, klen), read_cstring(L, slice, klen))); + lua_push_object(L, new bson_value(bt, read_cstring(slice, klen), read_cstring(slice, klen))); break; case bson_type::BSON_DOCUMENT: unpack_dict(L, slice, false); @@ -442,6 +442,17 @@ namespace lbson { class mgocodec : public codec_base { public: + virtual int load_packet(size_t data_len) { + if (!m_slice) return 0; + uint32_t* packet_len = (uint32_t*)m_slice->peek(sizeof(uint32_t)); + if (!packet_len) return 0; + m_packet_len = *packet_len; + if (m_packet_len > 0xffffff) return -1; + if (m_packet_len > data_len) return 0; + if (!m_slice->peek(m_packet_len)) return 0; + return m_packet_len; + } + virtual uint8_t* encode(lua_State* L, int index, size_t* len) { luabuf* buf = m_bson->get_buffer(); buf->clean(); diff --git a/extend/lbson/src/lbson.cpp b/extend/lbson/src/lbson.cpp index fe3faaca..a133cedc 100644 --- a/extend/lbson/src/lbson.cpp +++ b/extend/lbson/src/lbson.cpp @@ -5,7 +5,6 @@ namespace lbson { thread_local bson thread_bson; - thread_local mgocodec thread_codec; static int encode(lua_State* L) { return thread_bson.encode(L); @@ -38,11 +37,12 @@ namespace lbson { bson_numstr_len[i] = sprintf(tmp, "%d", i); memcpy(bson_numstrs[i], tmp, bson_numstr_len[i]); } - thread_codec.set_bson(&thread_bson); } static codec_base* mongo_codec() { - return &thread_codec; + mgocodec* codec = new mgocodec(); + codec->set_bson(&thread_bson); + return codec; } luakit::lua_table open_lbson(lua_State* L) { @@ -52,7 +52,7 @@ namespace lbson { llbson.set_function("decode", decode); llbson.set_function("encode_slice", encode_slice); llbson.set_function("decode_slice", decode_slice); - llbson.set_function("mongo_codec", mongo_codec); + llbson.set_function("mongocodec", mongo_codec); llbson.set_function("timestamp", timestamp); llbson.set_function("int32", int32); llbson.set_function("int64", int64); diff --git a/extend/lcodec/lcodec.mak b/extend/lcodec/lcodec.mak index f59f0e60..9c0ecae4 100644 --- a/extend/lcodec/lcodec.mak +++ b/extend/lcodec/lcodec.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/lcodec/src/http.h b/extend/lcodec/src/http.h index a0122a3c..8d4c79e8 100644 --- a/extend/lcodec/src/http.h +++ b/extend/lcodec/src/http.h @@ -34,6 +34,11 @@ namespace lcodec { class httpcodec : public codec_base { public: + virtual int load_packet(size_t data_len) { + if (!m_slice) return 0; + return data_len; + } + virtual uint8_t* encode(lua_State* L, int index, size_t* len) { m_buf->clean(); //status (http begining) @@ -63,7 +68,8 @@ namespace lcodec { size_t osize = m_slice->size(); string_view buf = m_slice->contents(); parse_http_packet(L, buf); - m_slice->erase(osize - buf.size()); + m_packet_len = osize - buf.size(); + m_slice->erase(m_packet_len); return lua_gettop(L) - top; } @@ -71,10 +77,6 @@ namespace lcodec { m_jcodec = codec; } - void set_buff(luabuf* buf) { - m_buf = buf; - } - protected: void format_http(size_t status) { switch (status) { @@ -159,12 +161,18 @@ namespace lcodec { contentlenable = true; mslice = m_buf->get_slice(); size_t content_size = atol(header.data()); + if (buf.size() < content_size) { + throw length_error("http text not full"); + } mslice->attach((uint8_t*)buf.data(), content_size); buf.remove_prefix(content_size); } else if (!strncasecmp(key.data(), "Transfer-Encoding", key.size()) && !strncasecmp(header.data(), "chunked", header.size())) { contentlenable = true; size_t pos = buf.find(CRLF2); + if (pos == string_view::npos) { + throw length_error("http text not full"); + } string_view chunk_data = buf.substr(0, pos); buf.remove_prefix(pos + LCRLF2); vector chunks; @@ -186,9 +194,15 @@ namespace lcodec { } } if (!contentlenable) { - mslice = m_buf->get_slice(); - mslice->attach((uint8_t*)buf.data(), buf.size()); - buf.remove_prefix(buf.size()); + if (!buf.empty()) { + mslice = m_buf->get_slice(); + mslice->attach((uint8_t*)buf.data(), buf.size()); + buf.remove_prefix(buf.size()); + } + } + if (!mslice || mslice->empty()) { + lua_pushnil(L); + return; } if (jsonable) { m_jcodec->set_slice(mslice); @@ -222,7 +236,6 @@ namespace lcodec { } protected: - luabuf* m_buf = nullptr; codec_base* m_jcodec = nullptr; }; } diff --git a/extend/lcodec/src/lcodec.cpp b/extend/lcodec/src/lcodec.cpp index bd880a5a..23558ff1 100644 --- a/extend/lcodec/src/lcodec.cpp +++ b/extend/lcodec/src/lcodec.cpp @@ -5,41 +5,35 @@ namespace lcodec { thread_local ketama thread_ketama; - thread_local rdscodec thread_rds; - thread_local wsscodec thread_wss; - thread_local httpcodec thread_http; thread_local luakit::luabuf thread_buff; - static rdscodec* rds_codec(codec_base* codec) { - thread_rds.set_codec(codec); - thread_rds.set_buff(&thread_buff); - return &thread_rds; + static codec_base* rds_codec(codec_base* codec) { + rdscodec* rcodec = new rdscodec(); + rcodec->set_codec(codec); + rcodec->set_buff(&thread_buff); + return rcodec; } - static wsscodec* wss_codec(codec_base* codec) { - thread_wss.set_codec(codec); - thread_wss.set_buff(&thread_buff); - return &thread_wss; + static codec_base* wss_codec(codec_base* codec) { + wsscodec* wcodec = new wsscodec(); + wcodec->set_codec(codec); + wcodec->set_buff(&thread_buff); + return wcodec; } - static httpcodec* http_codec(codec_base* codec) { - thread_http.set_codec(codec); - thread_http.set_buff(&thread_buff); - return &thread_http; + static codec_base* http_codec(codec_base* codec) { + httpcodec* hcodec = new httpcodec(); + hcodec->set_codec(codec); + hcodec->set_buff(&thread_buff); + return hcodec; } - - static int serialize(lua_State* L) { - return luakit::serialize(L, &thread_buff); - } - static int unserialize(lua_State* L) { - return luakit::unserialize(L); - } - static int encode(lua_State* L) { - return luakit::encode(L, &thread_buff); - } - static int decode(lua_State* L) { - return luakit::decode(L, &thread_buff); + + static codec_base* mysql_codec(size_t session_id) { + mysqlscodec* codec = new mysqlscodec(session_id); + codec->set_buff(&thread_buff); + return codec; } + static bool ketama_insert(std::string name, uint32_t node_id) { return thread_ketama.insert(name, node_id, 255); } @@ -65,11 +59,7 @@ namespace lcodec { luakit::lua_table open_lcodec(lua_State* L) { luakit::kit_state kit_state(L); auto llcodec = kit_state.new_table(); - llcodec.set_function("encode", encode); - llcodec.set_function("decode", decode); llcodec.set_function("bitarray", lbarray); - llcodec.set_function("serialize", serialize); - llcodec.set_function("unserialize", unserialize); llcodec.set_function("guid_new", guid_new); llcodec.set_function("guid_string", guid_string); llcodec.set_function("guid_tostring", guid_tostring); @@ -89,10 +79,11 @@ namespace lcodec { llcodec.set_function("ketama_remove", ketama_remove); llcodec.set_function("ketama_next", ketama_next); llcodec.set_function("ketama_map", ketama_map); + llcodec.set_function("mysqlcodec", mysql_codec); llcodec.set_function("rediscodec", rds_codec); llcodec.set_function("httpcodec", http_codec); llcodec.set_function("wsscodec", wss_codec); - + kit_state.new_class( "flip", &bitarray::flip, "fill", &bitarray::fill, diff --git a/extend/lcodec/src/lcodec.h b/extend/lcodec/src/lcodec.h index 495ea2a0..56a6ba0d 100644 --- a/extend/lcodec/src/lcodec.h +++ b/extend/lcodec/src/lcodec.h @@ -7,6 +7,7 @@ #include "http.h" #include "hash.h" #include "redis.h" +#include "mysql.h" #include "ketama.h" #include "websocket.h" #include "bitarray.h" diff --git a/extend/lcodec/src/mysql.h b/extend/lcodec/src/mysql.h index 5ea58f21..c0dcd5ea 100644 --- a/extend/lcodec/src/mysql.h +++ b/extend/lcodec/src/mysql.h @@ -1,43 +1,525 @@ #pragma once +#include +#include #include "lua_kit.h" using namespace std; using namespace luakit; namespace lcodec { + // cmd constants + const uint8_t COM_SLEEP = 0x00; + const uint8_t COM_CONNECT = 0x0b; + const uint8_t COM_STMT_PREPARE = 0x16; + const uint8_t COM_STMT_CLOSE = 0x19; + + // constants + inline uint32_t CLIENT_FLAG = 260047; //0011 1111 0111 1100 1111 + inline uint32_t MAX_PACKET_SIZE = 0xffffff; + inline uint32_t CLIENT_PLUGIN_AUTH = 1 << 19; + inline uint32_t CLIENT_DEPRECATE_EOF = 1 << 24; + + // field types + const uint16_t MYSQL_TYPE_TINY = 0x01; + const uint16_t MYSQL_TYPE_SHORT = 0x02; + const uint16_t MYSQL_TYPE_LONG = 0x03; + const uint16_t MYSQL_TYPE_FLOAT = 0x04; + const uint16_t MYSQL_TYPE_DOUBLE = 0x05; + const uint16_t MYSQL_TYPE_NULL = 0x06; + const uint16_t MYSQL_TYPE_LONGLONG = 0x08; + const uint16_t MYSQL_TYPE_INT24 = 0x09; + const uint16_t MYSQL_TYPE_YEAR = 0x0d; + const uint16_t MYSQL_TYPE_VARCHAR = 0x0f; + const uint16_t MYSQL_TYPE_NEWDECIMAL = 0xf6; + + // server status + inline size_t SERVER_MORE_RESULTS_EXISTS = 8; + + enum class packet_type: int + { + MP_OK = 0, + MP_ERR = 1, + MP_EOF = 2, + MP_DATA = 3, + MP_INF = 4, + }; + + struct mysql_cmd { + uint8_t cmd_id; + size_t session_id; + }; + + struct mysql_column { + string_view name; + uint8_t type; + uint16_t flags; + }; + typedef vector mysql_columns; class mysqlscodec : public codec_base { public: + mysqlscodec(size_t session_id) { + sessions.push_back(mysql_cmd{ COM_SLEEP, session_id }); + } + + virtual int load_packet(size_t data_len) { + if (!m_slice) return 0; + return data_len; + } + virtual uint8_t* encode(lua_State* L, int index, size_t* len) { m_buf->clean(); - return m_buf->data(len); + // cmd_id + uint8_t cmd_id = (uint8_t)lua_tointeger(L, index++); + // session_id + size_t session_id = lua_tointeger(L, index++); + //4 byte header placeholder + m_buf->write(0); + if (cmd_id != COM_CONNECT) { + return comand_encode(L, cmd_id, session_id, index, len); + } + return auth_encode(L, cmd_id, session_id, index, len); } virtual size_t decode(lua_State* L) { - if (!m_slice) return 0; int top = lua_gettop(L); + if (sessions.empty()) { + throw invalid_argument("invalid mysql data"); + } + size_t osize = m_slice->size(); + mysql_cmd cmd = sessions.front(); + lua_pushinteger(L, cmd.session_id); + switch (cmd.cmd_id) { + case COM_SLEEP: + auth_decode(L); + break; + case COM_STMT_PREPARE: + prepare_decode(L); + break; + default: + command_decode(L); + break; + } + sessions.pop_front(); + m_packet_len = osize - m_slice->size(); return lua_gettop(L) - top; } - void set_codec(codec_base* codec) { - m_jcodec = codec; + protected: + packet_type recv_packet() { + uint32_t payload = *(uint32_t*)m_slice->read(); + uint32_t length = (payload & 0xffffff); + if (length >= 0xffffff) { + throw invalid_argument("sharded packet not suppert!"); + } + uint8_t* data = m_slice->erase(length); + if (!data) { + throw length_error("mysql text not full"); + } + m_packet.attach(data, length); + switch (*data) { + case 0xfb: return packet_type::MP_INF; + case 0xfe: return packet_type::MP_EOF; + case 0x00: return packet_type::MP_OK; + case 0xff: return packet_type::MP_ERR; + } + return packet_type::MP_DATA; + } + + uint8_t* comand_encode(lua_State* L, uint8_t cmd_id, size_t session_id, int index, size_t* len) { + m_buf->write(cmd_id); + int top = lua_gettop(L); + if (index <= top) { + if (lua_type(L, index) == LUA_TNUMBER) { + m_buf->write(lua_tointeger(L, index++)); + } + else { + size_t data_len; + uint8_t* query = (uint8_t*)lua_tolstring(L, index++, &data_len); + m_buf->push_data(query, data_len); + } + } + if (index <= top) { + encode_stmt_args(L, index, top - index + 1); + } + // header + uint32_t size = (m_buf->size() - 4) & 0xffffff; + m_buf->copy(0, (uint8_t*)&size, 4); + // cmd + if (cmd_id != COM_STMT_CLOSE) { + sessions.push_back(mysql_cmd{ cmd_id, session_id }); + } + return m_buf->data(len); + } + + uint8_t* auth_encode(lua_State* L, uint8_t cmd_id, size_t session_id, int index, size_t* len) { + //4 byte client_flag + m_buf->write(CLIENT_FLAG); + //4 byte max_packet_size + m_buf->write(MAX_PACKET_SIZE); + //1 byte character_set + m_buf->write((uint8_t)lua_tointeger(L, index++)); + //23 byte filler(all 0) + m_buf->pop_space(23); + // username + uint8_t* user = (uint8_t*)lua_tolstring(L, index++, len); + m_buf->push_data(user, *len); + m_buf->push_data((uint8_t*)"\0", 1); + // auth_data + uint8_t* auth_data = (uint8_t*)lua_tolstring(L, index++, len); + m_buf->write(*len); + m_buf->push_data(auth_data, *len); + // dbname + const uint8_t* dbname = (const uint8_t*)lua_tolstring(L, index++, len); + m_buf->push_data(dbname, *len); + m_buf->push_data((uint8_t*)"\0", 1); + // header + uint32_t size = ((m_buf->size() - 4) & 0xffffff) | 0x01000000; + m_buf->copy(0, (uint8_t*)&size, 4); + // cmd + sessions.push_back(mysql_cmd{ cmd_id, session_id }); + return m_buf->data(len); } - void set_buff(luabuf* buf) { - m_buf = buf; + void command_decode(lua_State* L) { + packet_type type = recv_packet(); + switch (type) { + case packet_type::MP_OK: + return ok_packet_decode(L); + case packet_type::MP_DATA: + return data_packet_decode(L); + case packet_type::MP_ERR: + return err_packet_decode(L); + default: throw invalid_argument("unsuppert mysql packet type"); + } } - protected: - char* xor_byte(char* buffer, char* mask, size_t blen, size_t mlen) { - for (int i = 0; i < blen; i++) { - buffer[i] = buffer[i] ^ mask[i % mlen]; + void field_decode(mysql_columns& columns) { + string_view catalog = decode_length_encoded_string(); + string_view schema = decode_length_encoded_string(); + string_view table = decode_length_encoded_string(); + string_view org_table = decode_length_encoded_string(); + string_view name = decode_length_encoded_string(); + string_view org_name = decode_length_encoded_string(); + // 1 byte fix length (skip) + // 2 byte character_set (skip) + // 4 byte column_length (skip) + m_packet.erase(7); + uint8_t type = *(uint8_t*)m_packet.read(); + uint16_t flags = *(uint16_t*)m_packet.read(); + uint8_t decimals = *(uint8_t*)m_packet.read(); + columns.push_back(mysql_column { name, type, decimals }); + } + + packet_type rows_decode(lua_State* L, mysql_columns& columns) { + // rows + size_t row_indx = 1; + packet_type type = recv_packet(); + while (type == packet_type::MP_DATA){ + // row + lua_createtable(L, 0, 8); + for (const mysql_column& column : columns) { + auto value = decode_length_encoded_string(); + lua_pushlstring(L, column.name.data(), column.name.size()); + switch (column.type) { + case MYSQL_TYPE_FLOAT: + case MYSQL_TYPE_DOUBLE: + lua_pushnumber(L, strtod(value.data(), nullptr)); + break; + case MYSQL_TYPE_TINY: + case MYSQL_TYPE_SHORT: + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_INT24: + case MYSQL_TYPE_YEAR: + case MYSQL_TYPE_LONGLONG: + case MYSQL_TYPE_NEWDECIMAL: + lua_pushinteger(L, strtoll(value.data(), nullptr, 10)); + break; + default: + lua_pushlstring(L, value.data(), value.size()); + break; + } + lua_rawset(L, -3); + } + lua_seti(L, -2, row_indx++); + type = recv_packet(); + } + return type; + } + + bool result_set_decode(lua_State* L, size_t top, size_t rset_idx) { + // result set header + lua_createtable(L, 0, 8); + size_t column_count = decode_length_encoded_number(); + // field metadata + mysql_columns columns; + for (int i = 0; i < column_count; ++i) { + recv_packet(); + field_decode(columns); + } + // field eof + if ((m_capability & CLIENT_DEPRECATE_EOF) != CLIENT_DEPRECATE_EOF) { + recv_packet(); + eof_packet_decode(); + } + // rows data + packet_type type = rows_decode(L, columns); + lua_seti(L, -2, rset_idx); + // terminator + if (type == packet_type::MP_ERR) { + lua_settop(L, top); + err_packet_decode(L); + return false; + } + // rows eof + return eof_packet_decode(); + } + + void data_packet_decode(lua_State* L) { + size_t rset_idx = 1; + int top = lua_gettop(L); + lua_pushboolean(L, true); + //result sets + lua_createtable(L, 0, 4); + bool more = result_set_decode(L, top, rset_idx++); + while (more) { + recv_packet(); + more = result_set_decode(L, top, rset_idx++); + } + } + + void ok_packet_decode(lua_State* L) { + //type + m_packet.read(); + lua_pushboolean(L, true); + lua_createtable(L, 0, 4); + //affected_rows + lua_pushinteger(L, decode_length_encoded_number()); + lua_setfield(L, -2, "affected_rows"); + //last_insert_id + lua_pushinteger(L, decode_length_encoded_number()); + lua_setfield(L, -2, "last_insert_id"); + //status_flags + m_packet.read(); + //warnings + lua_pushinteger(L, *(uint16_t*)m_packet.read()); + lua_setfield(L, -2, "warnings"); + //info + auto info = m_packet.eof(); + lua_pushlstring(L, info.data(), info.size()); + lua_setfield(L, -2, "info"); + } + + void err_packet_decode(lua_State* L) { + //type + m_packet.read(); + lua_pushboolean(L, false); + lua_createtable(L, 0, 4); + //errnoo + lua_pushinteger(L, *(uint16_t*)m_packet.read()); + lua_setfield(L, -2, "errnoo"); + //1 byte sql_state_marker (skip) + m_packet.erase(1); + //5 byte sql_state + char* sql_state = (char*)m_packet.erase(5); + lua_pushlstring(L, sql_state, 5); + lua_setfield(L, -2, "sql_state"); + //error_message + auto error_message = m_packet.eof(); + lua_pushlstring(L, error_message.data(), error_message.size()); + lua_setfield(L, -2, "error_message"); + } + + bool eof_packet_decode() { + //type + m_packet.read(); + if ((m_capability & CLIENT_DEPRECATE_EOF) == CLIENT_DEPRECATE_EOF) { + size_t affected_rows = decode_length_encoded_number(); + size_t last_insert_id = decode_length_encoded_number(); + uint16_t status_flags = *(uint16_t*)m_packet.read(); + uint16_t warnings = *(uint16_t*)m_packet.read(); + auto info = m_packet.eof(); + return ((status_flags & SERVER_MORE_RESULTS_EXISTS) == SERVER_MORE_RESULTS_EXISTS); + } + else { + uint16_t warnings = *(uint16_t*)m_packet.read(); + uint16_t status_flags = *(uint16_t*)m_packet.read(); + return ((status_flags & SERVER_MORE_RESULTS_EXISTS) == SERVER_MORE_RESULTS_EXISTS); + } + } + + void prepare_decode(lua_State* L) { + recv_packet(); + uint8_t status = *(uint8_t*)m_packet.read(); + uint32_t statement_id = *(uint32_t*)m_packet.read(); + uint16_t num_columns = *(uint16_t*)m_packet.read(); + uint16_t num_params = *(uint16_t*)m_packet.read(); + int top = lua_gettop(L); + lua_pushinteger(L, statement_id); + lua_pushinteger(L, num_columns); + lua_pushinteger(L, num_params); + } + + void auth_decode(lua_State* L) { + recv_packet(); + //1 byte protocol version + uint8_t proto = *(uint8_t*)m_packet.read(); + //n byte server version + size_t data_len; + const char* version = read_cstring(m_packet, data_len); + //4 byte thread_id + uint32_t thread_id = *(uint32_t*)m_packet.read(); + //8 byte auth-plugin-data-part-1 + uint8_t* scramble1 = m_packet.peek(8); + //8 byte auth-plugin-data-part-1 + 1 byte filler + m_packet.erase(9); + //2 byte capability_flags_1 + uint16_t capability_flag_1 = *(uint16_t*)m_packet.read(); + //1 byte character_set + uint8_t character_set = *(uint8_t*)m_packet.read(); + lua_pushinteger(L, character_set); + //2 byte status_flags + uint16_t status_flags = *(uint16_t*)m_packet.read(); + //2 byte capability_flags_2 + uint16_t capability_flag_2 = *(uint16_t*)m_packet.read(); + m_capability = capability_flag_2 << 16 | capability_flag_1; + //1 byte character_set + uint8_t auth_plugin_data_len = *(uint8_t*)m_packet.read(); + //10 byte reserved (all 0) + m_packet.erase(10); + //auth-plugin-data-part-2 + char* scramble2 = nullptr; + auth_plugin_data_len = std::max(13, auth_plugin_data_len - 8); + scramble2 = (char*)m_packet.erase(auth_plugin_data_len); + lua_pushlstring(L, (char*)scramble1, 8); + lua_pushlstring(L, scramble2, 12); + //auth_plugin_name + const char* auth_plugin_name = nullptr; + if ((m_capability & CLIENT_PLUGIN_AUTH) == CLIENT_PLUGIN_AUTH) { + auth_plugin_name = read_cstring(m_packet, data_len); + lua_pushlstring(L, auth_plugin_name, data_len); + } + } + + void encode_stmt_args(lua_State* L, int index, int argnum) { + //enum_cursor_type + m_buf->write(0); + //iteration_count + m_buf->write(1); + //null_bitmap, length= (argnum + 7) / 8 + int argpos = 0; + int argbyte = (argnum + 7) / 8; + for (int i = 0; i < argbyte; ++i) { + uint8_t byte = 0; + for (int j = 0; j < 7; ++j) { + int aindex = index + argpos++; + if (aindex < argnum) { + uint8_t bit = lua_isnil(L, aindex) ? 0 : 1; + byte |= (bit < j); + } + } + m_buf->write(byte); + } + //new_params_bind_flag + m_buf->write(1); + //parameter_type + for (int i = 0; i < argnum; ++i) { + encode_args_type(L, index + i); + } + //parameter_values + for (int i = 0; i < argnum; ++i) { + encode_args_value(L, index + i); + } + } + + void encode_args_type(lua_State* L, int index) { + int type = lua_type(L, index); + switch (type) { + case LUA_TNIL: + m_buf->write(MYSQL_TYPE_NULL); + break; + case LUA_TBOOLEAN: + m_buf->write(MYSQL_TYPE_TINY); + break; + case LUA_TSTRING: + m_buf->write(MYSQL_TYPE_VARCHAR); + break; + case LUA_TNUMBER: + m_buf->write(lua_isinteger(L, index) ? MYSQL_TYPE_LONGLONG : MYSQL_TYPE_DOUBLE); + break; + default: + throw invalid_argument("invalid mysql stmt args type"); + } + } + + void encode_args_value(lua_State* L, int index) { + switch (lua_type(L, index)) { + case LUA_TBOOLEAN: + m_buf->write(lua_tointeger(L, index)); + break; + case LUA_TNUMBER: + lua_isinteger(L, index) ? m_buf->write(lua_tointeger(L, index)) : m_buf->write(lua_tonumber(L, index)); + break; + case LUA_TSTRING: { + uint32_t data_len; + uint8_t* data = (uint8_t*)lua_tolstring(L, index, (size_t*)&data_len); + if (data_len < 0xfb) { + m_buf->write(data_len); + } + else if (data_len < 0xffff) { + m_buf->write(0xfc); + m_buf->write(data_len); + } + else if (data_len < 0xffffff) { + m_buf->write((0xfd << 24) | data_len); + } + else { + m_buf->write(0xfe); + m_buf->write(data_len); + } + m_buf->push_data(data, data_len); + } + break; + } + } + + size_t decode_length_encoded_number() { + uint8_t nbyte = *(uint8_t*)m_packet.read(); + if (nbyte < 0xfb) return nbyte; + if (nbyte == 0xfc) return *(uint16_t*)m_packet.read(); + if (nbyte == 0xfd) return *(uint32_t*)m_packet.read(); + if (nbyte == 0xfe) return *(uint64_t*)m_packet.read(); + return 0; + } + + string_view decode_length_encoded_string() { + size_t length = decode_length_encoded_number(); + if (length > 0) { + char* data = (char*)m_packet.erase(length); + if (!data) throw invalid_argument("invalid length coded string"); + return string_view(data, length); + } + return ""; + } + + const char* read_cstring(slice& slice, size_t& l) { + size_t sz; + const char* dst = (const char*)slice.data(&sz); + for (l = 0; l < sz; ++l) { + if (dst[l] == '\0') { + slice.erase(l + 1); + return dst; + } + if (l == sz - 1) throw invalid_argument("invalid mysql block : cstring"); } - return buffer; + throw invalid_argument("invalid mysql block : cstring"); + return ""; } protected: - luabuf* m_buf = nullptr; - codec_base* m_jcodec = nullptr; + deque sessions; + uint32_t m_capability = 0; + slice m_packet; }; } diff --git a/extend/lcodec/src/redis.h b/extend/lcodec/src/redis.h index ca556980..e2cb4e1c 100644 --- a/extend/lcodec/src/redis.h +++ b/extend/lcodec/src/redis.h @@ -1,4 +1,5 @@ #pragma once +#include #include #ifdef _MSC_VER @@ -17,24 +18,33 @@ namespace lcodec { class rdscodec : public codec_base { public: + virtual int load_packet(size_t data_len) { + if (!m_slice) return 0; + return data_len; + } + virtual uint8_t* encode(lua_State* L, int index, size_t* len) { m_buf->clean(); int n = lua_gettop(L); + uint32_t session_id = lua_tointeger(L, index++); char* head = (char*)m_buf->peek_space(HEAD_SIZE); - m_buf->pop_space(sprintf(head, "*%d\r\n", n)); - for (int i = 1; i <= n; ++i) { + m_buf->pop_space(sprintf(head, "*%d\r\n", n - index + 1)); + for (int i = index; i <= n; ++i) { encode_bulk_string(L, i); } + sessions.push_back(session_id); return m_buf->data(len); } virtual size_t decode(lua_State* L) { - if (!m_slice) return 0; int top = lua_gettop(L); size_t osize = m_slice->size(); string_view buf = m_slice->contents(); + lua_pushinteger(L, sessions.empty() ? 0 : sessions.front()); parse_redis_packet(L, buf); - m_slice->erase(osize - buf.size()); + if (!sessions.empty()) sessions.pop_front(); + m_packet_len = osize - buf.size(); + m_slice->erase(m_packet_len); return lua_gettop(L) - top; } @@ -42,10 +52,6 @@ namespace lcodec { m_jcodec = codec; } - void set_buff(luabuf* buf) { - m_buf = buf; - } - protected: void parse_redis_success(lua_State* L, string_view line) { lua_pushboolean(L, true); @@ -200,7 +206,7 @@ namespace lcodec { } protected: - luabuf* m_buf = nullptr; + deque sessions; codec_base* m_jcodec = nullptr; }; } diff --git a/extend/lcodec/src/websocket.h b/extend/lcodec/src/websocket.h index 2f5cc4f4..cbec2ab7 100644 --- a/extend/lcodec/src/websocket.h +++ b/extend/lcodec/src/websocket.h @@ -9,6 +9,21 @@ namespace lcodec { class wsscodec : public codec_base { public: + virtual int load_packet(size_t data_len) { + if (!m_slice) return 0; + uint8_t* payload = (uint8_t*)m_slice->peek(sizeof(uint8_t), 1); + if (!payload) return 0; + uint8_t masklen = (((*payload) & 0x80) == 0x80) ? 4 : 0; + uint8_t payloadlen = (*payload) & 0x7f; + if (payloadlen < 0x7e) { + return masklen + payloadlen + sizeof(uint16_t); + } + size_t* length = (size_t*)m_slice->peek((payloadlen == 0x7f) ? 8 : 2, sizeof(uint16_t)); + if (!length) return 0; + m_packet_len = masklen + (*length) + sizeof(uint16_t); + return m_packet_len; + } + virtual uint8_t* encode(lua_State* L, int index, size_t* len) { m_buf->clean(); uint8_t* body = nullptr; @@ -33,9 +48,8 @@ namespace lcodec { } virtual size_t decode(lua_State* L) { - if (!m_slice) return 0; uint8_t head = *(uint8_t*)m_slice->read(); - if ((head & 0x80) != 0x80) throw length_error("shared packet not suppert!"); + if ((head & 0x80) != 0x80) throw invalid_argument("sharded packet not suppert!"); uint8_t payload = *(uint8_t*)m_slice->read(); uint8_t opcode = head & 0xf; bool mask = ((payload & 0x80) == 0x80); @@ -48,8 +62,7 @@ namespace lcodec { lua_pushinteger(L, opcode); if (mask) { size_t data_len; - char* maskkey = (char*)m_slice->peek(4); - m_slice->erase(4); + char* maskkey = (char*)m_slice->erase(4); char* data = (char*)m_slice->data(&data_len); xor_byte(data, maskkey, data_len, 4); } @@ -68,10 +81,6 @@ namespace lcodec { m_jcodec = codec; } - void set_buff(luabuf* buf) { - m_buf = buf; - } - protected: char* xor_byte(char* buffer, char* mask, size_t blen, size_t mlen) { for (int i = 0; i < blen; i++) { @@ -81,7 +90,6 @@ namespace lcodec { } protected: - luabuf* m_buf = nullptr; codec_base* m_jcodec = nullptr; }; } diff --git a/extend/lcrypt/lcrypt.mak b/extend/lcrypt/lcrypt.mak index 794e5078..88a143ab 100644 --- a/extend/lcrypt/lcrypt.mak +++ b/extend/lcrypt/lcrypt.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/lcurl/lcurl.mak b/extend/lcurl/lcurl.mak index 716b15a7..7601dbaf 100644 --- a/extend/lcurl/lcurl.mak +++ b/extend/lcurl/lcurl.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/ldetour/ldetour.mak b/extend/ldetour/ldetour.mak index 18a4d614..dcfa2734 100644 --- a/extend/ldetour/ldetour.mak +++ b/extend/ldetour/ldetour.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/ljson/ljson.mak b/extend/ljson/ljson.mak index 95b6d2c5..bfd36e06 100644 --- a/extend/ljson/ljson.mak +++ b/extend/ljson/ljson.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas MYCFLAGS += -Wno-implicit-fallthrough diff --git a/extend/ljson/src/ljson.cpp b/extend/ljson/src/ljson.cpp index 691b0a26..04d1a19e 100644 --- a/extend/ljson/src/ljson.cpp +++ b/extend/ljson/src/ljson.cpp @@ -4,11 +4,11 @@ namespace ljson { thread_local yyjson thread_json; - thread_local jsoncodec thread_codec; - static jsoncodec* json_codec() { - thread_codec.set_json(&thread_json); - return &thread_codec; + static codec_base* json_codec() { + jsoncodec* codec = new jsoncodec(); + codec->set_json(&thread_json); + return codec; } luakit::lua_table open_ljson(lua_State* L) { diff --git a/extend/ljson/src/ljson.h b/extend/ljson/src/ljson.h index fc2d8d5a..4ad41863 100644 --- a/extend/ljson/src/ljson.h +++ b/extend/ljson/src/ljson.h @@ -213,6 +213,12 @@ namespace ljson { class jsoncodec : public codec_base { public: + virtual int load_packet(size_t data_len) { + if (!m_slice) return 0; + m_packet_len = data_len; + return data_len; + } + virtual uint8_t* encode(lua_State* L, int index, size_t* len) { yyjson_write_err err; yyjson_mut_doc* doc = yyjson_mut_doc_new(nullptr); diff --git a/extend/lmake/share.lua b/extend/lmake/share.lua index c5203da9..e5d44110 100644 --- a/extend/lmake/share.lua +++ b/extend/lmake/share.lua @@ -16,6 +16,7 @@ FLAGS = { "Wno-sign-compare", "Wno-unused-variable", "Wno-unused-parameter", + "Wno-unused-but-set-variable", "Wno-unused-but-set-parameter", "Wno-unknown-pragmas" } diff --git a/extend/lstdfs/lstdfs.mak b/extend/lstdfs/lstdfs.mak index 1240afc9..af08f3c0 100644 --- a/extend/lstdfs/lstdfs.mak +++ b/extend/lstdfs/lstdfs.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/ltimer/ltimer.mak b/extend/ltimer/ltimer.mak index 60de93b7..489eabb3 100644 --- a/extend/ltimer/ltimer.mak +++ b/extend/ltimer/ltimer.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/lua/lua.mak b/extend/lua/lua.mak index a74ccf66..f0b61fd1 100644 --- a/extend/lua/lua.mak +++ b/extend/lua/lua.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/lua/lua/lapi.c b/extend/lua/lua/lapi.c index 34e64af1..332e97d1 100644 --- a/extend/lua/lua/lapi.c +++ b/extend/lua/lua/lapi.c @@ -417,9 +417,9 @@ LUA_API const char *lua_tolstring (lua_State *L, int idx, size_t *len) { o = index2value(L, idx); /* previous call may reallocate the stack */ } if (len != NULL) - *len = vslen(o); + *len = tsslen(tsvalue(o)); lua_unlock(L); - return svalue(o); + return getstr(tsvalue(o)); } diff --git a/extend/lua/lua/ldebug.c b/extend/lua/lua/ldebug.c index 28b1caab..690ac38f 100644 --- a/extend/lua/lua/ldebug.c +++ b/extend/lua/lua/ldebug.c @@ -426,7 +426,7 @@ static const char *getobjname (const Proto *p, int lastpc, int reg, */ static void kname (const Proto *p, int c, const char **name) { TValue *kvalue = &p->k[c]; - *name = (ttisstring(kvalue)) ? svalue(kvalue) : "?"; + *name = (ttisstring(kvalue)) ? getstr(tsvalue(kvalue)) : "?"; } @@ -569,7 +569,7 @@ static const char *getobjname (const Proto *p, int lastpc, int reg, int b = (op == OP_LOADK) ? GETARG_Bx(i) : GETARG_Ax(p->code[pc + 1]); if (ttisstring(&p->k[b])) { - *name = svalue(&p->k[b]); + *name = getstr(tsvalue(&p->k[b])); return "constant"; } break; @@ -627,7 +627,7 @@ static const char *funcnamefromcode (lua_State *L, const Proto *p, default: return NULL; /* cannot find a reasonable name */ } - *name = getstr(G(L)->tmname[tm]) + 2; + *name = getshrstr(G(L)->tmname[tm]) + 2; return "metamethod"; } @@ -865,6 +865,28 @@ static int changedline (const Proto *p, int oldpc, int newpc) { } +/* +** Traces Lua calls. If code is running the first instruction of a function, +** and function is not vararg, and it is not coming from an yield, +** calls 'luaD_hookcall'. (Vararg functions will call 'luaD_hookcall' +** after adjusting its variable arguments; otherwise, they could call +** a line/count hook before the call hook. Functions coming from +** an yield already called 'luaD_hookcall' before yielding.) +*/ +int luaG_tracecall (lua_State *L) { + CallInfo *ci = L->ci; + Proto *p = ci_func(ci)->p; + ci->u.l.trap = 1; /* ensure hooks will be checked */ + if (ci->u.l.savedpc == p->code) { /* first instruction (not resuming)? */ + if (p->is_vararg) + return 0; /* hooks will start at VARARGPREP instruction */ + else if (!(ci->callstatus & CIST_HOOKYIELD)) /* not yieded? */ + luaD_hookcall(L, ci); /* check 'call' hook */ + } + return 1; /* keep 'trap' on */ +} + + /* ** Traces the execution of a Lua function. Called before the execution ** of each opcode, when debug is on. 'L->oldpc' stores the last diff --git a/extend/lua/lua/ldebug.h b/extend/lua/lua/ldebug.h index 2c3074c6..2bfce3cb 100644 --- a/extend/lua/lua/ldebug.h +++ b/extend/lua/lua/ldebug.h @@ -58,6 +58,7 @@ LUAI_FUNC const char *luaG_addinfo (lua_State *L, const char *msg, TString *src, int line); LUAI_FUNC l_noret luaG_errormsg (lua_State *L); LUAI_FUNC int luaG_traceexec (lua_State *L, const Instruction *pc); +LUAI_FUNC int luaG_tracecall (lua_State *L); #endif diff --git a/extend/lua/lua/lgc.c b/extend/lua/lua/lgc.c index dd824e77..253a2892 100644 --- a/extend/lua/lua/lgc.c +++ b/extend/lua/lua/lgc.c @@ -542,10 +542,12 @@ static void traversestrongtable (global_State *g, Table *h) { static lu_mem traversetable (global_State *g, Table *h) { const char *weakkey, *weakvalue; const TValue *mode = gfasttm(g, h->metatable, TM_MODE); + TString *smode; markobjectN(g, h->metatable); - if (mode && ttisstring(mode) && /* is there a weak mode? */ - (cast_void(weakkey = strchr(svalue(mode), 'k')), - cast_void(weakvalue = strchr(svalue(mode), 'v')), + if (mode && ttisshrstring(mode) && /* is there a weak mode? */ + (cast_void(smode = tsvalue(mode)), + cast_void(weakkey = strchr(getshrstr(smode), 'k')), + cast_void(weakvalue = strchr(getshrstr(smode), 'v')), (weakkey || weakvalue))) { /* is really weak? */ if (!weakkey) /* strong keys? */ traverseweakvalue(g, h); @@ -638,7 +640,9 @@ static int traversethread (global_State *g, lua_State *th) { for (uv = th->openupval; uv != NULL; uv = uv->u.open.next) markobject(g, uv); /* open upvalues cannot be collected */ if (g->gcstate == GCSatomic) { /* final traversal? */ - for (; o < th->stack_last.p + EXTRA_STACK; o++) + if (!g->gcemergency) + luaD_shrinkstack(th); /* do not change stack in emergency cycle */ + for (o = th->top.p; o < th->stack_last.p + EXTRA_STACK; o++) setnilvalue(s2v(o)); /* clear dead stack slice */ /* 'remarkupvals' may have removed thread from 'twups' list */ if (!isintwups(th) && th->openupval != NULL) { @@ -646,8 +650,6 @@ static int traversethread (global_State *g, lua_State *th) { g->twups = th; } } - else if (!g->gcemergency) - luaD_shrinkstack(th); /* do not change stack in emergency cycle */ return 1 + stacksize(th); } @@ -1710,6 +1712,8 @@ static void fullinc (lua_State *L, global_State *g) { entersweep(L); /* sweep everything to turn them back to white */ /* finish any pending sweep phase to start a new cycle */ luaC_runtilstate(L, bitmask(GCSpause)); + luaC_runtilstate(L, bitmask(GCSpropagate)); /* start new cycle */ + g->gcstate = GCSenteratomic; /* go straight to atomic phase ??? */ luaC_runtilstate(L, bitmask(GCScallfin)); /* run up to finalizers */ /* estimate must be correct after a full GC cycle */ lua_assert(g->GCestimate == gettotalbytes(g)); diff --git a/extend/lua/lua/lmathlib.c b/extend/lua/lua/lmathlib.c index d0b1e1e5..f140d623 100644 --- a/extend/lua/lua/lmathlib.c +++ b/extend/lua/lua/lmathlib.c @@ -249,6 +249,15 @@ static int math_type (lua_State *L) { ** =================================================================== */ +/* +** This code uses lots of shifts. ANSI C does not allow shifts greater +** than or equal to the width of the type being shifted, so some shifts +** are written in convoluted ways to match that restriction. For +** preprocessor tests, it assumes a width of 32 bits, so the maximum +** shift there is 31 bits. +*/ + + /* number of binary digits in the mantissa of a float */ #define FIGS l_floatatt(MANT_DIG) @@ -271,16 +280,19 @@ static int math_type (lua_State *L) { /* 'long' has at least 64 bits */ #define Rand64 unsigned long +#define SRand64 long #elif !defined(LUA_USE_C89) && defined(LLONG_MAX) /* there is a 'long long' type (which must have at least 64 bits) */ #define Rand64 unsigned long long +#define SRand64 long long #elif ((LUA_MAXUNSIGNED >> 31) >> 31) >= 3 /* 'lua_Unsigned' has at least 64 bits */ #define Rand64 lua_Unsigned +#define SRand64 lua_Integer #endif @@ -319,23 +331,30 @@ static Rand64 nextrand (Rand64 *state) { } -/* must take care to not shift stuff by more than 63 slots */ - - /* ** Convert bits from a random integer into a float in the ** interval [0,1), getting the higher FIG bits from the ** random unsigned integer and converting that to a float. +** Some old Microsoft compilers cannot cast an unsigned long +** to a floating-point number, so we use a signed long as an +** intermediary. When lua_Number is float or double, the shift ensures +** that 'sx' is non negative; in that case, a good compiler will remove +** the correction. */ /* must throw out the extra (64 - FIGS) bits */ #define shift64_FIG (64 - FIGS) -/* to scale to [0, 1), multiply by scaleFIG = 2^(-FIGS) */ +/* 2^(-FIGS) == 2^-1 / 2^(FIGS-1) */ #define scaleFIG (l_mathop(0.5) / ((Rand64)1 << (FIGS - 1))) static lua_Number I2d (Rand64 x) { - return (lua_Number)(trim64(x) >> shift64_FIG) * scaleFIG; + SRand64 sx = (SRand64)(trim64(x) >> shift64_FIG); + lua_Number res = (lua_Number)(sx) * scaleFIG; + if (sx < 0) + res += 1.0; /* correct the two's complement if negative */ + lua_assert(0 <= res && res < 1); + return res; } /* convert a 'Rand64' to a 'lua_Unsigned' */ @@ -471,8 +490,6 @@ static lua_Number I2d (Rand64 x) { #else /* 32 < FIGS <= 64 */ -/* must take care to not shift stuff by more than 31 slots */ - /* 2^(-FIGS) = 1.0 / 2^30 / 2^3 / 2^(FIGS-33) */ #define scaleFIG \ (l_mathop(1.0) / (UONE << 30) / l_mathop(8.0) / (UONE << (FIGS - 33))) diff --git a/extend/lua/lua/lobject.c b/extend/lua/lua/lobject.c index f73ffc6d..9cfa5227 100644 --- a/extend/lua/lua/lobject.c +++ b/extend/lua/lua/lobject.c @@ -542,7 +542,7 @@ const char *luaO_pushvfstring (lua_State *L, const char *fmt, va_list argp) { addstr2buff(&buff, fmt, strlen(fmt)); /* rest of 'fmt' */ clearbuff(&buff); /* empty buffer into the stack */ lua_assert(buff.pushed == 1); - return svalue(s2v(L->top.p - 1)); + return getstr(tsvalue(s2v(L->top.p - 1))); } diff --git a/extend/lua/lua/lobject.h b/extend/lua/lua/lobject.h index 556608e4..980e42f8 100644 --- a/extend/lua/lua/lobject.h +++ b/extend/lua/lua/lobject.h @@ -386,7 +386,7 @@ typedef struct GCObject { typedef struct TString { CommonHeader; lu_byte extra; /* reserved words for short strings; "has hash" for longs */ - lu_byte shrlen; /* length for short strings */ + lu_byte shrlen; /* length for short strings, 0xFF for long strings */ unsigned int hash; union { size_t lnglen; /* length for long strings */ @@ -398,19 +398,17 @@ typedef struct TString { /* -** Get the actual string (array of bytes) from a 'TString'. +** Get the actual string (array of bytes) from a 'TString'. (Generic +** version and specialized versions for long and short strings.) */ -#define getstr(ts) ((ts)->contents) +#define getstr(ts) ((ts)->contents) +#define getlngstr(ts) check_exp((ts)->shrlen == 0xFF, (ts)->contents) +#define getshrstr(ts) check_exp((ts)->shrlen != 0xFF, (ts)->contents) -/* get the actual string (array of bytes) from a Lua value */ -#define svalue(o) getstr(tsvalue(o)) - /* get string length from 'TString *s' */ -#define tsslen(s) ((s)->tt == LUA_VSHRSTR ? (s)->shrlen : (s)->u.lnglen) - -/* get string length from 'TValue *o' */ -#define vslen(o) tsslen(tsvalue(o)) +#define tsslen(s) \ + ((s)->shrlen != 0xFF ? (s)->shrlen : (s)->u.lnglen) /* }================================================================== */ diff --git a/extend/lua/lua/lparser.c b/extend/lua/lua/lparser.c index b745f236..2b888c7c 100644 --- a/extend/lua/lua/lparser.c +++ b/extend/lua/lua/lparser.c @@ -1022,10 +1022,11 @@ static int explist (LexState *ls, expdesc *v) { } -static void funcargs (LexState *ls, expdesc *f, int line) { +static void funcargs (LexState *ls, expdesc *f) { FuncState *fs = ls->fs; expdesc args; int base, nparams; + int line = ls->linenumber; switch (ls->t.token) { case '(': { /* funcargs -> '(' [ explist ] ')' */ luaX_next(ls); @@ -1063,8 +1064,8 @@ static void funcargs (LexState *ls, expdesc *f, int line) { } init_exp(f, VCALL, luaK_codeABC(fs, OP_CALL, base, nparams+1, 2)); luaK_fixline(fs, line); - fs->freereg = base+1; /* call remove function and arguments and leaves - (unless changed) one result */ + fs->freereg = base+1; /* call removes function and arguments and leaves + one result (unless changed later) */ } @@ -1103,7 +1104,6 @@ static void suffixedexp (LexState *ls, expdesc *v) { /* suffixedexp -> primaryexp { '.' NAME | '[' exp ']' | ':' NAME funcargs | funcargs } */ FuncState *fs = ls->fs; - int line = ls->linenumber; primaryexp(ls, v); for (;;) { switch (ls->t.token) { @@ -1123,12 +1123,12 @@ static void suffixedexp (LexState *ls, expdesc *v) { luaX_next(ls); codename(ls, &key); luaK_self(fs, v, &key); - funcargs(ls, v, line); + funcargs(ls, v); break; } case '(': case TK_STRING: case '{': { /* funcargs */ luaK_exp2nextreg(fs, v); - funcargs(ls, v, line); + funcargs(ls, v); break; } default: return; diff --git a/extend/lua/lua/lstate.c b/extend/lua/lua/lstate.c index 06667dac..7fefacba 100644 --- a/extend/lua/lua/lstate.c +++ b/extend/lua/lua/lstate.c @@ -433,7 +433,7 @@ void luaE_warning (lua_State *L, const char *msg, int tocont) { void luaE_warnerror (lua_State *L, const char *where) { TValue *errobj = s2v(L->top.p - 1); /* error object */ const char *msg = (ttisstring(errobj)) - ? svalue(errobj) + ? getstr(tsvalue(errobj)) : "error object is not a string"; /* produce warning "error in %s (%s)" (where, msg) */ luaE_warning(L, "error in ", 1); diff --git a/extend/lua/lua/lstate.h b/extend/lua/lua/lstate.h index 40ff89aa..007704c8 100644 --- a/extend/lua/lua/lstate.h +++ b/extend/lua/lua/lstate.h @@ -181,7 +181,7 @@ struct CallInfo { union { struct { /* only for Lua functions */ const Instruction *savedpc; - volatile l_signalT trap; + volatile l_signalT trap; /* function is tracing lines/counts */ int nextraargs; /* # of extra arguments in vararg functions */ } l; struct { /* only for C functions */ diff --git a/extend/lua/lua/lstring.c b/extend/lua/lua/lstring.c index 13dcaf42..e921dd0f 100644 --- a/extend/lua/lua/lstring.c +++ b/extend/lua/lua/lstring.c @@ -36,7 +36,7 @@ int luaS_eqlngstr (TString *a, TString *b) { lua_assert(a->tt == LUA_VLNGSTR && b->tt == LUA_VLNGSTR); return (a == b) || /* same instance or... */ ((len == b->u.lnglen) && /* equal length and ... */ - (memcmp(getstr(a), getstr(b), len) == 0)); /* equal contents */ + (memcmp(getlngstr(a), getlngstr(b), len) == 0)); /* equal contents */ } @@ -52,7 +52,7 @@ unsigned int luaS_hashlongstr (TString *ts) { lua_assert(ts->tt == LUA_VLNGSTR); if (ts->extra == 0) { /* no hash? */ size_t len = ts->u.lnglen; - ts->hash = luaS_hash(getstr(ts), len, ts->hash); + ts->hash = luaS_hash(getlngstr(ts), len, ts->hash); ts->extra = 1; /* now it has its hash */ } return ts->hash; @@ -157,6 +157,7 @@ static TString *createstrobj (lua_State *L, size_t l, int tag, unsigned int h) { TString *luaS_createlngstrobj (lua_State *L, size_t l) { TString *ts = createstrobj(L, l, LUA_VLNGSTR, G(L)->seed); ts->u.lnglen = l; + ts->shrlen = 0xFF; /* signals that it is a long string */ return ts; } @@ -193,7 +194,7 @@ static TString *internshrstr (lua_State *L, const char *str, size_t l) { TString **list = &tb->hash[lmod(h, tb->size)]; lua_assert(str != NULL); /* otherwise 'memcmp'/'memcpy' are undefined */ for (ts = *list; ts != NULL; ts = ts->u.hnext) { - if (l == ts->shrlen && (memcmp(str, getstr(ts), l * sizeof(char)) == 0)) { + if (l == ts->shrlen && (memcmp(str, getshrstr(ts), l * sizeof(char)) == 0)) { /* found! */ if (isdead(g, ts)) /* dead (but not collected yet)? */ changewhite(ts); /* resurrect it */ @@ -206,8 +207,8 @@ static TString *internshrstr (lua_State *L, const char *str, size_t l) { list = &tb->hash[lmod(h, tb->size)]; /* rehash with new size */ } ts = createstrobj(L, l, LUA_VSHRSTR, h); - memcpy(getstr(ts), str, l * sizeof(char)); ts->shrlen = cast_byte(l); + memcpy(getshrstr(ts), str, l * sizeof(char)); ts->u.hnext = *list; *list = ts; tb->nuse++; @@ -226,7 +227,7 @@ TString *luaS_newlstr (lua_State *L, const char *str, size_t l) { if (l_unlikely(l >= (MAX_SIZE - sizeof(TString))/sizeof(char))) luaM_toobig(L); ts = luaS_createlngstrobj(L, l); - memcpy(getstr(ts), str, l * sizeof(char)); + memcpy(getlngstr(ts), str, l * sizeof(char)); return ts; } } diff --git a/extend/lua/lua/lundump.c b/extend/lua/lua/lundump.c index 02aed64f..e8d92a85 100644 --- a/extend/lua/lua/lundump.c +++ b/extend/lua/lua/lundump.c @@ -81,7 +81,7 @@ static size_t loadUnsigned (LoadState *S, size_t limit) { static size_t loadSize (LoadState *S) { - return loadUnsigned(S, ~(size_t)0); + return loadUnsigned(S, MAX_SIZET); } @@ -122,7 +122,7 @@ static TString *loadStringN (LoadState *S, Proto *p) { ts = luaS_createlngstrobj(L, size); /* create string */ setsvalue2s(L, L->top.p, ts); /* anchor it ('loadVector' can GC) */ luaD_inctop(L); - loadVector(S, getstr(ts), size); /* load directly in final place */ + loadVector(S, getlngstr(ts), size); /* load directly in final place */ L->top.p--; /* pop string */ } luaC_objbarrier(L, p, ts); diff --git a/extend/lua/lua/lundump.h b/extend/lua/lua/lundump.h index f3748a99..bc71ced8 100644 --- a/extend/lua/lua/lundump.h +++ b/extend/lua/lua/lundump.h @@ -21,8 +21,7 @@ /* ** Encode major-minor version in one byte, one nibble for each */ -#define MYINT(s) (s[0]-'0') /* assume one-digit numerals */ -#define LUAC_VERSION (MYINT(LUA_VERSION_MAJOR)*16+MYINT(LUA_VERSION_MINOR)) +#define LUAC_VERSION (LUA_VERSION_MAJOR_N*16+LUA_VERSION_MINOR_N) #define LUAC_FORMAT 0 /* this is the official format */ diff --git a/extend/lua/lua/lvm.c b/extend/lua/lua/lvm.c index 2b437bdf..4d71cfff 100644 --- a/extend/lua/lua/lvm.c +++ b/extend/lua/lua/lvm.c @@ -91,8 +91,10 @@ static int l_strton (const TValue *obj, TValue *result) { lua_assert(obj != result); if (!cvt2num(obj)) /* is object not a string? */ return 0; - else - return (luaO_str2num(svalue(obj), result) == vslen(obj) + 1); + else { + TString *st = tsvalue(obj); + return (luaO_str2num(getstr(st), result) == tsslen(st) + 1); + } } @@ -626,8 +628,9 @@ int luaV_equalobj (lua_State *L, const TValue *t1, const TValue *t2) { static void copy2buff (StkId top, int n, char *buff) { size_t tl = 0; /* size already copied */ do { - size_t l = vslen(s2v(top - n)); /* length of string being copied */ - memcpy(buff + tl, svalue(s2v(top - n)), l * sizeof(char)); + TString *st = tsvalue(s2v(top - n)); + size_t l = tsslen(st); /* length of string being copied */ + memcpy(buff + tl, getstr(st), l * sizeof(char)); tl += l; } while (--n > 0); } @@ -653,11 +656,11 @@ void luaV_concat (lua_State *L, int total) { } else { /* at least two non-empty string values; get as many as possible */ - size_t tl = vslen(s2v(top - 1)); + size_t tl = tsslen(tsvalue(s2v(top - 1))); TString *ts; /* collect total length and number of strings */ for (n = 1; n < total && tostring(L, s2v(top - n - 1)); n++) { - size_t l = vslen(s2v(top - n - 1)); + size_t l = tsslen(tsvalue(s2v(top - n - 1))); if (l_unlikely(l >= (MAX_SIZE/sizeof(char)) - tl)) { L->top.p = top - total; /* pop strings to avoid wasting stack */ luaG_runerror(L, "string length overflow"); @@ -671,7 +674,7 @@ void luaV_concat (lua_State *L, int total) { } else { /* long string; copy strings directly to final result */ ts = luaS_createlngstrobj(L, tl); - copy2buff(top, n, getstr(ts)); + copy2buff(top, n, getlngstr(ts)); } setsvalue2s(L, top - n, ts); /* create result */ } @@ -1157,18 +1160,11 @@ void luaV_execute (lua_State *L, CallInfo *ci) { startfunc: trap = L->hookmask; returning: /* trap already set */ - cl = clLvalue(s2v(ci->func.p)); + cl = ci_func(ci); k = cl->p->k; pc = ci->u.l.savedpc; - if (l_unlikely(trap)) { - if (pc == cl->p->code) { /* first instruction (not resuming)? */ - if (cl->p->is_vararg) - trap = 0; /* hooks will start after VARARGPREP instruction */ - else /* check 'call' hook */ - luaD_hookcall(L, ci); - } - ci->u.l.trap = 1; /* assume trap is on, for now */ - } + if (l_unlikely(trap)) + trap = luaG_tracecall(L); base = ci->func.p + 1; /* main loop of interpreter */ for (;;) { diff --git a/extend/lua/luac.mak b/extend/lua/luac.mak index 98505e64..c98f5328 100644 --- a/extend/lua/luac.mak +++ b/extend/lua/luac.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/lua/lualib.mak b/extend/lua/lualib.mak index 216cb7d5..0ed88bc9 100644 --- a/extend/lua/lualib.mak +++ b/extend/lua/lualib.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/luakit/include/lua_codec.h b/extend/luakit/include/lua_codec.h index 4787705a..78017b8f 100644 --- a/extend/luakit/include/lua_codec.h +++ b/extend/luakit/include/lua_codec.h @@ -42,7 +42,7 @@ namespace luakit { T value_decode(lua_State* L, slice* slice) { T* value = slice->read(); if (value == nullptr) { - throw std::length_error("decode can't unpack one value"); + throw std::invalid_argument("decode can't unpack one value"); } return *value; } @@ -164,7 +164,7 @@ namespace luakit { } auto str = (const char*)slice->peek(sz); if (str == nullptr || sz > USHRT_MAX) { - throw std::length_error("decode string is out of range"); + throw std::invalid_argument("decode string is out of range"); } slice->erase(sz); lua_pushlstring(L, str, sz); @@ -386,28 +386,58 @@ namespace luakit { class codec_base { public: - void __gc() {} + virtual ~codec_base(){} virtual size_t decode(lua_State* L) = 0; + virtual int load_packet(size_t data_len) = 0; virtual uint8_t* encode(lua_State* L, int index, size_t* len) = 0; size_t decode(lua_State* L, uint8_t* data, size_t len) { slice mslice(data, len); m_slice = &mslice; return decode(L); } - void set_slice(slice* slice) { m_slice = slice; } + virtual void error(const std::string& err) { + m_err = err; + m_failed = true; + } + virtual void set_slice(slice* slice) { + m_err = ""; + m_slice = slice; + m_packet_len = 0; + m_failed = false; + } + virtual bool failed() { return m_failed; } + virtual const char* err() { return m_err.c_str(); } + virtual size_t get_packet_len() { return m_packet_len; } + virtual void set_buff(luabuf* buf) { m_buf = buf; } + protected: + bool m_failed = false; + luabuf* m_buf = nullptr; slice* m_slice = nullptr; + size_t m_packet_len = 0; + std::string m_err = ""; }; class luacodec : public codec_base { public: + virtual int load_packet(size_t data_len) { + if (!m_slice) return 0; + uint32_t* packet_len = (uint32_t*)m_slice->peek(sizeof(uint32_t)); + if (!packet_len) return 0; + m_packet_len = *packet_len; + if (m_packet_len > 0xffffff) return -1; + if (m_packet_len > data_len) return 0; + if (!m_slice->peek(m_packet_len)) return 0; + return m_packet_len; + } + virtual uint8_t* encode(lua_State* L, int index, size_t* len) { - m_buf.clean(); + m_buf->clean(); int n = lua_gettop(L); for (int i = index; i <= n; i++) { - encode_one(L, &m_buf, i, 0); + encode_one(L, m_buf, i, 0); } - return m_buf.data(len); + return m_buf->data(len); } virtual size_t decode(lua_State* L) { @@ -422,8 +452,5 @@ namespace luakit { m_slice = nullptr; return argnum; } - - protected: - luabuf m_buf; }; } diff --git a/extend/luakit/include/lua_function.h b/extend/luakit/include/lua_function.h index 23d45f22..55da046a 100644 --- a/extend/luakit/include/lua_function.h +++ b/extend/luakit/include/lua_function.h @@ -283,7 +283,7 @@ namespace luakit { } catch(const std::length_error&) { return false; } catch(const std::exception& e) { - if (efn) efn(e.what()); + codec->error(e.what()); return false; } if (!lua_call_function(L, efn, arg_num, sizeof...(ret_types))) diff --git a/extend/luakit/include/lua_kit.h b/extend/luakit/include/lua_kit.h index 4ec1dce1..531f435a 100644 --- a/extend/luakit/include/lua_kit.h +++ b/extend/luakit/include/lua_kit.h @@ -5,7 +5,6 @@ #include "lua_class.h" namespace luakit { - class kit_state { public: kit_state() { @@ -20,11 +19,28 @@ namespace luakit { "peek", &slice::check, "string", &slice::string ); + m_buf = new luabuf(); + lua_table luakit = new_table("luakit"); + luakit.set_function("encode", [&](lua_State* L) { return encode(L, m_buf); }); + luakit.set_function("decode", [&](lua_State* L) { return decode(L, m_buf); }); + luakit.set_function("unserialize", [&](lua_State* L) { return unserialize(L); }); + luakit.set_function("serialize", [&](lua_State* L) { return serialize(L, m_buf); }); } kit_state(lua_State* L) : m_L(L) {} void close() { lua_close(m_L); + if (m_buf) { delete m_buf; } + if (m_codec) { delete m_codec; } + } + + codec_base* create_codec() { + if (!m_codec) { + if (!m_buf) m_buf = new luabuf(); + m_codec = new luacodec(); + m_codec->set_buff(m_buf); + } + return m_codec; } template @@ -169,6 +185,8 @@ namespace luakit { } protected: + luabuf* m_buf = nullptr; + luacodec* m_codec = nullptr; lua_State* m_L = nullptr; }; diff --git a/extend/luakit/include/lua_slice.h b/extend/luakit/include/lua_slice.h index 0c43d1a1..685e6147 100644 --- a/extend/luakit/include/lua_slice.h +++ b/extend/luakit/include/lua_slice.h @@ -20,35 +20,30 @@ namespace luakit { return m_tail == m_head; } + slice clone() { + return slice(m_head, m_tail - m_head); + } + void attach(uint8_t* data, size_t size) { m_head = data; m_tail = data + size; } - uint8_t* peek(size_t peek_len) { - size_t data_len = m_tail - m_head; + uint8_t* peek(size_t peek_len, size_t offset = 0) { + size_t data_len = m_tail - m_head - offset; if (peek_len > 0 && data_len >= peek_len) { - return m_head; + return m_head + offset; } return nullptr; } - size_t erase(size_t erase_len) { + uint8_t* erase(size_t erase_len) { + uint8_t* data = m_head; if (m_head + erase_len <= m_tail) { m_head += erase_len; - return erase_len; - } - return 0; - } - - int check(lua_State* L) { - size_t peek_len = lua_tointeger(L, 1); - size_t data_len = m_tail - m_head; - if (peek_len > 0 && data_len >= peek_len) { - lua_pushlstring(L, (const char*)m_head, peek_len); - return 1; + return data; } - return 0; + return nullptr; } size_t pop(uint8_t* dest, size_t read_len) { @@ -73,17 +68,6 @@ namespace luakit { return nullptr; } - int recv(lua_State* L) { - size_t data_len = m_tail - m_head; - size_t read_len = lua_tointeger(L, 1); - if (read_len > 0 && data_len >= read_len) { - lua_pushlstring(L, (const char*)m_head, read_len); - m_head += read_len; - return 1; - } - return 0; - } - uint8_t* data(size_t* len) { *len = (size_t)(m_tail - m_head); return m_head; @@ -98,6 +82,34 @@ namespace luakit { return std::string_view((const char*)m_head, len); } + std::string_view eof() { + uint8_t* head = m_head; + m_head = m_tail; + size_t len = (size_t)(m_tail - head); + return std::string_view((const char*)head, len); + } + + int check(lua_State* L) { + size_t peek_len = lua_tointeger(L, 1); + size_t data_len = m_tail - m_head; + if (peek_len > 0 && data_len >= peek_len) { + lua_pushlstring(L, (const char*)m_head, peek_len); + return 1; + } + return 0; + } + + int recv(lua_State* L) { + size_t data_len = m_tail - m_head; + size_t read_len = lua_tointeger(L, 1); + if (read_len > 0 && data_len >= read_len) { + lua_pushlstring(L, (const char*)m_head, read_len); + m_head += read_len; + return 1; + } + return 0; + } + int string(lua_State* L) { size_t len = (size_t)(m_tail - m_head); lua_pushlstring(L, (const char*)m_head, len); diff --git a/extend/lualog/lualog.mak b/extend/lualog/lualog.mak index bbe964b8..42642b07 100644 --- a/extend/lualog/lualog.mak +++ b/extend/lualog/lualog.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/protobuf/lua-protobuf.lmak b/extend/luapb/luapb.lmak similarity index 69% rename from extend/protobuf/lua-protobuf.lmak rename to extend/luapb/luapb.lmak index 26c013ad..a4583dbe 100644 --- a/extend/protobuf/lua-protobuf.lmak +++ b/extend/luapb/luapb.lmak @@ -1,15 +1,16 @@ --工程名字 -PROJECT_NAME = "pb" +PROJECT_NAME = "luapb" --目标名字 -TARGET_NAME = "pb" +TARGET_NAME = "luapb" ----工程类型: static/dynamic/exe PROJECT_TYPE = "dynamic" --需要的include目录 INCLUDES = { - "../lua/lua" + "../lua/lua", + "../luakit/include" } @@ -18,14 +19,18 @@ WINDOWS_DEFINES = { "LUA_BUILD_AS_DLL" } ---源文件路径 -SRC_DIR = "lua-protobuf" +--源文件路径v +SRC_DIR = "src" --需要连接的库文件 LIBS = { "lua" } +OBJS = { + "luapb.cpp" +} + --依赖项目 DEPS = { "lualib" diff --git a/extend/protobuf/lua-protobuf.mak b/extend/luapb/luapb.mak similarity index 82% rename from extend/protobuf/lua-protobuf.mak rename to extend/luapb/luapb.mak index 778533d2..7176f2d0 100644 --- a/extend/protobuf/lua-protobuf.mak +++ b/extend/luapb/luapb.mak @@ -1,8 +1,8 @@ #工程名字 -PROJECT_NAME = pb +PROJECT_NAME = luapb #目标名字 -TARGET_NAME = pb +TARGET_NAME = luapb #系统环境 UNAME_S = $(shell uname -s) @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas @@ -32,6 +33,7 @@ STDCPP = -std=c++17 #需要的include目录 MYCFLAGS += -I../lua/lua +MYCFLAGS += -I../luakit/include #需要定义的选项 @@ -40,7 +42,7 @@ LDFLAGS = #源文件路径 -SRC_DIR = lua-protobuf +SRC_DIR = src #需要排除的源文件,目录基于$(SRC_DIR) EXCLUDE = @@ -95,11 +97,10 @@ LDFLAGS += -L$(SOLUTION_DIR)library #自动生成目标 OBJS = -#根目录 -OBJS += $(patsubst $(SRC_DIR)/%.c, $(INT_DIR)/%.o, $(filter-out $(EXCLUDE), $(wildcard $(SRC_DIR)/*.c))) -OBJS += $(patsubst $(SRC_DIR)/%.m, $(INT_DIR)/%.o, $(filter-out $(EXCLUDE), $(wildcard $(SRC_DIR)/*.m))) -OBJS += $(patsubst $(SRC_DIR)/%.cc, $(INT_DIR)/%.o, $(filter-out $(EXCLUDE), $(wildcard $(SRC_DIR)/*.cc))) -OBJS += $(patsubst $(SRC_DIR)/%.cpp, $(INT_DIR)/%.o, $(filter-out $(EXCLUDE), $(wildcard $(SRC_DIR)/*.cpp))) +COBJS = $(patsubst %.c, $(INT_DIR)/%.o, luapb.cpp) +MOBJS = $(patsubst %.m, $(INT_DIR)/%.o, $(COBJS)) +CCOBJS = $(patsubst %.cc, $(INT_DIR)/%.o, $(MOBJS)) +OBJS = $(patsubst %.cpp, $(INT_DIR)/%.o, $(CCOBJS)) # 编译所有源文件 $(INT_DIR)/%.o : $(SRC_DIR)/%.c diff --git a/extend/protobuf/lua-protobuf.vcxproj b/extend/luapb/luapb.vcxproj similarity index 85% rename from extend/protobuf/lua-protobuf.vcxproj rename to extend/luapb/luapb.vcxproj index 2b9f87e2..f6d50c6f 100644 --- a/extend/protobuf/lua-protobuf.vcxproj +++ b/extend/luapb/luapb.vcxproj @@ -7,17 +7,20 @@ - + - + + + true + - {F09883B5-7B33-D3D3-3C39-BBC8DD3B2BE2} - pb + {2967F038-B90B-EBF5-B268-BE3BAF66D417} + luapb Win32Proj 10.0 - pb + luapb @@ -36,14 +39,14 @@ <_ProjectFileVersion>11.0.50727.1 - pb + luapb $(SolutionDir)temp\bin\$(Platform)\ $(SolutionDir)temp\$(ProjectName)\$(Platform)\ Disabled - ..\lua\lua;$(SolutionDir)extend\mimalloc\mimalloc\include;%(AdditionalIncludeDirectories) + ..\lua\lua;..\luakit\include;$(SolutionDir)extend\mimalloc\mimalloc\include;%(AdditionalIncludeDirectories) WIN32;NDEBUG;_WINDOWS;_CRT_SECURE_NO_WARNINGS;LUA_BUILD_AS_DLL;%(PreprocessorDefinitions) Default MultiThreadedDLL diff --git a/extend/protobuf/lua-protobuf.vcxproj.filters b/extend/luapb/luapb.vcxproj.filters similarity index 78% rename from extend/protobuf/lua-protobuf.vcxproj.filters rename to extend/luapb/luapb.vcxproj.filters index 0de5dd86..a8c53da0 100644 --- a/extend/protobuf/lua-protobuf.vcxproj.filters +++ b/extend/luapb/luapb.vcxproj.filters @@ -1,12 +1,15 @@  - + inc - + + src + + src diff --git a/extend/luapb/src/luapb.cpp b/extend/luapb/src/luapb.cpp new file mode 100644 index 00000000..1cb7b861 --- /dev/null +++ b/extend/luapb/src/luapb.cpp @@ -0,0 +1,144 @@ +#define LUA_LIB +#include + +#include "pb.c" +#include "lua_kit.h" + +using namespace std; +using namespace luakit; + +namespace luapb { + + thread_local luabuf thread_buff; + + #pragma pack(1) + struct pb_header { + uint16_t len; // 整个包的长度 + uint8_t flag; // 标志位 + uint8_t type; // 消息类型 + uint16_t cmd_id; // 协议ID + uint16_t session_id; // sessionId + uint8_t crc8; // crc8 + }; + #pragma pack() + + class pbcodec : public codec_base { + public: + pbcodec(const char* pbpkg, const char* pbenum) { + m_pbpkg = pbpkg; + m_pbenum = pbenum; + m_pbpkg.append("."); + } + + virtual int load_packet(size_t data_len) { + if (!m_slice) return 0; + pb_header* header =(pb_header*)m_slice->peek(sizeof(pb_header)); + if (!header) return 0; + m_packet_len = header->len; + if (!m_slice->peek(m_packet_len)) return 0; + if (m_packet_len > 0xffff) return -1; + if (m_packet_len > data_len) return 0; + return m_packet_len; + } + + virtual uint8_t* encode(lua_State* L, int index, size_t* len) { + //header + pb_header header; + lpb_State *LS = lpb_lstate(L); + //cmdid + const pb_Type* t = pb_type_from_stack(L, LS, &header, index++); + pb_Slice sh = pb_lslice((const char*)&header, sizeof(header)); + header.flag = (uint8_t)lua_tointeger(L, index++); + header.type = (uint8_t)lua_tointeger(L, index++); + header.crc8 = (uint8_t)lua_tointeger(L, index++); + //encode + lpb_Env e; + e.L = L, e.LS = LS; + pb_resetbuffer(e.b = &LS->buffer); + lua_pushvalue(L, index); + pb_addslice(e.b, sh); + lpbE_encode(&e, t, -1); + *len = pb_bufflen(e.b); + return (uint8_t*)pb_buffer(e.b); + } + + virtual size_t decode(lua_State* L) { + pb_header* header =(pb_header*)m_slice->erase(sizeof(pb_header)); + //cmd_id + lpb_State* LS = lpb_lstate(L); + const pb_Type* t = pb_type_from_enum(L, LS, header->cmd_id); + //data + size_t data_len; + const char* data = (const char*)m_slice->data(&data_len); + pb_Slice s = pb_lslice(data, data_len); + //decode + lpb_Env e; + int top = lua_gettop(L); + lua_pushinteger(L, header->cmd_id); + lua_pushinteger(L, header->flag); + lua_pushinteger(L, header->type); + lua_pushinteger(L, header->crc8); + lpb_pushtypetable(L, LS, t); + e.L = L, e.LS = LS, e.s = &s; + lpbD_message(&e, t); + return lua_gettop(L) - top; + } + + protected: + const pb_Type* pb_type_from_name(lua_State* L, lpb_State* LS, string cmd_name) { + //去掉前缀 NID_ + cmd_name = cmd_name.substr(4); + std::transform(cmd_name.begin(), cmd_name.end(), cmd_name.begin(), [](auto c) { return std::tolower(c); }); + cmd_name = m_pbpkg + cmd_name; + return lpb_type(L, LS, pb_lslice(cmd_name.c_str(), cmd_name.size())); + } + + const pb_Type* pb_type_from_enum(lua_State* L, lpb_State* LS, size_t cmd_id) { + const pb_Type* t = lpb_type(L, LS, pb_lslice(m_pbenum.c_str(), m_pbenum.size())); + const pb_Field* f = pb_field(t, cmd_id); + if (f == nullptr) throw invalid_argument("invalid pb cmdid: " + cmd_id); + return pb_type_from_name(L, LS, (const char*)f->name); + } + + const pb_Type* pb_type_from_stack(lua_State* L, lpb_State* LS, pb_header* header, int index) { + const pb_Type* t = lpb_type(L, LS, pb_lslice(m_pbenum.c_str(), m_pbenum.size())); + const pb_Field* f = lpb_field(L, index, t); + if (f) { + header->cmd_id = f->number; + return pb_type_from_name(L, LS, (const char*)f->name); + } + if (lua_type(L, index) == LUA_TNUMBER) { + luaL_error(L, "invalid pb cmdid: %d", lua_tointeger(L, index)); + } + if (lua_type(L, index) == LUA_TSTRING) { + luaL_error(L, "invalid pb cmd: %s", lua_tostring(L, index)); + } + luaL_error(L, "invalid pb cmd type"); + return nullptr; + } + + protected: + string m_pbpkg; + string m_pbenum; + }; + + static codec_base* pb_codec(const char* pkgname, const char* pbenum) { + pbcodec* codec = new pbcodec(pkgname, pbenum); + codec->set_buff(&thread_buff); + return codec; + } + + luakit::lua_table open_luapb(lua_State* L) { + luaopen_pb(L); + lua_table luapb(L); + luapb.set_function("pbcodec", pb_codec); + return luapb; + } +} + +extern "C" { + LUALIB_API int luaopen_luapb(lua_State* L) { + auto luapb = luapb::open_luapb(L); + return luapb.push_stack(); + } +} diff --git a/extend/protobuf/lua-protobuf/pb.c b/extend/luapb/src/pb.c similarity index 100% rename from extend/protobuf/lua-protobuf/pb.c rename to extend/luapb/src/pb.c diff --git a/extend/protobuf/lua-protobuf/pb.h b/extend/luapb/src/pb.h similarity index 100% rename from extend/protobuf/lua-protobuf/pb.h rename to extend/luapb/src/pb.h diff --git a/extend/luaxlsx/luaxlsx.mak b/extend/luaxlsx/luaxlsx.mak index 8ffe32a9..7b188c40 100644 --- a/extend/luaxlsx/luaxlsx.mak +++ b/extend/luaxlsx/luaxlsx.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas MYCFLAGS += -Wno-implicit-fallthrough diff --git a/extend/lworker/lworker.mak b/extend/lworker/lworker.mak index 1d6c3fab..bb000c36 100644 --- a/extend/lworker/lworker.mak +++ b/extend/lworker/lworker.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/extend/lworker/src/scheduler.h b/extend/lworker/src/scheduler.h index a74a3d98..b35519eb 100644 --- a/extend/lworker/src/scheduler.h +++ b/extend/lworker/src/scheduler.h @@ -15,6 +15,7 @@ namespace lworker { m_service = service; m_sandbox = sandbox; m_lua = std::make_shared(L); + m_codec = m_lua->create_codec(); } std::shared_ptr find_worker(std::string_view name) { @@ -86,7 +87,11 @@ namespace lworker { slice* slice = read_slice(m_read_buf, &plen); while (slice) { m_codec->set_slice(slice); - m_lua->table_call(service, "on_scheduler", nullptr, m_codec.get(), std::tie()); + m_lua->table_call(service, "on_scheduler", nullptr, m_codec, std::tie()); + if (m_codec->failed()){ + m_read_buf->clean(); + break; + } m_read_buf->pop_size(plen); if (ltimer::steady_ms() - clock_ms > 100) break; slice = read_slice(m_read_buf, &plen); @@ -111,11 +116,11 @@ namespace lworker { private: spin_mutex m_mutex; + codec_base* m_codec = nullptr; std::string m_service, m_sandbox; std::shared_ptr m_lua = nullptr; std::shared_ptr m_read_buf = std::make_shared(); std::shared_ptr m_write_buf = std::make_shared(); - std::shared_ptr m_codec = std::make_shared(); std::map, std::less<>> m_worker_map; }; } diff --git a/extend/lworker/src/worker.h b/extend/lworker/src/worker.h index 1756a466..26abe7f1 100644 --- a/extend/lworker/src/worker.h +++ b/extend/lworker/src/worker.h @@ -72,6 +72,9 @@ namespace lworker { size_t data_len; std::unique_lock lock(m_mutex); uint8_t* data = m_codec->encode(L, 2, &data_len); + if (data == nullptr) { + return false; + } uint8_t* target = m_write_buf->peek_space(data_len + sizeof(uint32_t)); if (target) { m_write_buf->write(data_len); @@ -94,7 +97,11 @@ namespace lworker { slice* slice = read_slice(m_read_buf, &plen); while (slice) { m_codec->set_slice(slice); - m_lua->table_call(service, "on_worker", nullptr, m_codec.get(), std::tie()); + m_lua->table_call(service, "on_worker", nullptr, m_codec, std::tie()); + if (m_codec->failed()) { + m_read_buf->clean(); + break; + } m_read_buf->pop_size(plen); slice = read_slice(m_read_buf, &plen); if (ltimer::steady_ms() - clock_ms > 100) break; @@ -106,6 +113,7 @@ namespace lworker { } void run(){ + m_codec = m_lua->create_codec(); auto quanta = m_lua->new_table(m_service.c_str()); quanta.set("title", m_name); quanta.set_function("stop", [&]() { m_running = false; }); @@ -142,12 +150,12 @@ namespace lworker { std::thread m_thread; bool m_stop = false; bool m_running = false; + codec_base* m_codec = nullptr; ischeduler* m_schedulor = nullptr; std::string m_name, m_entry, m_service, m_sandbox; std::shared_ptr m_lua = std::make_shared(); std::shared_ptr m_read_buf = std::make_shared(); std::shared_ptr m_write_buf = std::make_shared(); - std::shared_ptr m_codec = std::make_shared(); }; } diff --git a/extend/mimalloc/mimalloc.mak b/extend/mimalloc/mimalloc.mak index 76d8265c..c1db598c 100644 --- a/extend/mimalloc/mimalloc.mak +++ b/extend/mimalloc/mimalloc.mak @@ -19,6 +19,7 @@ MYCFLAGS += -Wsign-compare MYCFLAGS += -Wno-sign-compare MYCFLAGS += -Wno-unused-variable MYCFLAGS += -Wno-unused-parameter +MYCFLAGS += -Wno-unused-but-set-variable MYCFLAGS += -Wno-unused-but-set-parameter MYCFLAGS += -Wno-unknown-pragmas diff --git a/quanta.sln b/quanta.sln index d9ebe9f5..f4656f3c 100644 --- a/quanta.sln +++ b/quanta.sln @@ -64,17 +64,17 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "lualog", "extend\lualog\lua {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF} = {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF} EndProjectSection EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "luaxlsx", "extend\luaxlsx\luaxlsx.vcxproj", "{B2999D78-279A-1A53-CC82-74D494144722}" +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "luapb", "extend\luapb\luapb.vcxproj", "{2967F038-B90B-EBF5-B268-BE3BAF66D417}" ProjectSection(ProjectDependencies) = postProject {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF} = {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF} EndProjectSection EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "lworker", "extend\lworker\lworker.vcxproj", "{7186BCD3-4393-85B5-963F-880AC8A6F795}" +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "luaxlsx", "extend\luaxlsx\luaxlsx.vcxproj", "{B2999D78-279A-1A53-CC82-74D494144722}" ProjectSection(ProjectDependencies) = postProject {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF} = {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF} EndProjectSection EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "pb", "extend\protobuf\lua-protobuf.vcxproj", "{F09883B5-7B33-D3D3-3C39-BBC8DD3B2BE2}" +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "lworker", "extend\lworker\lworker.vcxproj", "{7186BCD3-4393-85B5-963F-880AC8A6F795}" ProjectSection(ProjectDependencies) = postProject {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF} = {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF} EndProjectSection @@ -127,12 +127,12 @@ Global {2DE9C09F-02F5-DF2B-524C-FFBEBC8DE5EE}.Develop|x64.Build.0 = Develop|x64 {E44F97F7-1DE3-A547-6B58-30AEE63F22F5}.Develop|x64.ActiveCfg = Develop|x64 {E44F97F7-1DE3-A547-6B58-30AEE63F22F5}.Develop|x64.Build.0 = Develop|x64 + {2967F038-B90B-EBF5-B268-BE3BAF66D417}.Develop|x64.ActiveCfg = Develop|x64 + {2967F038-B90B-EBF5-B268-BE3BAF66D417}.Develop|x64.Build.0 = Develop|x64 {B2999D78-279A-1A53-CC82-74D494144722}.Develop|x64.ActiveCfg = Develop|x64 {B2999D78-279A-1A53-CC82-74D494144722}.Develop|x64.Build.0 = Develop|x64 {7186BCD3-4393-85B5-963F-880AC8A6F795}.Develop|x64.ActiveCfg = Develop|x64 {7186BCD3-4393-85B5-963F-880AC8A6F795}.Develop|x64.Build.0 = Develop|x64 - {F09883B5-7B33-D3D3-3C39-BBC8DD3B2BE2}.Develop|x64.ActiveCfg = Develop|x64 - {F09883B5-7B33-D3D3-3C39-BBC8DD3B2BE2}.Develop|x64.Build.0 = Develop|x64 {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF}.Develop|x64.ActiveCfg = Develop|x64 {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF}.Develop|x64.Build.0 = Develop|x64 {F8DBBBDB-3B80-B4F1-70A8-272978F579EB}.Develop|x64.ActiveCfg = Develop|x64 @@ -157,9 +157,9 @@ Global {8667E717-B97C-935C-4FDF-2878AF7DAB1A} = {DEDA8889-2128-69C7-B2CC-3996BC8EB907} {2DE9C09F-02F5-DF2B-524C-FFBEBC8DE5EE} = {DEDA8889-2128-69C7-B2CC-3996BC8EB907} {E44F97F7-1DE3-A547-6B58-30AEE63F22F5} = {DEDA8889-2128-69C7-B2CC-3996BC8EB907} + {2967F038-B90B-EBF5-B268-BE3BAF66D417} = {DEDA8889-2128-69C7-B2CC-3996BC8EB907} {B2999D78-279A-1A53-CC82-74D494144722} = {DEDA8889-2128-69C7-B2CC-3996BC8EB907} {7186BCD3-4393-85B5-963F-880AC8A6F795} = {DEDA8889-2128-69C7-B2CC-3996BC8EB907} - {F09883B5-7B33-D3D3-3C39-BBC8DD3B2BE2} = {DEDA8889-2128-69C7-B2CC-3996BC8EB907} {B0AEF51E-FBD8-3472-DB40-2C819DE2E5DF} = {EEE587B4-691E-274F-5E0D-3D4754F2B4B6} {F8DBBBDB-3B80-B4F1-70A8-272978F579EB} = {EEE587B4-691E-274F-5E0D-3D4754F2B4B6} {1BD1853E-C5E4-992D-7266-AA4AA72600AB} = {EEE587B4-691E-274F-5E0D-3D4754F2B4B6} diff --git a/script/basic/library.lua b/script/basic/library.lua index 2b284ec1..89d4a401 100644 --- a/script/basic/library.lua +++ b/script/basic/library.lua @@ -11,7 +11,7 @@ local librarys = { --定时器库 timer = require("ltimer"), --PB解析库 - protobuf = require("pb"), + protobuf = require("luapb"), --json库 json = require("ljson"), --bson库 diff --git a/script/basic/logger.lua b/script/basic/logger.lua index fe1ddf6b..f93d93ae 100644 --- a/script/basic/logger.lua +++ b/script/basic/logger.lua @@ -10,7 +10,7 @@ local sformat = string.format local lwarn = log.warn local lfilter = log.filter local lis_filter = log.is_filter -local serialize = codec.serialize +local serialize = luakit.serialize local LOG_LEVEL = log.LOG_LEVEL diff --git a/script/driver/influx.lua b/script/driver/influx.lua index d0d13474..5974dd29 100644 --- a/script/driver/influx.lua +++ b/script/driver/influx.lua @@ -8,7 +8,7 @@ local tconcat = table.concat local sgsub = string.gsub local sformat = string.format local ssplit = qstring.split -local serialize = codec.serialize +local serialize = luakit.serialize local http_client = quanta.get("http_client") local WEEK_S = quanta.enum("PeriodTime", "WEEK_S") diff --git a/script/driver/mongo.lua b/script/driver/mongo.lua index f132df0b..d6e6cbc4 100644 --- a/script/driver/mongo.lua +++ b/script/driver/mongo.lua @@ -27,8 +27,6 @@ local lhmac_sha1 = crypt.hmac_sha1 local lxor_byte = crypt.xor_byte local lclock_ms = timer.clock_ms -local eproto_type = luabus.eproto_type - local timer_mgr = quanta.get("timer_mgr") local event_mgr = quanta.get("event_mgr") local thread_mgr = quanta.get("thread_mgr") @@ -46,7 +44,6 @@ local prop = property(MongoDB) prop:reader("id", nil) --id prop:reader("name", "") --dbname prop:reader("user", nil) --user -prop:reader("codec", nil) --codec prop:reader("passwd", nil) --passwd prop:reader("salted_pass", nil) --salted_pass prop:reader("executer", nil) --执行者 @@ -64,7 +61,7 @@ function MongoDB:__init(conf, id) self.user = conf.user self.passwd = conf.passwd self.cursor_id = bson.int64(0) - self.codec = bson.mongo_codec() + self.codec = bson.mongocodec() self:set_options(conf.opts) self:setup_pool(conf.hosts) --attach_hour @@ -157,7 +154,7 @@ end function MongoDB:login(socket) local id, ip, port = socket.id, socket.ip, socket.port - local ok, err = socket:connect(ip, port, eproto_type.mongo) + local ok, err = socket:connect(ip, port) if not ok then log_err("[MongoDB][login] connect db(%s:%s:%s:%s) failed: %s!", ip, port, self.name, id, err) return false diff --git a/script/driver/mysql.lua b/script/driver/mysql.lua index 59052b0f..fdd9fa67 100644 --- a/script/driver/mysql.lua +++ b/script/driver/mysql.lua @@ -1,856 +1,222 @@ --mysql.lua local Socket = import("driver/socket.lua") -local QueueFIFO = import("container/queue_fifo.lua") -local tonumber = tonumber local lsha1 = crypt.sha1 -local ssub = string.sub -local srep = string.rep local sgsub = string.gsub -local sbyte = string.byte -local schar = string.char -local spack = string.pack local log_err = logger.err local log_info = logger.info local sformat = string.format -local sunpack = string.unpack -local tpack = table.pack -local tunpack = table.unpack -local tointeger = math.tointeger - +local tinsert = table.insert +local tdelete = qtable.delete +local mrandom = qmath.random +local lxor_byte = crypt.xor_byte +local qhash = codec.hash_code +local mysqlcodec = codec.mysqlcodec +local makechan = quanta.make_channel + +local event_mgr = quanta.get("event_mgr") +local timer_mgr = quanta.get("timer_mgr") local thread_mgr = quanta.get("thread_mgr") local update_mgr = quanta.get("update_mgr") +local SUCCESS = quanta.enum("KernCode", "SUCCESS") +local SECOND_MS = quanta.enum("PeriodTime", "SECOND_MS") +local SECOND_10_MS = quanta.enum("PeriodTime", "SECOND_10_MS") local DB_TIMEOUT = quanta.enum("NetwkTime", "DB_CALL_TIMEOUT") - ---charset编码 -local CHARSET_MAP = { - _default = 0, - big5 = 1, - dec8 = 3, - cp850 = 4, - hp8 = 6, - koi8r = 7, - latin1 = 8, - latin2 = 9, - swe7 = 10, - ascii = 11, - ujis = 12, - sjis = 13, - hebrew = 16, - tis620 = 18, - euckr = 19, - koi8u = 22, - gb2312 = 24, - greek = 25, - cp1250 = 26, - gbk = 28, - latin5 = 30, - armscii8 = 32, - utf8 = 33, - ucs2 = 35, - cp866 = 36, - keybcs2 = 37, - macce = 38, - macroman = 39, - cp852 = 40, - latin7 = 41, - utf8mb4 = 45, - cp1251 = 51, - utf16 = 54, - utf16le = 56, - cp1256 = 57, - cp1257 = 59, - utf32 = 60, - binary = 63, - geostd8 = 92, - cp932 = 95, - eucjpms = 97, - gb18030 = 248 -} +local POOL_COUNT = environ.number("QUANTA_DB_POOL_COUNT", 3) -- constants -local COM_QUERY = "\x03" -local COM_PING = "\x0e" -local COM_STMT_PREPARE = "\x16" -local COM_STMT_EXECUTE = "\x17" -local COM_STMT_CLOSE = "\x19" -local COM_STMT_RESET = "\x1a" - -local CURSOR_TYPE_NO_CURSOR = 0x00 -local SERVER_MORE_RESULTS_EXISTS = 8 - --- mysql field value type converters -local converters = { - [0x01] = tonumber, -- tiny - [0x02] = tonumber, -- short - [0x03] = tonumber, -- long - [0x04] = tonumber, -- float - [0x05] = tonumber, -- double - [0x08] = tonumber, -- long long - [0x09] = tonumber, -- int24 - [0x0d] = tonumber, -- year - [0xf6] = tonumber, -- newdecimal -} - -local function _get_int1(data, pos, signed) - return sunpack(signed and "= 0 and first <= 250 then - return first, pos + 1 - end - if first == 251 then - return nil, pos + 1 - end - if first == 252 then - return sunpack(" 0 then - local types_buf, values_buf = "", "" - local field_index, null_count = 1, (arg_num + 7) // 8 - for i = 1, null_count do - local byte = 0 - for j = 0, 7 do - if field_index < arg_num then - local bit = args[field_index] and 0 or 1 - byte = byte | (bit << j) - end - field_index = field_index + 1 - end - cmd_packet = cmd_packet .. schar(byte) - end - for i, v in ipairs(args) do - local f = store_types[type(v)] - if not f then - return false, sformat("invalid parameter %s, type:%s", v, type(v)) - end - local ts, vs = f(v) - types_buf = types_buf .. ts - values_buf = values_buf .. vs - end - cmd_packet = cmd_packet .. schar(0x01) .. types_buf .. values_buf - end - return self:_compose_packet(cmd_packet) -end - ---ok报文 -local function _parse_ok_packet(packet) - --1 byte 0x00报文标志(不处理) - --1-9 byte 受影响行数 - local affrows, pos_aff = _from_length_coded_bin(packet, 2) - --1-9 byte 索引ID值 - local index, pos_idx = _from_length_coded_bin(packet, pos_aff) - --2 byte 服务器状态 - local status, pos_state = sunpack(" 2 then - null_fields[field_idx - 2] = (sbyte(packet, i) & (1 << j) ~= 0) - end - field_idx = field_idx + 1 - end - end - for i, field in ipairs(fields) do - if not null_fields[i] then - local value - local parser = _binary_parser[field.typ] - value, pos = parser(packet, pos, field.signed) - if not field.ignore then - row[field.name] = value - end - end - end - return row -end - -local function _parse_not_data_packet(packet, typ) - if typ == "ERR" then - return nil, _parse_err_packet(packet) - end - if typ == "OK" then - return _parse_ok_packet(packet) - end - return nil, "packet type " .. typ .. " not supported" -end - -local function _parae_packet_type(buff) - if not buff or #buff == 0 then - return nil, "empty packet" - end - local typ = "DATA" - local field_count = sbyte(buff, 1) - if field_count == 0x00 then - typ = "OK" - elseif field_count == 0xff then - typ = "ERR" - elseif field_count == 0xfe then - typ = "EOF" - end - return typ -end - -local function _recv_field_resp(context, packet, typ) - if typ == "EOF" then - return true - end - if typ == "DATA" then - return true, _parse_field_packet(packet) - end - return _parse_not_data_packet(packet, typ) -end - -local function _recv_rows_resp(context, packet, typ, fields, binary) - if typ == "EOF" then - local _, status_flags = _parse_eof_packet(packet) - if status_flags & SERVER_MORE_RESULTS_EXISTS ~= 0 then - return true, "again" - end - return true - end - if typ == "DATA" then - if binary then - return true, _parse_row_data_binary(packet, fields) - end - return true, _parse_row_data_packet(packet, fields) - end - return _parse_not_data_packet(packet, typ) -end - ---result_set报文 -local function _parse_result_set_packet(context, packet, ignores, binary) - --Result Set Header - --1-9 byte Field结构计数field_count - local _, pos_field = _from_length_coded_bin(packet, 1) - --1-9 byte 额外信息 - local _ = _from_length_coded_bin(packet, pos_field) - -- Field结构 - local fields = {} - while true do - local ok, field = _async_call(context, "recv field packet", _recv_field_resp) - if not ok then - return nil, field - end - if not field then - break - end - field.ignore = ignores and ignores[field.name] - fields[#fields + 1] = field - end - -- Row Data - local rows = {} - while true do - local rok, row = _async_call(context, "recv row packet", _recv_rows_resp, fields, binary) - if not rok then - return nil, row - end - if not row then - break - end - if row == "again" then - return rows, row - end - rows[#rows + 1] = row - end - return rows -end - -local function _recv_result_set_resp(context, packet, typ, ignores, binary) - if typ == "DATA" then - local rows, rerr = _parse_result_set_packet(context, packet, ignores, binary) - if not rows then - return nil, rerr - end - return rows, rerr - end - return _parse_not_data_packet(packet, typ) -end - -local function _recv_query_resp(context, packet, typ, ignores, binary) - local res, err = _recv_result_set_resp(context, packet, typ, ignores, binary) - if not res then - return false, err - end - if err ~= "again" then - return true, res - end - local multiresultset = { res } - while err == "again" do - res, err = _async_call(context, "recv resultset packet", _recv_result_set_resp, ignores, binary) - if not res then - return false, err - end - multiresultset[#multiresultset + 1] = res - end - multiresultset.multiresultset = true - return true, multiresultset -end - -local function _compute_auth_token(plugin, passwd, scramble) - if plugin == "mysql_native_password" then - return true, _compute_token(passwd, scramble), true - end - return false, "only mysql_native_password is supported" -end - -local function _recv_auth_resp(context, packet, typ, passwd) - if typ == "ERR" then - return false, _parse_err_packet(packet) - end - if typ == "EOF" then - if #packet == 1 then - return false, "old pre-4.1 authentication protocol not supported" - end - local plugin, pos = _from_cstring(packet, 2) - if not plugin then - return false, "malformed packet" - end - local scramble = ssub(packet, pos); - return _compute_auth_token(plugin, passwd, scramble) - end - return true, packet -end - -local function _recv_prepare_resp(context, packet, typ) - if typ == "ERR" then - return false, _parse_err_packet(packet) - end - --第一节只能是OK - if typ ~= "OK" then - return false, sformat("first typ must be OK, now %s[no:300201]", typ) - end - local prepare_id, field_count, param_count, warning_count = sunpack(" 0 then - while true do - local ok, field = _async_call(context, "recv field packet", _recv_field_resp) - if not ok then - return false, field - end - if not field then - break - end - params[#params + 1] = field - end - end - if field_count > 0 then - while true do - local ok, field = _async_call(context, "recv field packet", _recv_field_resp) - if not ok then - return false, field - end - if not field then - break - end - fields[#fields + 1] = field - end - end - return true, { params = params, fields = fields, prepare_id = prepare_id, - field_count = field_count, param_count = param_count, warning_count = warning_count - } -end +local COM_QUERY = 0x03 +local COM_CONNECT = 0x0b +local COM_PING = 0x0e +local COM_STMT_PREPARE = 0x16 +local COM_STMT_EXECUTE = 0x17 +local COM_STMT_CLOSE = 0x19 +local COM_STMT_RESET = 0x1a local MysqlDB = class() local prop = property(MysqlDB) prop:reader("id", nil) --id -prop:reader("ip", nil) --mysql地址 -prop:reader("sock", nil) --网络连接对象 prop:reader("name", "") --dbname -prop:reader("port", 3306) --mysql端口 prop:reader("user", nil) --user prop:reader("passwd", nil) --passwd -prop:reader("packet_no", 0) --passwd -prop:reader("sessions", nil) --sessions -prop:accessor("charset", "_default") --charset -prop:accessor("max_packet_size", 1024*1024) --max_packet_size,1mb +prop:reader("executer", nil) --执行者 +prop:reader("timer_id", nil) --timer_id +prop:reader("connections", {}) --connections +prop:reader("alives", {}) --alives function MysqlDB:__init(conf, id) self.id = id self.name = conf.db self.user = conf.user self.passwd = conf.passwd - self.sessions = QueueFIFO() - self.sock = Socket(self) - self:choose_host(conf.hosts) + --setup + self:set_options(conf.opts) + self:setup_pool(conf.hosts) --update update_mgr:attach_hour(self) - update_mgr:attach_second(self) end function MysqlDB:__release() self:close() end -function MysqlDB:choose_host(hosts) - for host, port in pairs(hosts) do - self.ip, self.port = host, port - break +function MysqlDB:close() + for _, sock in pairs(self.alives) do + sock:close() + end + for _, sock in pairs(self.connections) do + sock:close() end + self.connections = {} + self.alives = {} end function MysqlDB:set_options(opts) end -function MysqlDB:close() - if self.sock then - self.sessions:clear() - self.sock:close() +function MysqlDB:set_executer(id) + local count = #self.alives + if count > 0 then + local index = qhash(id or mrandom(), count) + self.executer = self.alives[index] + return true end + return false end function MysqlDB:on_hour() - if self.sock:is_alive() then - self:ping() + for _, sock in pairs(self.alives) do + self.executer = sock + self:request(COM_PING, "mysql ping") end end -function MysqlDB:on_second() - if not self.sock:is_alive() then - if not self.sock:connect(self.ip, self.port) then - log_err("[MysqlDB][on_second] connect db(%s:%s:%s) failed!", self.ip, self.port, self.name) - return - end - local ok, err, ver = self:auth() - if not ok then - log_err("[MysqlDB][on_second] auth db(%s:%d:%s) failed! because: %s", self.ip, self.port, self.name, err) - return +function MysqlDB:setup_pool(hosts) + if not next(hosts) then + log_err("[MysqlDB][setup_pool] mysql config err: hosts is empty") + return + end + local count = 1 + for _, host in pairs(hosts) do + for c = 1, POOL_COUNT do + local socket = Socket(self, host[1], host[2]) + self.connections[count] = socket + socket:set_id(count) + count = count + 1 end - log_info("[MysqlDB][on_second] connect db(%s:%d-%s[%s]) success!", self.ip, self.port, self.name, ver) end + self.timer_id = timer_mgr:register(0, SECOND_MS, -1, function() + self:check_alive() + end) end -function MysqlDB:auth() - if not self.passwd or not self.user or not self.name then - return false, "user or password or dbname not config!" - end - local context = { cmd = "auth" } - self.sessions:push(context) - local ok, packet = _async_call(context, "recv auth packet", _recv_auth_resp) - if not ok then - return false, packet - end - --1 byte 协议版本号 (服务器认证报文开始)(skip) - --n byte 服务器版本号 - local version, pos = _from_cstring(packet, 2) - if not version then - return false, "bad handshake initialization packet: bad server version" - end - --4 byte thread_id (skip) - pos = pos + 4 - --8 byte 挑战随机数1 - local scramble1 = ssub(packet, pos, pos + 8 - 1) - if not scramble1 then - return false, "1st part of scramble not found" +function MysqlDB:check_alive() + if next(self.connections) then + thread_mgr:entry(self:address(), function() + local channel = makechan("check mysql") + for _, sock in pairs(self.connections) do + channel:push(function() + return self:login(sock) + end) + end + if channel:execute(true) then + timer_mgr:set_period(self.timer_id, SECOND_10_MS) + end + self:set_executer() + end) end - --1 byte 填充值 (skip) - --2 byte server_capabilities (skip) - --1 byte server_lang (skip) - --2 byte server_status (skip) - --2 byte server_capabilities high (skip) - --1 byte 挑战长度 (未使用) (skip) - --10 byte 填充值 (skip) - pos = pos + 8 + 1 + 2 + 1 + 2 + 2 + 1 + 10 - --12 byte 挑战随机数2 - local scramble2 = ssub(packet, pos, pos + 12 - 1) - if not scramble2 then - return false, "2nd part of scramble not found" +end + +function MysqlDB:login(socket) + local id, ip, port = socket.id, socket.ip, socket.port + if not socket:connect(ip, port) then + log_err("[MysqlDB][login] connect db(%s:%s:%s) failed!", ip, port, id) + return false end - --1 byte 挑战数结束(服务器认证报文结束)(skip) - --n byte plugin - local plugin = "mysql_native_password" - if #packet > pos + 12 then - plugin = _from_cstring(packet, pos + 13) + local ok, res = self:auth(socket) + if not ok then + socket:close() + self:delive(socket) + log_err("[MysqlDB][login] auth db(%s:%s:%s) auth failed! because: %s", ip, port, id, res) + return false end - --客户端认证报文 - --2 byte 客户端权能标志 - --2 byte 客户端权能标志扩展 - local client_flags = 260047 - --4 byte 最大消息长度 - local packet_size = self.max_packet_size - --1 byte 字符编码 - local charset = schar(CHARSET_MAP[self.charset]) - --23 byte 填充值 - local fuller = srep("\0", 23) - --n byte 用户名 - --n byte 挑战认证数据(scramble1+scramble2+passwd) + self.connections[id] = nil + tinsert(self.alives, socket) + log_info("[MysqlDB][login] login db(%s:%s:%s) success!", ip, port, id) + return true, SUCCESS +end + +function MysqlDB:auth(socket) + local session_id = thread_mgr:build_session_id() + socket:set_codec(mysqlcodec(session_id)) + local charset, scramble1, scramble2 = thread_mgr:yield(session_id, "mysql server auth", DB_TIMEOUT) local scramble = scramble1 .. scramble2 - local tok, token = _compute_auth_token(plugin, self.passwd, scramble) - if not tok then - return false, token - end - --n byte 数据库名(可选) - local req = spack(" 0 then - bdata = sock:peek(length, 4) - if not bdata then - break - end - end - sock:pop(4 + length) - --收到一个完整包 - local context = self.sessions:head() - if context then - thread_mgr:fork(function() - local callback = context.callback - local session_id = context.session_id - local typ, err = _parae_packet_type(bdata) - if not typ then - if session_id == context.commit_id then - self.sessions:pop() - end - thread_mgr:response(session_id, false, err) - return - end - local result = tpack(callback(context, bdata, typ, tunpack(context.args))) - if session_id == context.commit_id then - self.sessions:pop() - end - thread_mgr:response(session_id, tunpack(result)) - end) - end +function MysqlDB:on_socket_recv(socket, session_id, ...) + if session_id > 0 then + thread_mgr:response(session_id, ...) end end -function MysqlDB:request(packet, callback, quote, param) - if not self.sock:send(packet) then - return false, "send request failed" +function MysqlDB:request(cmd, quote, ...) + if self.executer then + local session_id = thread_mgr:build_session_id() + if self.executer:send_data(cmd, session_id, ...) then + return thread_mgr:yield(session_id, quote, DB_TIMEOUT) + end end - local context = { cmd = quote } - self.sessions:push(context) - return _async_call(context, quote, callback, param) + return false, "send request failed" end -function MysqlDB:query(query, ignores) - self.packet_no = -1 - log_info("[MysqlDB][query] sql: %s", query) - local querypacket = self:_compose_packet(COM_QUERY .. query) - return self:request(querypacket, _recv_query_resp, "mysql_query", ignores) +function MysqlDB:query(query) + return self:request(COM_QUERY, "mysql query", query) end -- 注册预处理语句 function MysqlDB:prepare(sql) - self.packet_no = -1 - local querypacket = self:_compose_packet(COM_STMT_PREPARE .. sql) - return self:request(querypacket, _recv_prepare_resp, "mysql_prepare") + return self:request(COM_STMT_PREPARE, "mysql prepare", sql) end --执行预处理语句 -function MysqlDB:execute(stmt, ...) - self.packet_no = -1 - local querypacket, err = _compose_stmt_execute(self, stmt, CURSOR_TYPE_NO_CURSOR, {...}) - if not querypacket then - return false, sformat("%s[no:30902]", err) - end - return self:request(querypacket, _recv_query_resp, "mysql_execute") +function MysqlDB:execute(prepare_id, ...) + return self:request(COM_STMT_EXECUTE, "mysql_execute", prepare_id, ...) end --重置预处理句柄 function MysqlDB:stmt_reset(prepare_id) - self.packet_no = -1 - local cmd_packet = spack("c1 0 then local ok, res = self:auth(socket) if not ok or res ~= "OK" then @@ -287,36 +286,27 @@ function RedisDB:on_socket_error(sock, token, err) event_mgr:fire_second(function() self:check_alive() end) - local cmd_queue = sock.cmd_queue - local session_id = cmd_queue:pop() - while session_id do - if session_id > 0 then - thread_mgr:response(session_id, false, err) - end - session_id = cmd_queue:pop() - end end -function RedisDB:on_socket_recv(sock, succ, res) +function RedisDB:on_socket_recv(sock, session_id, succ, res) if self.subscrible then self:do_socket_recv(res) end - local session_id = sock.cmd_queue:pop() - if session_id and session_id > 0 then + if session_id > 0 then self.res_counter:count_increase() thread_mgr:response(session_id, succ, res) end end -function RedisDB:wait_response(socket, session_id, cmd, ...) - if not socket:send_data(cmd, ...) then +function RedisDB:commit(socket, cmd, ...) + local session_id = thread_mgr:build_session_id() + if not socket:send_data(session_id, cmd, ...) then return false, "send request failed" end self.req_counter:count_increase() - socket.cmd_queue:push(session_id) local ok, res = thread_mgr:yield(session_id, sformat("redis_comit:%s", cmd), DB_TIMEOUT) if not ok then - log_err("[RedisDB][wait_response] exec cmd %s failed: %s", cmd, res) + log_err("[RedisDB][commit] exec cmd %s failed: %s", cmd, res) return ok, res end local convertor = rconvertors[slower(cmd)] @@ -326,22 +316,15 @@ function RedisDB:wait_response(socket, session_id, cmd, ...) return ok, res end -function RedisDB:commit(sock, cmd, ...) - local session_id = thread_mgr:build_session_id() - return self:wait_response(sock, session_id, cmd, ...) -end - function RedisDB:send(cmd, key, ...) local sock = self:choose_node(key) - if sock and sock:send_data(cmd, key, ...) then - sock.cmd_queue:push(0) + if sock then + sock:send_data(0, cmd, key, ...) end end function RedisDB:direct_send(sock, cmd, ...) - if sock:send_data(cmd, ...) then - sock.cmd_queue:push(0) - end + sock:send_data(0, cmd, ...) end function RedisDB:execute(cmd, key, ...) diff --git a/script/driver/redisps.lua b/script/driver/redisps.lua index 0dedf2aa..7053c160 100644 --- a/script/driver/redisps.lua +++ b/script/driver/redisps.lua @@ -6,7 +6,6 @@ local event_mgr = quanta.get("event_mgr") local Redis = import("driver/redis.lua") local Socket = import("driver/socket.lua") -local QueueFIFO = import("container/queue_fifo.lua") local subscribe_commands = { subscribe = { cmd = "SUBSCRIBE" }, -- >= 2.0 @@ -61,7 +60,6 @@ function PSRedis:setup_pool(hosts) for _, host in pairs(hosts) do local socket = Socket(self, host[1], host[2]) self.connections[1] = socket - socket.cmd_queue = QueueFIFO() socket:set_id(1) break end diff --git a/script/driver/socket.lua b/script/driver/socket.lua index 20da12d8..21519d4c 100644 --- a/script/driver/socket.lua +++ b/script/driver/socket.lua @@ -1,10 +1,10 @@ --socket.lua -local log_err = logger.err -local log_info = logger.info -local qxpcall = quanta.xpcall +local log_err = logger.err +local log_info = logger.info +local qxpcall = quanta.xpcall -local eproto_type = luabus.eproto_type +local proto_text = luabus.eproto_type.text local socket_mgr = quanta.get("socket_mgr") local thread_mgr = quanta.get("thread_mgr") @@ -17,6 +17,7 @@ local prop = property(Socket) prop:reader("ip", nil) prop:reader("port", 0) prop:reader("host", nil) +prop:reader("codec", nil) prop:reader("token", nil) prop:reader("alive", false) prop:reader("session", nil) --连接成功对象 @@ -38,21 +39,21 @@ function Socket:close() self.session.close() self.alive = false self.session = nil + self.codec = nil self.token = nil end end -function Socket:listen(ip, port, ptype) +function Socket:listen(ip, port) if self.listener then return true end - self.listener = socket_mgr.listen(ip, port) + self.listener = socket_mgr.listen(ip, port, proto_text) if not self.listener then log_err("[Socket][listen] failed to listen: %s:%d", ip, port) return false end self.ip, self.port = ip, port - self.listener.set_proto_type(ptype or eproto_type.text) log_info("[Socket][listen] start listen at: %s:%d", ip, port) self.listener.on_accept = function(session) qxpcall(self.on_socket_accept, "on_socket_accept: %s", self, session, ip, port) @@ -62,8 +63,13 @@ end function Socket:set_codec(codec) if self.session then + self.codec = codec self.session.set_codec(codec) end + if self.listener then + self.codec = codec + self.listener.set_codec(codec) + end end function Socket:connect(ip, port, ptype) @@ -73,12 +79,11 @@ function Socket:connect(ip, port, ptype) end return false, "socket in connecting" end - local session, cerr = socket_mgr.connect(ip, port, CONNECT_TIMEOUT) + local session, cerr = socket_mgr.connect(ip, port, CONNECT_TIMEOUT, proto_text) if not session then - log_err("[Socket][connect] failed to connect: %s:%d err=%s", ip, port, cerr) + log_err("[Socket][connect] failed to connect: %s:%s err=%s", ip, port, cerr) return false, cerr end - session.set_proto_type(ptype or eproto_type.text) --设置阻塞id local token = session.token local block_id = thread_mgr:build_session_id() diff --git a/script/driver/websocket.lua b/script/driver/websocket.lua index 2a51bfc1..f0fca908 100644 --- a/script/driver/websocket.lua +++ b/script/driver/websocket.lua @@ -5,9 +5,12 @@ local log_info = logger.info local log_debug = logger.debug local lsha1 = crypt.sha1 local lb64encode = crypt.b64_encode +local jsoncodec = json.jsoncodec +local wsscodec = codec.wsscodec +local httpcodec = codec.httpcodec local qxpcall = quanta.xpcall -local eproto_type = luabus.eproto_type +local proto_text = luabus.eproto_type.text local socket_mgr = quanta.get("socket_mgr") local thread_mgr = quanta.get("thread_mgr") @@ -19,6 +22,7 @@ local prop = property(WebSocket) prop:reader("ip", nil) prop:reader("host", nil) prop:reader("token", nil) +prop:reader("jcodec", nil) --codec prop:reader("wcodec", nil) --codec prop:reader("hcodec", nil) --codec prop:reader("alive", false) @@ -28,9 +32,9 @@ prop:reader("port", 0) function WebSocket:__init(host) self.host = host - local jcodec = json.jsoncodec() - self.wcodec = codec.wsscodec(jcodec) - self.hcodec = codec.httpcodec(jcodec) + self.jcodec = jsoncodec() + self.wcodec = wsscodec(self.jcodec) + self.hcodec = httpcodec(self.jcodec) end function WebSocket:close() @@ -49,17 +53,18 @@ function WebSocket:listen(ip, port, ptype) if self.listener then return true end - self.listener = socket_mgr.listen(ip, port) - if not self.listener then + local listener = socket_mgr.listen(ip, port, proto_text) + if not listener then log_err("[WebSocket][listen] failed to listen: %s:%d", ip, port) return false end - self.ip, self.port = ip, port - self.listener.set_proto_type(ptype or eproto_type.text) + listener.set_codec(self.hcodec) log_info("[WebSocket][listen] start listen at: %s:%d", ip, port) - self.listener.on_accept = function(session) + listener.on_accept = function(session) qxpcall(self.on_socket_accept, "on_socket_accept: %s", self, session, ip, port) end + self.ip, self.port = ip, port + self.listener = listener return true end @@ -91,7 +96,7 @@ function WebSocket:on_socket_recv(session, token, opcode, message) self:send_frame(0xA, "PONG") return end - if opcode <= 0X02 then + if opcode <= 0x02 then self.host:on_socket_recv(self, token, message) end end) @@ -112,7 +117,6 @@ function WebSocket:accept(session, ip, port) session.on_error = function(stoken, err) self:on_socket_error(stoken, err) end - session.set_codec(self.hcodec) self.ip, self.port = ip, port end @@ -149,7 +153,7 @@ function WebSocket:on_handshake(session, token, url, params, headers, body) self.session = session self:send_data(101, cbheaders, "") self.host:on_socket_accept(self, token) - self.session.set_codec(self.wcodec) + session.set_codec(self.wcodec) return true end diff --git a/script/network/http_server.lua b/script/network/http_server.lua index 9c3d88db..29bc9971 100644 --- a/script/network/http_server.lua +++ b/script/network/http_server.lua @@ -10,12 +10,15 @@ local log_debug = logger.debug local tunpack = table.unpack local signalquit = signal.quit local saddr = qstring.addr +local jsoncodec = json.jsoncodec +local httpcodec = codec.httpcodec local HttpServer = class() local prop = property(HttpServer) prop:reader("ip", nil) --http server地址 prop:reader("port", 8080) --http server端口 -prop:reader("codec", nil) --codec +prop:reader("hcodec", nil) --codec +prop:reader("jcodec", nil) --codec prop:reader("listener", nil) --网络连接对象 prop:reader("clients", {}) --clients prop:reader("get_handlers", {}) --get_handlers @@ -24,8 +27,8 @@ prop:reader("del_handlers", {}) --del_handlers prop:reader("post_handlers", {}) --post_handlers function HttpServer:__init(http_addr) - local jcodec = json.jsoncodec() - self.codec = codec.httpcodec(jcodec) + self.jcodec = jsoncodec() + self.hcodec = httpcodec(self.jcodec) self:setup(http_addr) end @@ -37,6 +40,7 @@ function HttpServer:setup(http_addr) signalquit(1) return end + socket:set_codec(self.hcodec) log_info("[HttpServer][setup] listen(%s:%s) success!", self.ip, self.port) self.listener = socket end @@ -52,14 +56,13 @@ function HttpServer:on_socket_error(socket, token, err) self.listener = nil return end - log_debug("[HttpServer][on_socket_error] client(token:%s) close!", token) + log_debug("[HttpServer][on_socket_error] client(token:%s) close(%s)!", token, err) self.clients[token] = nil end function HttpServer:on_socket_accept(socket, token) --log_debug("[HttpServer][on_socket_accept] client(token:%s) connected!", token) self.clients[token] = socket - socket:set_codec(self.codec) end function HttpServer:on_socket_recv(socket, method, url, params, headers, body) diff --git a/script/network/net_client.lua b/script/network/net_client.lua index 1472ab15..fa83b310 100644 --- a/script/network/net_client.lua +++ b/script/network/net_client.lua @@ -4,25 +4,15 @@ local log_err = logger.err local log_fatal = logger.fatal local qeval = quanta.eval local qxpcall = quanta.xpcall -local env_status = environ.status -local b64_encode = crypt.b64_encode -local b64_decode = crypt.b64_decode -local lz4_encode = crypt.lz4_encode -local lz4_decode = crypt.lz4_decode -local lcrc8 = codec.crc8 local event_mgr = quanta.get("event_mgr") local socket_mgr = quanta.get("socket_mgr") local thread_mgr = quanta.get("thread_mgr") -local protobuf_mgr = quanta.get("protobuf_mgr") local proxy_agent = quanta.get("proxy_agent") -local out_press = env_status("QUANTA_OUT_PRESS") -local out_encrypt = env_status("QUANTA_OUT_ENCRYPT") +local proto_pb = luabus.eproto_type.pb local FLAG_REQ = quanta.enum("FlagMask", "REQ") -local FLAG_ZIP = quanta.enum("FlagMask", "ZIP") -local FLAG_ENCRYPT = quanta.enum("FlagMask", "ENCRYPT") local CONNECT_TIMEOUT = quanta.enum("NetwkTime", "CONNECT_TIMEOUT") local RPC_CALL_TIMEOUT = quanta.enum("NetwkTime", "RPC_CALL_TIMEOUT") @@ -30,17 +20,17 @@ local NetClient = class() local prop = property(NetClient) prop:reader("ip", nil) prop:reader("port", nil) +prop:reader("codec", nil) prop:reader("alive", false) prop:reader("socket", nil) --连接成功对象 prop:reader("holder", nil) --持有者 prop:reader("wait_list", {}) --等待协议列表 -prop:accessor("codec", nil) --编解码器 function NetClient:__init(holder, ip, port) self.ip = ip self.port = port self.holder = holder - self.codec = protobuf_mgr + self.codec = protobuf.pbcodec("ncmd_cs", "ncmd_cs.NCmdId") end -- 发起连接 @@ -48,15 +38,15 @@ function NetClient:connect(block) if self.socket then return true end - local socket, cerr = socket_mgr.connect(self.ip, self.port, CONNECT_TIMEOUT) + local socket, cerr = socket_mgr.connect(self.ip, self.port, CONNECT_TIMEOUT, proto_pb) if not socket then log_err("[NetClient][connect] failed to connect: %s:%s err=%s", self.ip, self.port, cerr) return false, cerr end --设置阻塞id - socket.set_proto_type(luabus.eproto_type.head); local block_id = block and thread_mgr:build_session_id() -- 调用成功,开始安装回调函数 + socket.set_codec(self.codec) socket.on_connect = function(res) local success = (res == "ok") thread_mgr:fork(function() @@ -71,9 +61,9 @@ function NetClient:connect(block) thread_mgr:response(block_id, success, res) end end - socket.on_call_head = function(recv_len, cmd_id, flag, type, crc8, session_id, slice) + socket.on_call_pb = function(recv_len, session_id, cmd_id, flag, type, crc8, body) proxy_agent:statistics("on_proto_recv", cmd_id, recv_len) - qxpcall(self.on_socket_rpc, "on_socket_rpc: %s", self, socket, cmd_id, flag, type, session_id, slice) + qxpcall(self.on_socket_rpc, "on_socket_rpc: %s", self, socket, cmd_id, flag, type, session_id, body) end socket.on_error = function(token, err) thread_mgr:fork(function() @@ -92,47 +82,12 @@ function NetClient:get_token() return self.socket and self.socket.token end -function NetClient:encode(cmd, data, flag) - local en_data, cmd_id = self.codec:encode(cmd, data) - if not en_data then - return - end - -- 加密处理 - if out_encrypt then - en_data = b64_encode(en_data) - flag = flag | FLAG_ENCRYPT - end - -- 压缩处理 - if out_press then - en_data = lz4_encode(en_data) - flag = flag | FLAG_ZIP - end - return en_data, cmd_id, flag -end - -function NetClient:decode(cmd_id, buff, flag) - if flag & FLAG_ZIP == FLAG_ZIP then - --解压处理 - buff = lz4_decode(buff) - end - if flag & FLAG_ENCRYPT == FLAG_ENCRYPT then - --解密处理 - buff = b64_decode(buff) - end - return self.codec:decode(cmd_id, buff) -end - -function NetClient:on_socket_rpc(socket, cmd_id, flag, type, session_id, buff) - local body, cmd_name = self:decode(cmd_id, buff, flag) - if not body then - log_err("[NetClient][on_socket_rpc] decode failed! cmd_id:%s", cmd_id) - return - end +function NetClient:on_socket_rpc(socket, cmd_id, flag, type, session_id, body) event_mgr:notify_trigger("on_message_recv", cmd_id, body) if session_id == 0 or (flag & FLAG_REQ == FLAG_REQ) then -- 执行消息分发 local function dispatch_rpc_message() - local _ = qeval(cmd_name) + local _ = qeval(cmd_id) self.holder:on_socket_rpc(self, cmd_id, body, session_id) end thread_mgr:fork(dispatch_rpc_message) @@ -167,9 +122,9 @@ function NetClient:write(cmd, data, type, session_id, flag) return false end -- call lbus - local send_len = self.socket.call_head(cmd_id, pflag, type or 0, lcrc8(body), session_id or 0, body, #body) + local send_len = self.socket.call_pb(cmd_id, pflag, type, session_id, data) if send_len < 0 then - log_err("[NetClient][write] call_head failed! code:%s", send_len) + log_err("[NetClient][write] call_pb failed! code:%s", send_len) return false end proxy_agent:statistics("on_proto_send", cmd_id, send_len) @@ -178,7 +133,7 @@ end -- 发送数据 function NetClient:send(cmd_id, data, type) - return self:write(cmd_id, data, type, 0, FLAG_REQ) + return self:write(cmd_id, data, type or 0, 0, FLAG_REQ) end -- 发起远程命令 @@ -187,7 +142,7 @@ function NetClient:call(cmd_id, data, type) return false end local session_id = self.socket.build_session_id() - if not self:write(cmd_id, data, type, session_id, FLAG_REQ) then + if not self:write(cmd_id, data, type or 0, session_id, FLAG_REQ) then return false end return thread_mgr:yield(session_id, cmd_id, RPC_CALL_TIMEOUT) diff --git a/script/network/net_server.lua b/script/network/net_server.lua index 1907b2b3..bf252eeb 100644 --- a/script/network/net_server.lua +++ b/script/network/net_server.lua @@ -7,10 +7,8 @@ local log_fatal = logger.fatal local signalquit = signal.quit local qeval = quanta.eval local qxpcall = quanta.xpcall -local b64_encode = crypt.b64_encode -local b64_decode = crypt.b64_decode -local lz4_encode = crypt.lz4_encode -local lz4_decode = crypt.lz4_decode + +local proto_pb = luabus.eproto_type.pb local event_mgr = quanta.get("event_mgr") local thread_mgr = quanta.get("thread_mgr") @@ -20,15 +18,11 @@ local proxy_agent = quanta.get("proxy_agent") local FLAG_REQ = quanta.enum("FlagMask", "REQ") local FLAG_RES = quanta.enum("FlagMask", "RES") -local FLAG_ZIP = quanta.enum("FlagMask", "ZIP") -local FLAG_ENCRYPT = quanta.enum("FlagMask", "ENCRYPT") local NETWORK_TIMEOUT = quanta.enum("NetwkTime", "NETWORK_TIMEOUT") local FAST_MS = quanta.enum("PeriodTime", "FAST_MS") local SECOND_MS = quanta.enum("PeriodTime", "SECOND_MS") local TOO_FAST = quanta.enum("KernCode", "TOO_FAST") -local OUT_PRESS = environ.status("QUANTA_OUT_PRESS") -local OUT_ENCRYPT = environ.status("QUANTA_OUT_ENCRYPT") local FLOW_CTRL = environ.status("QUANTA_FLOW_CTRL") local FC_PACKETS = environ.number("QUANTA_FLOW_CTRL_PACKAGE") local FC_BYTES = environ.number("QUANTA_FLOW_CTRL_BYTES") @@ -42,11 +36,12 @@ prop:reader("sessions", {}) --会话列表 prop:reader("session_type", "default") --会话类型 prop:reader("session_count", 0) --会话数量 prop:reader("listener", nil) --监听器 -prop:accessor("codec", nil) --编解码器 +prop:reader("broadcast_token", nil) --监听器 +prop:reader("codec", nil) --编解码器 function NetServer:__init(session_type) self.session_type = session_type - self.codec = protobuf_mgr + self.codec = protobuf.pbcodec("ncmd_cs", "ncmd_cs.NCmdId") end --induce:根据 order 推导port @@ -58,19 +53,21 @@ function NetServer:setup(ip, port, induce) return end local real_port = induce and (port + quanta.order - 1) or port - self.listener = socket_mgr.listen(ip, real_port) - if not self.listener then + local listener = socket_mgr.listen(ip, real_port, proto_pb) + if not listener then log_err("[NetServer][setup] failed to listen: %s:%d", ip, real_port) signalquit() return end - self.ip, self.port = ip, real_port log_info("[NetServer][setup] start listen at: %s:%d", ip, real_port) -- 安装回调 - self.listener.set_proto_type(luabus.eproto_type.head) - self.listener.on_accept = function(session) + listener.set_codec(self.codec) + listener.on_accept = function(session) qxpcall(self.on_socket_accept, "on_socket_accept: %s", self, session) end + self.listener = listener + self.ip, self.port = ip, real_port + self.broadcast_token = listener.token end -- 连接回调 @@ -87,14 +84,14 @@ function NetServer:on_socket_accept(session) self:add_session(session) -- 绑定call回调 session.call_client = function(cmd_id, flag, session_id, body) - local send_len = session.call_head(cmd_id, flag, 0, 0, session_id, body, #body) + local send_len = session.call_pb(session_id, cmd_id, flag, 0, 0, body) if send_len <= 0 then - log_err("[NetServer][call_client] call_head failed! code:%s", send_len) + log_err("[NetServer][call_client] call_pb failed! code:%s", send_len) return false end return true end - session.on_call_head = function(recv_len, cmd_id, flag, type, crc8, session_id, slice) + session.on_call_pb = function(recv_len, session_id, cmd_id, flag, type, crc8, body) local now_ms = quanta.now_ms if session.lc_crc == crc8 and now_ms - session.lc_time < FAST_MS then self:callback_errcode(session, cmd_id, TOO_FAST, session_id) @@ -107,7 +104,7 @@ function NetServer:on_socket_accept(session) session.fc_bytes = session.fc_bytes + recv_len end proxy_agent:statistics("on_proto_recv", cmd_id, recv_len) - qxpcall(self.on_socket_recv, "on_socket_recv: %s", self, session, cmd_id, flag, type, session_id, slice) + qxpcall(self.on_socket_recv, "on_socket_recv: %s", self, session, cmd_id, flag, type, session_id, body) end -- 绑定网络错误回调(断开) session.on_error = function(stoken, err) @@ -122,41 +119,17 @@ function NetServer:write(session, cmd, data, session_id, flag) log_fatal("[NetServer][write] session lost! cmd_id:%s-(%s)", cmd, data) return false end - local body, cmd_id, pflag = self:encode(cmd, data, flag) - if not body then - log_fatal("[NetServer][write] encode failed! cmd_id:%s-(%s)", cmd, data) - return false - end - if session_id > 0 then - session_id = session_id & 0xffff - end - return session.call_client(cmd_id, pflag, session_id, body) + return session.call_client(cmd, flag, session_id, data) end -- 广播数据 -function NetServer:broadcast(cmd, data) - local body, cmd_id, pflag = self:encode(cmd, data, FLAG_REQ) - if not body then - log_fatal("[NetServer][broadcast] encode failed! cmd_id:%s-(%s)", cmd_id, data) - return false - end - for _, session in pairs(self.sessions) do - session.call_client(cmd_id, pflag, 0, body) - end - return true +function NetServer:broadcast(cmd_id, data) + socket_mgr.broadcast(self.codec, self.broadcast_token, cmd_id, FLAG_REQ, 0, 0, data) end -- 广播数据 -function NetServer:broadcast_groups(sessions, cmd, data) - local body, cmd_id, pflag = self:encode(cmd, data, FLAG_REQ) - if not body then - log_fatal("[NetServer][broadcast_groups] encode failed! cmd_id:%s-(%s)", cmd_id, data) - return false - end - for _, session in pairs(sessions or {}) do - session.call_client(cmd_id, pflag, 0, body) - end - return true +function NetServer:broadcast_groups(tokens, cmd_id, data) + socket_mgr.broadgroup(self.codec, tokens, cmd_id, FLAG_REQ, 0, 0, data) end -- 发送数据 @@ -182,48 +155,12 @@ function NetServer:callback_errcode(session, cmd_id, code, session_id) return self:write(session, callback_id, data, session_id or 0, FLAG_RES) end -function NetServer:encode(cmd, data, flag) - local en_data, cmd_id = self.codec:encode(cmd, data) - if not en_data then - return - end - -- 加密处理 - if OUT_ENCRYPT then - en_data = b64_encode(en_data) - flag = flag | FLAG_ENCRYPT - end - -- 压缩处理 - if OUT_PRESS then - en_data = lz4_encode(en_data) - flag = flag | FLAG_ZIP - end - return en_data, cmd_id, flag -end - -function NetServer:decode(cmd_id, buff, flag) - if flag & FLAG_ZIP == FLAG_ZIP then - --解压处理 - buff = lz4_decode(buff) - end - if flag & FLAG_ENCRYPT == FLAG_ENCRYPT then - --解密处理 - buff = b64_decode(buff) - end - return self.codec:decode(cmd_id, buff) -end - -- 收到远程调用回调 -function NetServer:on_socket_recv(session, cmd_id, flag, type, session_id, buff) - -- 解码 - local body, cmd_name = self:decode(cmd_id, buff, flag) - if not body then - log_warn("[NetServer][on_socket_rpc] decode failed! cmd_id:%s", cmd_id) - return - end +function NetServer:on_socket_recv(session, cmd_id, flag, type, session_id, body) if session_id == 0 or (flag & FLAG_REQ == FLAG_REQ) then - local function dispatch_rpc_message(_session, typ, cmd, bd) - local _ = qeval(cmd_name) - local result = event_mgr:notify_listener("on_socket_cmd", _session, typ, cmd, bd, session_id) + local function dispatch_rpc_message(_session, typ, cmd, cbody) + local _ = qeval(cmd_id) + local result = event_mgr:notify_listener("on_socket_cmd", _session, typ, cmd, cbody, session_id) if not result[1] then log_err("[NetServer][on_socket_recv] on_socket_cmd failed! cmd_id:%s", cmd_id) end diff --git a/script/network/rpc_server.lua b/script/network/rpc_server.lua index 03900d36..0f30d748 100644 --- a/script/network/rpc_server.lua +++ b/script/network/rpc_server.lua @@ -26,11 +26,11 @@ local SERVICE_MAX = 255 local RpcServer = singleton() local prop = property(RpcServer) -prop:reader("ip", "") --监听ip -prop:reader("port", 0) --监听端口 +prop:reader("ip", "") --监听ip +prop:reader("port", 0) --监听端口 prop:reader("clients", {}) prop:reader("listener", nil) -prop:reader("holder", nil) --持有者 +prop:reader("holder", nil) --持有者 --induce:根据 order 推导port function RpcServer:__init(holder, ip, port, induce) @@ -46,13 +46,13 @@ function RpcServer:__init(holder, ip, port, induce) signalquit() return end + listener.on_accept = function(client) + qxpcall(self.on_socket_accept, "on_socket_accept: %s", self, client) + end self.holder = holder self.listener = listener self.ip, self.port = ip, real_port log_info("[RpcServer][setup] now listen %s:%s success!", ip, real_port) - self.listener.on_accept = function(client) - qxpcall(self.on_socket_accept, "on_socket_accept: %s", self, client) - end event_mgr:add_listener(self, "rpc_heartbeat") event_mgr:add_listener(self, "rpc_register") end @@ -111,7 +111,7 @@ function RpcServer:on_socket_accept(client) event_mgr:notify_listener("on_transfer_rpc", client, session_id, service_id, target_id, slice) return end - event_mgr:notify_listener("on_boardcast_rpc", client, target_id, slice) + event_mgr:notify_listener("on_broadcast_rpc", client, target_id, slice) end thread_mgr:fork(dispatch_rpc_message) end @@ -166,6 +166,7 @@ function RpcServer:broadcast(rpc, ...) for _, client in pairs(self.clients) do client.call_rpc(rpc, 0, FLAG_REQ, ...) end + socket_mgr:broadgroup() end --broadcast接口,注册后才转发 diff --git a/script/store/clickhouse_mgr.lua b/script/store/clickhouse_mgr.lua index 09ac1d77..b88284f2 100644 --- a/script/store/clickhouse_mgr.lua +++ b/script/store/clickhouse_mgr.lua @@ -11,12 +11,12 @@ local MAIN_DBID = environ.number("QUANTA_DB_MAIN_ID") local ClickHouseMgr = singleton() local prop = property(ClickHouseMgr) prop:reader("clickhouse_dbs", {}) -- clickhouse_dbs -prop:reader("default_db", nil) -- default_db -prop:reader("default_id", nil) -- default_id function ClickHouseMgr:__init() self:setup() -- 注册事件 + event_mgr:add_listener(self, "rpc_clickhouse_query", "query") + event_mgr:add_listener(self, "rpc_clickhouse_prepare", "prepare") event_mgr:add_listener(self, "rpc_clickhouse_execute", "execute") end @@ -26,7 +26,7 @@ function ClickHouseMgr:setup() local drivers = environ.driver("QUANTA_MYSQL_URLS") for i, conf in ipairs(drivers) do local clickhouse_db = MysqlDB(conf, i) - self.clickhouse_dbs[conf.id] = clickhouse_db + self.clickhouse_dbs[i] = clickhouse_db end end @@ -35,12 +35,36 @@ function ClickHouseMgr:get_db(db_id) return self.clickhouse_dbs[db_id or MAIN_DBID] end -function ClickHouseMgr:execute(db_id, sql) - local clickhousedb = self:get_db(db_id) - if clickhousedb then - local ok, res_oe = clickhousedb:query(sql) +function ClickHouseMgr:query(db_id, primary_id, sql) + local clickhouse_db = self:get_db(db_id) + if clickhouse_db and clickhouse_db:set_executer(primary_id) then + local ok, res_oe = clickhouse_db:query(sql) if not ok then - log_err("[ClickHouseMgr][execute] execute %s failed, because: %s", sql, res_oe) + log_err("[ClickHouseMgr][query] query %s failed, because: %s", sql, res_oe) + end + return ok and SUCCESS or MYSQL_FAILED, res_oe + end + return MYSQL_FAILED, "clickhouse db not exist" +end + +function ClickHouseMgr:execute(db_id, primary_id, stmt, ...) + local clickhouse_db = self:get_db(db_id) + if clickhouse_db and clickhouse_db:set_executer(primary_id) then + local ok, res_oe = clickhouse_db:execute(stmt, ...) + if not ok then + log_err("[ClickHouseMgr][execute] execute %s failed, because: %s", stmt, res_oe) + end + return ok and SUCCESS or MYSQL_FAILED, res_oe + end + return MYSQL_FAILED, "clickhouse db not exist" +end + +function ClickHouseMgr:prepare(db_id, sql) + local clickhouse_db = self:get_db(db_id) + if clickhouse_db and clickhouse_db:set_executer() then + local ok, res_oe = clickhouse_db:prepare(sql) + if not ok then + log_err("[ClickHouseMgr][prepare] prepare %s failed, because: %s", sql, res_oe) end return ok and SUCCESS or MYSQL_FAILED, res_oe end diff --git a/script/store/mysql_mgr.lua b/script/store/mysql_mgr.lua index e1e838c4..c05100ad 100644 --- a/script/store/mysql_mgr.lua +++ b/script/store/mysql_mgr.lua @@ -15,6 +15,8 @@ prop:reader("mysql_dbs", {}) -- mysql_dbs function MysqlMgr:__init() self:setup() -- 注册事件 + event_mgr:add_listener(self, "rpc_mysql_query", "query") + event_mgr:add_listener(self, "rpc_mysql_prepare", "prepare") event_mgr:add_listener(self, "rpc_mysql_execute", "execute") end @@ -24,7 +26,7 @@ function MysqlMgr:setup() local drivers = environ.driver("QUANTA_MYSQL_URLS") for i, conf in ipairs(drivers) do local mysql_db = MysqlDB(conf, i) - self.mysql_dbs[conf.id] = mysql_db + self.mysql_dbs[i] = mysql_db end end @@ -33,12 +35,36 @@ function MysqlMgr:get_db(db_id) return self.mysql_dbs[db_id or MAIN_DBID] end -function MysqlMgr:execute(db_id, sql) +function MysqlMgr:query(db_id, primary_id, sql) local mysqldb = self:get_db(db_id) - if mysqldb then + if mysqldb and mysqldb:set_executer(primary_id) then local ok, res_oe = mysqldb:query(sql) if not ok then - log_err("[MysqlMgr][execute] execute %s failed, because: %s", sql, res_oe) + log_err("[MysqlMgr][query] query %s failed, because: %s", sql, res_oe) + end + return ok and SUCCESS or MYSQL_FAILED, res_oe + end + return MYSQL_FAILED, "mysql db not exist" +end + +function MysqlMgr:execute(db_id, primary_id, stmt, ...) + local mysqldb = self:get_db(db_id) + if mysqldb and mysqldb:set_executer(primary_id) then + local ok, res_oe = mysqldb:execute(stmt, ...) + if not ok then + log_err("[MysqlMgr][execute] execute %s failed, because: %s", stmt, res_oe) + end + return ok and SUCCESS or MYSQL_FAILED, res_oe + end + return MYSQL_FAILED, "mysql db not exist" +end + +function MysqlMgr:prepare(db_id, sql) + local mysqldb = self:get_db(db_id) + if mysqldb and mysqldb:set_executer() then + local ok, res_oe = mysqldb:prepare(sql) + if not ok then + log_err("[MysqlMgr][prepare] prepare %s failed, because: %s", sql, res_oe) end return ok and SUCCESS or MYSQL_FAILED, res_oe end diff --git a/server/business/component/attr_component.lua b/server/business/component/attr_component.lua index f0f66cf9..20b5d96d 100644 --- a/server/business/component/attr_component.lua +++ b/server/business/component/attr_component.lua @@ -217,7 +217,7 @@ function AttrComponent:on_attr_sync() self:send("NID_ENTITY_ATTR_UPDATE_NTF", { id = self.id, attrs = attrs }) end if self.range > 1 and next(battrs) then - self:boardcast_message("NID_ENTITY_ATTR_UPDATE_NTF", { id = self.id, attrs = battrs }) + self:broadcast_message("NID_ENTITY_ATTR_UPDATE_NTF", { id = self.id, attrs = battrs }) end self.sync_attrs = {} end diff --git a/server/cache/cache_gm.lua b/server/cache/cache_gm.lua index 1063fce5..cccf053a 100644 --- a/server/cache/cache_gm.lua +++ b/server/cache/cache_gm.lua @@ -5,7 +5,7 @@ local log_info = logger.info local log_err = logger.err local sformat = string.format local qfailed = quanta.failed -local unserialize = codec.unserialize +local unserialize = luakit.unserialize local gm_agent = quanta.get("gm_agent") local cache_mgr = quanta.get("cache_mgr") diff --git a/server/gateway/gateway.lua b/server/gateway/gateway.lua index 1af99507..4e70bb32 100644 --- a/server/gateway/gateway.lua +++ b/server/gateway/gateway.lua @@ -175,15 +175,15 @@ end --群发消息 function Gateway:rpc_groupcast_client(player_ids, cmd_id, data) - local sessions = {} + local tokens = {} for _, player_id in pairs(player_ids) do local player = self:get_player(player_id) if player then - sessions[#sessions + 1] = player:get_session() + tokens[#tokens + 1] = player:get_session_token() end end - if next(sessions) then - client_mgr:broadcast_groups(sessions, cmd_id, data) + if next(tokens) then + client_mgr:broadcast_groups(tokens, cmd_id, data) end end diff --git a/server/gateway/group_mgr.lua b/server/gateway/group_mgr.lua index 3c9f5e8b..68496853 100644 --- a/server/gateway/group_mgr.lua +++ b/server/gateway/group_mgr.lua @@ -16,12 +16,12 @@ end function GroupMgr:add_member(group_id, player_id, player) log_info("[GroupMgr][add_member] group_id(%s) player_id(%s)!", group_id, player_id) local group = self.groups[group_id] - local session = player:get_session() + local token = player:get_session_token() if not group then - self.groups[group_id] = qtweak({ [player_id] = session }) + self.groups[group_id] = qtweak({ [player_id] = token }) return end - group[player_id] = session + group[player_id] = token end --更新分组信息 @@ -35,9 +35,9 @@ end --广播消息 function GroupMgr:broadcast(group_id, cmd_id, data) - local sessions = self.groups[group_id] - if sessions then - client_mgr:broadcast_groups(sessions, cmd_id, data) + local tokens = self.groups[group_id] + if tokens then + client_mgr:broadcast_groups(tokens, cmd_id, data) end end diff --git a/server/gateway/player.lua b/server/gateway/player.lua index aeb60295..e1517af6 100644 --- a/server/gateway/player.lua +++ b/server/gateway/player.lua @@ -24,6 +24,10 @@ function GatePlayer:__init(session, open_id, player_id) self.player_id = player_id end +function GatePlayer:get_session_token() + return self.session.token +end + --查询组ID function GatePlayer:get_group_id(group_name) return self.groups[group_name] diff --git a/server/router/transfer_mgr.lua b/server/router/transfer_mgr.lua index f1ed3cad..39fc7cb0 100644 --- a/server/router/transfer_mgr.lua +++ b/server/router/transfer_mgr.lua @@ -30,7 +30,7 @@ function TransferMgr:__init() event_mgr:add_listener(self, "rpc_login_service") --消息转发 event_mgr:add_listener(self, "on_transfer_rpc") - event_mgr:add_listener(self, "on_boardcast_rpc") + event_mgr:add_listener(self, "on_broadcast_rpc") --初始化变量 self.rpc_server = router_server:get_rpc_server() end @@ -66,14 +66,14 @@ function TransferMgr:rpc_query_service(client, player_id, serv_name) end --转发广播 -function TransferMgr:on_boardcast_rpc(client, player_id, slice) +function TransferMgr:on_broadcast_rpc(client, player_id, slice) local routers = self:find_routers(player_id) if not routers then slice = slice.string() routers = self:query_routers(player_id, NODE_ID) end if not routers then - log_warn("[TransferMgr][on_boardcast_rpc]: %s find routers failed!", player_id) + log_warn("[TransferMgr][on_broadcast_rpc]: %s find routers failed!", player_id) return end for _, server_id in pairs(routers) do diff --git a/server/test/clickhouse_test.lua b/server/test/clickhouse_test.lua index 0fec9eb0..d1fe1bf9 100644 --- a/server/test/clickhouse_test.lua +++ b/server/test/clickhouse_test.lua @@ -9,15 +9,15 @@ local ck_mgr = ClickMgr() local MAIN_DBID = environ.number("QUANTA_DB_MAIN_ID") timer_mgr:once(2000, function() - local code, res_oe = ck_mgr:execute(MAIN_DBID, "drop table if exists test_ck") + local code, res_oe = ck_mgr:query(MAIN_DBID, MAIN_DBID, "drop table if exists test_ck") log_debug("db drop table code: %s, err = %s", code, res_oe) - code, res_oe = ck_mgr:execute(MAIN_DBID, "create table if not exists test_ck (id int, pid int, value int, primary key (id)) ENGINE = MergeTree") + code, res_oe = ck_mgr:query(MAIN_DBID, MAIN_DBID, "create table if not exists test_ck (id int, pid int, value int, primary key (id)) ENGINE = MergeTree") log_debug("db create table code: %s, err = %s", code, res_oe) - code, res_oe = ck_mgr:execute(MAIN_DBID, "select count(*) as count from test_ck where pid=123456") + code, res_oe = ck_mgr:query(MAIN_DBID, MAIN_DBID, "select count(*) as count from test_ck where pid=123456") log_debug("db select code: %s, count = %s", code, res_oe) - code, res_oe = ck_mgr:execute(MAIN_DBID, "insert into test_ck (id, pid, value) values (1, 123456, 40)") + code, res_oe = ck_mgr:query(MAIN_DBID, MAIN_DBID, "insert into test_ck (id, pid, value) values (1, 123456, 40)") log_debug("db insert code: %s, count = %s", code, res_oe) - code, res_oe = ck_mgr:execute(MAIN_DBID, "select * from test_ck where pid = 123456") + code, res_oe = ck_mgr:query(MAIN_DBID, MAIN_DBID, "select * from test_ck where pid = 123456") log_debug("db select code: %s, res_oe = %s", code, res_oe) end) diff --git a/server/test/codec_test.lua b/server/test/codec_test.lua index 69fe496a..66708322 100644 --- a/server/test/codec_test.lua +++ b/server/test/codec_test.lua @@ -4,15 +4,15 @@ local log_debug = logger.debug local log_dump = logger.dump local lhex_encode = crypt.hex_encode -local encode = codec.encode -local decode = codec.decode -local serialize = codec.serialize -local unserialize = codec.unserialize +local crc8 = codec.crc8 local hash_code = codec.hash_code local fnv_32a = codec.fnv_1a_32 local fnv_32 = codec.fnv_1_32 local jumphash = codec.jumphash -local crc8 = codec.crc8 +local encode = luakit.encode +local decode = luakit.decode +local serialize = luakit.serialize +local unserialize = luakit.unserialize --ketama --[[ diff --git a/server/test/json_test.lua b/server/test/json_test.lua index a20ae243..1d444dd2 100644 --- a/server/test/json_test.lua +++ b/server/test/json_test.lua @@ -9,10 +9,10 @@ local bencode = bson.encode local bdecode = bson.decode local cencode = cjson.encode local cdecode = cjson.decode -local lencode = codec.encode -local ldecode = codec.decode local log_debug = logger.debug local new_guid = codec.guid_new +local lencode = luakit.encode +local ldecode = luakit.decode local protobuf_mgr = quanta.get("protobuf_mgr") diff --git a/server/test/mysql_test.lua b/server/test/mysql_test.lua index 358d064f..51cdcee4 100644 --- a/server/test/mysql_test.lua +++ b/server/test/mysql_test.lua @@ -9,26 +9,29 @@ local mysql_mgr = MysqlMgr() local MAIN_DBID = environ.number("QUANTA_DB_MAIN_ID") timer_mgr:once(3000, function() - local code, res_oe = mysql_mgr:execute(MAIN_DBID, "drop table test_mysql") - log_debug("db drop table code: %s, err = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "create table if not exists test_mysql (id int auto_increment, pid int, value int, primary key (id))") - log_debug("db create table code: %s, err = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "select count(*) as count from test_mysql where pid=123456") + log_debug("mysql db start test!") + local code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "drop table test_mysql") + log_debug("db drop table code: %s, res = %s", code, res_oe) + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "create table if not exists test_mysql (id int auto_increment, pid int, value int, primary key (id))") + log_debug("db create table code: %s, res = %s", code, res_oe) + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "insert into test_mysql (pid, value) values (123457, 40)") + log_debug("db insert code: %s, res_oe = %s", code, res_oe) + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "insert into test_mysql (pid, value) values (123456, 40)") + log_debug("db insert code: %s, res_oe = %s", code, res_oe) + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "select count(*) as count from test_mysql where pid=123457") log_debug("db select code: %s, count = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "insert into test_mysql (pid, value) values (123457, 40)") - log_debug("db insert code: %s, count = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "select * from test_mysql where pid = 123456") + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "select * from test_mysql where pid = 123456") log_debug("db select code: %s, res_oe = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "update test_mysql set pid = 123454, value = 20 where pid = 123456 limit 1") + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "update test_mysql set pid = 123454, value = 20 where pid = 123456 limit 1") log_debug("db update code: %s, err = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "select * from test_mysql") + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "select * from test_mysql") log_debug("db select code: %s, res_oe = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "delete from test_mysql where pid = 123457") + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "delete from test_mysql where pid = 123457") log_debug("db delete code: %s, err = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "replace into test_mysql (id, pid, value) values (1, 123457, 40)") + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "replace into test_mysql (id, pid, value) values (1, 123457, 40)") log_debug("db replace code: %s, count = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "select * from test_mysql") + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "select * from test_mysql") log_debug("db select code: %s, res_oe = %s", code, res_oe) - code, res_oe = mysql_mgr:execute(MAIN_DBID, "select count(*) as count from test_mysql where pid=123456") + code, res_oe = mysql_mgr:query(MAIN_DBID, MAIN_DBID, "select count(*) as count from test_mysql where pid=123454") log_debug("db count code: %s, count = %s", code, res_oe) end) \ No newline at end of file diff --git a/tools/accord/accord.lua b/tools/accord/accord.lua index 5314f2c9..f864eca4 100644 --- a/tools/accord/accord.lua +++ b/tools/accord/accord.lua @@ -7,7 +7,7 @@ local lappend = stdfs.append local lfilename = stdfs.filename local lextension = stdfs.extension local lcurdir = stdfs.current_path -local serialize = codec.serialize +local serialize = luakit.serialize local pb_enum_id = protobuf.enum local json_encode = json.encode local tunpack = table.unpack