diff --git a/extend/lbson/src/bson.h b/extend/lbson/src/bson.h index 2ee202a3..08c00910 100644 --- a/extend/lbson/src/bson.h +++ b/extend/lbson/src/bson.h @@ -91,22 +91,6 @@ namespace lbson { return lua_gettop(L); } - int pairs(lua_State* L) { - m_buffer.clean(); - size_t data_len = 0; - bson_value* value = lua_to_object(L, -1); - if (value == nullptr) { - char* data = (char*)encode_pairs(L, &data_len); - value = new bson_value(bson_type::BSON_DOCUMENT, data, data_len); - } else { - lua_pop(L, 1); - char* data = (char*)encode_pairs(L, &data_len); - value->str = string(data, data_len); - } - lua_push_object(L, value); - return 1; - } - uint8_t* encode_pairs(lua_State* L, size_t* data_len) { int n = lua_gettop(L); if (n < 2 || n % 2 != 0) { @@ -138,7 +122,68 @@ namespace lbson { return &m_buffer;; } + int date(lua_State* L, int64_t value) { + value = value * 1000; + return make_bson_value(L, bson_type::BSON_DATE, (const char *)&value, sizeof(value)); + } + + int int64(lua_State* L, int64_t value) { + return make_bson_value(L, bson_type::BSON_INT64, (const char *)&value, sizeof(value)); + } + + int objectid(lua_State* L) { + size_t data_len = 0; + const char* value = lua_tolstring(L, 1, &data_len); + if (data_len != 12) return luaL_error(L, "Invalid object id"); + return make_bson_value(L, bson_type::BSON_OBJECTID, value, data_len); + } + + int pairs(lua_State* L) { + m_buffer.clean(); + size_t data_len = 0; + char* data = (char*)encode_pairs(L, &data_len); + return make_bson_value(L, bson_type::BSON_DOCUMENT, data, data_len); + } + + int binary(lua_State* L) { + size_t data_len = 0; + const char* value = lua_tolstring(L, 1, &data_len); + luaL_Buffer b; + luaL_buffinit(L, &b); + luaL_addchar(&b, 0); + luaL_addchar(&b, (int)bson_type::BSON_BINARY); + luaL_addchar(&b, 0); + luaL_addlstring(&b, value, data_len); + luaL_pushresult(&b); + return 1; + } + + int regex(lua_State* L) { + luaL_Buffer b; + luaL_buffinit(L, &b); + luaL_addchar(&b, 0); + luaL_addchar(&b, (int)bson_type::BSON_REGEX); + lua_pushvalue(L,1); + luaL_addvalue(&b); + luaL_addchar(&b, 0); + lua_pushvalue(L,2); + luaL_addvalue(&b); + luaL_addchar(&b, 0); + luaL_pushresult(&b); + return 1; + } + protected: + int make_bson_value(lua_State *L, bson_type type, const char* ptr, size_t len) { + luaL_Buffer b; + luaL_buffinit(L, &b); + luaL_addchar(&b, 0); + luaL_addchar(&b, (int)type); + luaL_addlstring(&b, ptr, len); + luaL_pushresult(&b); + return 1; + } + size_t bson_index(char* str, size_t i) { if (i < max_bson_index) { memcpy(str, bson_numstrs[i], 4); @@ -299,8 +344,12 @@ namespace lbson { case LUA_TSTRING: { size_t sz; const char* buf = lua_tolstring(L, -1, &sz); - write_key(bson_type::BSON_STRING, key, len); - write_string(buf, sz); + if (buf[0] == 0 && sz >= 2) { + m_buffer.push_data((uint8_t*)buf + 1, sz - 1); + } else { + write_key(bson_type::BSON_STRING, key, len); + write_string(buf, sz); + } } break; case LUA_TNIL: diff --git a/extend/lbson/src/lbson.cpp b/extend/lbson/src/lbson.cpp index 84ae8c1d..9558cbd2 100644 --- a/extend/lbson/src/lbson.cpp +++ b/extend/lbson/src/lbson.cpp @@ -15,17 +15,20 @@ namespace lbson { static int pairs(lua_State* L) { return thread_bson.pairs(L); } - static bson_value* doc() { - return new bson_value(bson_type::BSON_DOCUMENT, ""); + static int regex(lua_State* L) { + return thread_bson.regex(L); } - static bson_value* int32(int32_t value) { - return new bson_value(bson_type::BSON_INT32, value); + static int binary(lua_State* L) { + return thread_bson.binary(L); } - static bson_value* int64(int64_t value) { - return new bson_value(bson_type::BSON_INT64, value); + static int objectid(lua_State* L) { + return thread_bson.objectid(L); } - static bson_value* date(int64_t value) { - return new bson_value(bson_type::BSON_DATE, value * 1000); + static int int64(lua_State* L, int64_t value) { + return thread_bson.int64(L, value); + } + static int date(lua_State* L, int64_t value) { + return thread_bson.date(L, value); } static void init_static_bson() { @@ -45,14 +48,15 @@ namespace lbson { luakit::lua_table open_lbson(lua_State* L) { luakit::kit_state kit_state(L); auto llbson = kit_state.new_table("bson"); + llbson.set_function("mongocodec", mongo_codec); + llbson.set_function("objectid", objectid); llbson.set_function("encode", encode); llbson.set_function("decode", decode); - llbson.set_function("mongocodec", mongo_codec); - llbson.set_function("int32", int32); + llbson.set_function("binary", binary); llbson.set_function("int64", int64); llbson.set_function("pairs", pairs); + llbson.set_function("regex", regex); llbson.set_function("date", date); - llbson.set_function("doc", doc); kit_state.new_class( "val", &bson_value::val, "str", &bson_value::str, diff --git a/script/driver/mongo.lua b/script/driver/mongo.lua index a8bd19be..671f15dd 100644 --- a/script/driver/mongo.lua +++ b/script/driver/mongo.lua @@ -27,6 +27,7 @@ local lb64decode = ssl.b64_decode local lhmac_sha1 = ssl.hmac_sha1 local lxor_byte = ssl.xor_byte local bsonpairs = bson.pairs +local bint64 = bson.int64 local lclock_ms = timer.clock_ms local timer_mgr = quanta.get("timer_mgr") @@ -49,8 +50,6 @@ prop:reader("passwd", nil) --passwd prop:reader("salted_pass", nil) --salted_pass prop:reader("executer", nil) --执行者 prop:reader("timer_id", nil) --timer_id -prop:reader("cursor_id", nil) --cursor_id -prop:reader("sort_doc", nil) --sort_doc prop:reader("connections", {}) --connections prop:reader("alives", {}) --alives prop:reader("req_counter", nil) @@ -60,8 +59,6 @@ function MongoDB:__init(conf) self.name = conf.db self.user = conf.user self.passwd = conf.passwd - self.sort_doc = bson.doc() - self.cursor_id = bson.int64(0) self.codec = bson.mongocodec() self:set_options(conf.opts) self:setup_pool(conf.hosts) @@ -380,14 +377,11 @@ function MongoDB:find_one(co_name, query, projection) return succ end -function MongoDB:format_pairs(args, doc) +function MongoDB:format_pairs(args) if args then if type(next(args)) == "string" then return args end - if doc then - tinsert(args, doc) - end return bsonpairs(tunpack(args)) end end @@ -395,7 +389,7 @@ end -- 参数说明 --sort: {k1=1} / {k1,1,k2,-1,k3,-1} function MongoDB:find(co_name, query, projection, sortor, limit, skip) - local fsortor = self:format_pairs(sortor, self.sort_doc) + local fsortor = self:format_pairs(sortor) local succ, reply = self:runCommand("find", co_name, "filter", query, "projection", projection, "sort", fsortor, "limit", limit, "skip", skip) if not succ then return succ, reply @@ -411,8 +405,8 @@ function MongoDB:find(co_name, query, projection, sortor, limit, skip) if limit and #results >= limit then break end - self.cursor_id.val = cursor.id - local msucc, moreply = self:runCommand("getMore", self.cursor_id, "collection", co_name, "batchSize", limit) + local new_cur_id = bint64(cursor.id) + local msucc, moreply = self:runCommand("getMore", new_cur_id, "collection", co_name, "batchSize", limit) if not msucc then return msucc, moreply end