diff --git a/extend/lcodec/src/mysql.h b/extend/lcodec/src/mysql.h index 477ed212..7ff9f1d5 100644 --- a/extend/lcodec/src/mysql.h +++ b/extend/lcodec/src/mysql.h @@ -1,35 +1,53 @@ #pragma once #include +#include #include "lua_kit.h" using namespace std; using namespace luakit; namespace lcodec { + // cmd constants + const uint8_t COM_SLEEP = 0x00; + const uint8_t COM_CONNECT = 0x0b; + const uint8_t COM_STMT_PREPARE = 0x16; + const uint8_t COM_STMT_CLOSE = 0x19; + // constants - inline size_t CLIENT_FLAG = 260047; - inline size_t MAX_PACKET_SIZE = 0xffffff; - inline size_t CLIENT_PLUGIN_AUTH = 1 << 3; + inline size_t CLIENT_FLAG = 260047; //0011 1111 0111 1100 1111 + inline size_t MAX_PACKET_SIZE = 0xffffff; + inline size_t CLIENT_PLUGIN_AUTH = 1 << 19; + inline size_t CLIENT_DEPRECATE_EOF = 1 << 24; // field types - inline uint16_t MYSQL_TYPE_TINY = 0x01; - inline uint16_t MYSQL_TYPE_DOUBLE = 0x05; - inline uint16_t MYSQL_TYPE_NULL = 0x08; - inline uint16_t MYSQL_TYPE_LONGLONG = 0x08; - inline uint16_t MYSQL_TYPE_VARCHAR = 0x0f; + const uint16_t MYSQL_TYPE_TINY = 0x01; + const uint16_t MYSQL_TYPE_SHORT = 0x02; + const uint16_t MYSQL_TYPE_LONG = 0x03; + const uint16_t MYSQL_TYPE_FLOAT = 0x04; + const uint16_t MYSQL_TYPE_DOUBLE = 0x05; + const uint16_t MYSQL_TYPE_NULL = 0x06; + const uint16_t MYSQL_TYPE_LONGLONG = 0x08; + const uint16_t MYSQL_TYPE_INT24 = 0x09; + const uint16_t MYSQL_TYPE_YEAR = 0x0d; + const uint16_t MYSQL_TYPE_VARCHAR = 0x0f; + const uint16_t MYSQL_TYPE_NEWDECIMAL = 0xf6; - // cmd constants - const uint8_t COM_SLEEP = 0x00; - const uint8_t COM_CONNECT = 0x0b; - const uint8_t COM_STMT_PREPARE = 0x16; - const uint8_t COM_STMT_CLOSE = 0x19; + // server status + inline size_t SERVER_MORE_RESULTS_EXISTS = 8; struct mysql_cmd { uint8_t cmd_id; size_t session_id; }; + struct mysql_column { + string_view name; + uint8_t type; + uint16_t flags; + }; + typedef vector mysql_columns; + class mysqlscodec : public codec_base { public: mysqlscodec(size_t session_id) { @@ -140,13 +158,84 @@ namespace lcodec { size_t command_decode(lua_State* L) { uint8_t type = *(uint8_t*)m_slice->read(); - if (type == 0x00) ok_packet_decode(L); - if (type == 0xff) err_packet_decode(L); - return 0; + if (type == 0x00) return ok_packet_decode(L); + if (type <= 0xfa) return data_packet_decode(L); + if (type == 0xff) return err_packet_decode(L); + throw invalid_argument("unsuppert packet type:" + type); + } + + void column_decode(mysql_columns& columns, size_t column_count) { + for (int i = 0; i < column_count; ++i) { + string_view catalog = decode_length_encoded_string(); + string_view schema = decode_length_encoded_string(); + string_view table = decode_length_encoded_string(); + string_view org_table = decode_length_encoded_string(); + string_view name = decode_length_encoded_string(); + string_view org_name = decode_length_encoded_string(); + // 1 byte fix length (skip) + // 2 byte character_set (skip) + // 4 byte column_length (skip) + m_slice->erase(7); + uint8_t type = *(uint8_t*)m_slice->read(); + uint16_t flags = *(uint16_t*)m_slice->read(); + uint8_t decimals = *(uint8_t*)m_slice->read(); + columns.push_back(mysql_column { name, type, decimals }); + } } - void ok_packet_decode(lua_State* L) { + void rows_decode(lua_State* L, mysql_columns& columns, size_t index) { + lua_createtable(L, 0, 8); + for (auto column : columns) { + auto value = decode_length_encoded_string(); + switch (column.type) { + case MYSQL_TYPE_TINY: + case MYSQL_TYPE_SHORT: + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_FLOAT: + case MYSQL_TYPE_DOUBLE: + case MYSQL_TYPE_INT24: + case MYSQL_TYPE_YEAR: + case MYSQL_TYPE_LONGLONG: + case MYSQL_TYPE_NEWDECIMAL: + if (lua_stringtonumber(L, value.data()) == 0) { + lua_pushlstring(L, value.data(), value.size()); + } + break; + default: + lua_pushlstring(L, value.data(), value.size()); + break; + } + lua_setfield(L, -2, column.name.data()); + } + lua_seti(L, -2, index); + } + + bool result_set_decode(lua_State* L, size_t index) { + mysql_columns columns; + size_t column_count = decode_length_encoded_number(); + column_decode(columns, column_count); + if ((m_capability & CLIENT_DEPRECATE_EOF) != CLIENT_DEPRECATE_EOF) { + eof_packet_decode(L); + } + rows_decode(L, columns, index); + return eof_packet_decode(L); + } + + int data_packet_decode(lua_State* L) { + size_t index = 1; + int top = lua_gettop(L); + lua_pushboolean(L, true); + lua_createtable(L, 0, 4); + bool again = result_set_decode(L, index++); + while (again) { + again = result_set_decode(L, index++); + } + return lua_gettop(L) - top; + } + + int ok_packet_decode(lua_State* L) { size_t data_len; + int top = lua_gettop(L); size_t affected_rows = decode_length_encoded_number(); size_t last_insert_id = decode_length_encoded_number(); uint16_t status_flags = *(uint16_t*)m_slice->read(); @@ -157,25 +246,44 @@ namespace lcodec { lua_pushinteger(L, last_insert_id); lua_pushinteger(L, warnings); lua_pushstring(L, info); + return lua_gettop(L) - top; } - void err_packet_decode(lua_State* L) { + int err_packet_decode(lua_State* L) { + size_t data_len; + int top = lua_gettop(L); uint16_t errnoo = *(uint16_t*)m_slice->read(); //skip sql_state_marker - size_t data_len; char* sql_state = (char*)m_slice->peek(5, 1); const char* error_message = read_cstring(m_slice, data_len); lua_pushboolean(L, false); lua_pushinteger(L, errnoo); lua_pushstring(L, sql_state); lua_pushstring(L, error_message); + return lua_gettop(L) - top; + } + + bool eof_packet_decode(lua_State* L) { + if ((m_capability & CLIENT_DEPRECATE_EOF) == CLIENT_DEPRECATE_EOF) { + size_t data_len; + size_t affected_rows = decode_length_encoded_number(); + size_t last_insert_id = decode_length_encoded_number(); + uint16_t status_flags = *(uint16_t*)m_slice->read(); + uint16_t warnings = *(uint16_t*)m_slice->read(); + const char* info = read_cstring(m_slice, data_len); + return ((status_flags & SERVER_MORE_RESULTS_EXISTS) == SERVER_MORE_RESULTS_EXISTS); + } else { + uint16_t warnings = *(uint16_t*)m_slice->read(); + uint16_t status_flags = *(uint16_t*)m_slice->read(); + return ((status_flags & SERVER_MORE_RESULTS_EXISTS) == SERVER_MORE_RESULTS_EXISTS); + } } size_t prepare_decode(lua_State* L) { uint8_t status = *(uint8_t*)m_slice->read(); uint32_t statement_id = *(uint32_t*)m_slice->read(); uint16_t num_columns = *(uint16_t*)m_slice->read(); - uint16_t num_params = *(uint16_t*)m_slice->read(); + uint16_t num_params = *(uint16_t*)m_slice->read(); int top = lua_gettop(L); lua_pushinteger(L, statement_id); lua_pushinteger(L, num_columns); @@ -203,6 +311,7 @@ namespace lcodec { uint16_t status_flags = *(uint16_t*)m_slice->read(); //2 byte capability_flags_2 uint16_t capability_flag_2 = *(uint16_t*)m_slice->read(); + m_capability = capability_flag_2 << 16 | capability_flag_1; //1 byte character_set uint8_t auth_plugin_data_len = *(uint8_t*)m_slice->read(); //10 byte reserved (all 0) @@ -215,7 +324,7 @@ namespace lcodec { } //auth_plugin_name const char* auth_plugin_name = nullptr; - if ((capability_flag_2 & CLIENT_PLUGIN_AUTH) == CLIENT_PLUGIN_AUTH) { + if ((m_capability & CLIENT_PLUGIN_AUTH) == CLIENT_PLUGIN_AUTH) { auth_plugin_name = read_cstring(m_slice, data_len); } int top = lua_gettop(L); @@ -330,17 +439,21 @@ namespace lcodec { if (nbyte == 0xfc) return *(uint16_t*)m_slice->read(); if (nbyte == 0xfd) return *(uint32_t*)m_slice->read(); if (nbyte == 0xfe) return *(uint64_t*)m_slice->read(); - throw invalid_argument("invalid length coded number:" + nbyte); + return 0; } string_view decode_length_encoded_string() { size_t length = decode_length_encoded_number(); - char* data = (char*)m_slice->peek(length); - if (!data) throw invalid_argument("invalid length coded string:" + length); - return string_view(data, length); + if (length > 0) { + char* data = (char*)m_slice->peek(length); + if (!data) throw invalid_argument("invalid length coded string:" + length); + return string_view(data, length); + } + return ""; } protected: deque sessions; + uint32_t m_capability = 0; }; }