From 6cf0ffcec969e4a983171a5f411506b2ed0fd2c1 Mon Sep 17 00:00:00 2001 From: Roger Meier Date: Sat, 5 Apr 2014 00:45:42 +0200 Subject: [PATCH] THRIFT-1681: Add Lua Support Patch: Dave Watson Github Pull Request: This closes #92 --- compiler/cpp/Makefile.am | 3 +- compiler/cpp/src/generate/t_lua_generator.cc | 1226 ++++++++++++++++++ configure.ac | 18 + lib/Makefile.am | 3 + lib/lua/Makefile.am | 58 + lib/lua/TBinaryProtocol.lua | 264 ++++ lib/lua/TBufferedTransport.lua | 91 ++ lib/lua/TFramedTransport.lua | 119 ++ lib/lua/TMemoryBuffer.lua | 91 ++ lib/lua/TProtocol.lua | 162 +++ lib/lua/TServer.lua | 139 ++ lib/lua/TSocket.lua | 132 ++ lib/lua/TTransport.lua | 93 ++ lib/lua/Thrift.lua | 273 ++++ lib/lua/src/longnumberutils.c | 47 + lib/lua/src/luabitwise.c | 83 ++ lib/lua/src/luabpack.c | 162 +++ lib/lua/src/lualongnumber.c | 228 ++++ lib/lua/src/luasocket.c | 386 ++++++ lib/lua/src/socket.h | 78 ++ lib/lua/src/usocket.c | 362 ++++++ test/ThriftTest.thrift | 1 + test/lua/test_basic_client.lua | 136 ++ test/lua/test_basic_server.lua | 104 ++ 24 files changed, 4258 insertions(+), 1 deletion(-) create mode 100644 compiler/cpp/src/generate/t_lua_generator.cc create mode 100644 lib/lua/Makefile.am create mode 100644 lib/lua/TBinaryProtocol.lua create mode 100644 lib/lua/TBufferedTransport.lua create mode 100644 lib/lua/TFramedTransport.lua create mode 100644 lib/lua/TMemoryBuffer.lua create mode 100644 lib/lua/TProtocol.lua create mode 100644 lib/lua/TServer.lua create mode 100644 lib/lua/TSocket.lua create mode 100644 lib/lua/TTransport.lua create mode 100644 lib/lua/Thrift.lua create mode 100644 lib/lua/src/longnumberutils.c create mode 100644 lib/lua/src/luabitwise.c create mode 100644 lib/lua/src/luabpack.c create mode 100644 lib/lua/src/lualongnumber.c create mode 100644 lib/lua/src/luasocket.c create mode 100644 lib/lua/src/socket.h create mode 100644 lib/lua/src/usocket.c create mode 100644 test/lua/test_basic_client.lua create mode 100644 test/lua/test_basic_server.lua diff --git a/compiler/cpp/Makefile.am b/compiler/cpp/Makefile.am index 5fee8569517..47725bd4c36 100644 --- a/compiler/cpp/Makefile.am +++ b/compiler/cpp/Makefile.am @@ -90,7 +90,8 @@ thrift_SOURCES += src/generate/t_c_glib_generator.cc \ src/generate/t_delphi_generator.cc \ src/generate/t_go_generator.cc \ src/generate/t_gv_generator.cc \ - src/generate/t_d_generator.cc + src/generate/t_d_generator.cc \ + src/generate/t_lua_generator.cc thrift_CPPFLAGS = -I$(srcdir)/src thrift_CXXFLAGS = -Wall diff --git a/compiler/cpp/src/generate/t_lua_generator.cc b/compiler/cpp/src/generate/t_lua_generator.cc new file mode 100644 index 00000000000..b7fdad41df5 --- /dev/null +++ b/compiler/cpp/src/generate/t_lua_generator.cc @@ -0,0 +1,1226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include "t_oop_generator.h" +#include "platform.h" + +using std::ofstream; +using std::string; +using std::vector; +using std::map; + +static const string endl = "\n"; // avoid ostream << std::endl flushes + +/** + * LUA code generator. + * + */ +class t_lua_generator : public t_oop_generator { + public: + t_lua_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + std::map::const_iterator iter; + + iter = parsed_options.find("omit_requires"); + gen_requires_ = (iter == parsed_options.end()); + + out_dir_base_ = "gen-lua"; + } + + /** + * Init and close methods + */ + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + + std::string render_const_value(t_type* type, t_const_value* value); + + private: + + /** + * True iff we should generate lua require statements. + */ + bool gen_requires_; + + /** + * Struct-level generation functions + */ + void generate_lua_struct_definition( + std::ofstream& out, t_struct* tstruct, bool is_xception=false); + void generate_lua_struct_reader(std::ofstream& out, t_struct* tstruct); + void generate_lua_struct_writer(std::ofstream& out, t_struct* tstruct); + + /** + * Service-level generation functions + */ + void generate_service_client (std::ofstream& out, t_service* tservice); + void generate_service_interface (std::ofstream& out, t_service* tservice); + void generate_service_processor (std::ofstream& out, t_service* tservice); + void generate_process_function (std::ofstream& out, t_service* tservice, + t_function* tfunction); + void generate_service_helpers (ofstream &out, t_service* tservice); + void generate_function_helpers (ofstream &out, t_function* tfunction); + + /** + * Deserialization (Read) + */ + void generate_deserialize_field( + std::ofstream &out, t_field *tfield, std::string prefix=""); + + void generate_deserialize_struct( + std::ofstream &out, t_struct *tstruct, std::string prefix=""); + + void generate_deserialize_container( + std::ofstream &out, t_type *ttype, std::string prefix=""); + + void generate_deserialize_set_element( + std::ofstream &out, t_set *tset, std::string prefix=""); + + void generate_deserialize_map_element( + std::ofstream &out, t_map *tmap, std::string prefix=""); + + void generate_deserialize_list_element( + std::ofstream &out, t_list *tlist, std::string prefix=""); + + /** + * Serialization (Write) + */ + void generate_serialize_field( + std::ofstream &out, t_field *tfield, std::string prefix=""); + + void generate_serialize_struct( + std::ofstream &out, t_struct *tstruct, std::string prefix=""); + + void generate_serialize_container( + std::ofstream &out, t_type *ttype, std::string prefix=""); + + void generate_serialize_map_element( + std::ofstream &out, t_map *tmap, std::string kiter, std::string viter); + + void generate_serialize_set_element( + std::ofstream &out, t_set *tmap, std::string iter); + + void generate_serialize_list_element( + std::ofstream &out, t_list *tlist, std::string iter); + + /** + * Helper rendering functions + */ + std::string lua_includes(); + std::string function_signature(t_function* tfunction, std::string prefix=""); + std::string argument_list(t_struct* tstruct, std::string prefix=""); + std::string type_to_enum(t_type* ttype); + static std::string get_namespace(const t_program* program); + + std::string autogen_comment() { + return + std::string("--\n") + + "-- Autogenerated by Thrift\n" + + "--\n" + + "-- DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" + + "-- @""generated\n" + + "--\n"; + } + + /** + * File streams + */ + std::ofstream f_types_; + std::ofstream f_consts_; + std::ofstream f_service_; +}; + + +/** + * Init and close methods + */ +void t_lua_generator::init_generator() { + // Make output directory + string outdir = get_out_dir(); + MKDIR(outdir.c_str()); + + // Make output files + string cur_namespace = get_namespace(program_); + string f_consts_name = outdir + cur_namespace + "constants.lua"; + f_consts_.open(f_consts_name.c_str()); + string f_types_name = outdir + cur_namespace + "ttypes.lua"; + f_types_.open(f_types_name.c_str()); + + // Add headers + f_consts_ << autogen_comment() << lua_includes(); + f_types_ << autogen_comment() << lua_includes(); + if (gen_requires_) { + f_types_ << endl << "require '" << cur_namespace << "constants'"; + } +} + +void t_lua_generator::close_generator() { + // Close types file + f_types_.close(); + f_consts_.close(); +} + +/** + * Generate a typedef (essentially a constant) + */ +void t_lua_generator::generate_typedef(t_typedef* ttypedef) { + f_types_ + << endl << endl << indent() + << ttypedef->get_symbolic() << " = " + << ttypedef->get_type()->get_name(); +} + +/** + * Generates code for an enumerated type (table) + */ +void t_lua_generator::generate_enum(t_enum* tenum) { + f_types_ << endl << endl << tenum->get_name() << " = {" << endl; + + vector constants = tenum->get_constants(); + vector::iterator c_iter; + for (c_iter = constants.begin(); c_iter != constants.end();) { + int32_t value = (*c_iter)->get_value(); + + f_types_ << " " << (*c_iter)->get_name() << " = " << value; + ++c_iter; + if (c_iter != constants.end()) { + f_types_ << ","; + } + f_types_ << endl; + } + f_types_ << "}"; +} + +/** + * Generate a constant (non-local) value + */ +void t_lua_generator::generate_const(t_const* tconst) { + t_type* type = tconst->get_type(); + string name = tconst->get_name(); + t_const_value* value = tconst->get_value(); + + f_consts_ << endl << endl << name << " = "; + f_consts_ << render_const_value(type, value); +} + +/** + * Prints the value of a constant with the given type. + */ +string t_lua_generator::render_const_value( + t_type* type, t_const_value* value) { + std::ostringstream out; + + type = get_true_type(type); + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + out << "'" << value->get_string() << "'"; + break; + case t_base_type::TYPE_BOOL: + out << (value->get_integer() > 0 ? "true" : "false"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + out << value->get_integer(); + break; + case t_base_type::TYPE_I64: + out << "lualongnumber.new('" << value->get_string() << "')"; + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + out << value->get_integer(); + } else { + out << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << value->get_integer(); + } else if (type->is_struct() || type->is_xception()) { + out << type->get_name() << " = {" << endl; + indent_up(); + + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end();) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + + v_iter->first->get_string(); + } + + indent(out); + out << render_const_value(g_type_string, v_iter->first); + out << " = "; + out << render_const_value(field_type, v_iter->second); + ++v_iter; + if (v_iter != val.end()) { + out << ","; + } + } + + out << "}"; + indent_down(); + } else if (type->is_map()) { + out << type->get_name() << "{" << endl; + indent_up(); + + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end();) { + indent(out) + << "[" << render_const_value(ktype, v_iter->first) << "] = " + << render_const_value(vtype, v_iter->second); + ++v_iter; + if (v_iter != val.end()) { + out << ","; + } + out << endl; + } + indent_down(); + indent(out) << "}"; + } else if (type->is_list() || type->is_set()) { + t_type* etype; + if (type->is_list()) { + etype = ((t_list*)type)->get_elem_type(); + } else { + etype = ((t_set*)type)->get_elem_type(); + } + out << type->get_name() << " = {" << endl; + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end();) { + indent(out); + out << "[" << render_const_value(etype, *v_iter) << "]"; + if (type->is_set()) { + out << " = true"; + } else { + out << " = false"; + } + ++v_iter; + if (v_iter != val.end()) { + out << "," << endl; + } + } + out << "}"; + } + return out.str(); +} + +/** + * Generate a thrift struct + */ +void t_lua_generator::generate_struct(t_struct* tstruct) { + generate_lua_struct_definition(f_types_, tstruct, false); +} + +/** + * Generate a thrift exception + */ +void t_lua_generator::generate_xception(t_struct* txception) { + generate_lua_struct_definition(f_types_, txception, true); +} + +/** + * Generate a thrift struct or exception (lua table) + */ +void t_lua_generator::generate_lua_struct_definition(ofstream &out, + t_struct *tstruct, + bool is_exception) { + vector::const_iterator m_iter; + const vector& members = tstruct->get_members(); + + indent(out) << endl << endl << tstruct->get_name(); + if (is_exception) { + out << " = TException:new{" << endl << + indent() << " __type = '" << tstruct->get_name() << "'"; + if (members.size() > 0) { + out << ","; + } + out << endl; + } else { + out << " = __TObject:new{" << endl; + } + indent_up(); + for (m_iter = members.begin(); m_iter != members.end();) { + indent(out); + out << (*m_iter)->get_name(); + ++m_iter; + if (m_iter != members.end()) { + out << "," << endl; + } + } + indent_down(); + indent(out); + out << endl << "}"; + + generate_lua_struct_reader(out, tstruct); + generate_lua_struct_writer(out, tstruct); +} + +/** + * Generate a struct/exception reader + */ +void t_lua_generator::generate_lua_struct_reader(ofstream& out, + t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + // function + indent(out) << endl << endl + << "function " << tstruct->get_name() << ":read(iprot)" << endl; + indent_up(); + + indent(out) << "iprot:readStructBegin()" << endl; + + // while: Read in fields + indent(out) << "while true do" << endl; + indent_up(); + + // if: Check what to read + indent(out) << "local fname, ftype, fid = iprot:readFieldBegin()" << endl; + indent(out) << "if ftype == TType.STOP then" << endl; + indent_up(); + indent(out) << "break" << endl; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + indent_down(); + indent(out) + << "elseif fid == " << (*f_iter)->get_key() << " then" << endl; + indent_up(); + indent(out) + << "if ftype == " << type_to_enum((*f_iter)->get_type()) + << " then" << endl; + indent_up(); + + // Read field contents + generate_deserialize_field(out, *f_iter, "self."); + + indent_down(); + indent(out) << "else" << endl; + indent(out) << " iprot:skip(ftype)" << endl; + indent(out) << "end" << endl; + } + + // end if + indent_down(); + indent(out) << "else" << endl; + indent(out) << " iprot:skip(ftype)" << endl; + indent(out) << "end" << endl; + indent(out) << "iprot:readFieldEnd()" << endl; + + // end while + indent_down(); + indent(out) << "end" << endl; + indent(out) << "iprot:readStructEnd()" << endl; + + // end function + indent_down(); + indent(out); + out << "end"; +} + +/** + * Generate a struct/exception writer + */ +void t_lua_generator::generate_lua_struct_writer(ofstream& out, + t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + // function + indent(out) << endl << endl + << "function " << tstruct->get_name() << ":write(oprot)" << endl; + indent_up(); + + indent(out) + << "oprot:writeStructBegin('" << tstruct->get_name() << "')" << endl; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + indent(out) << "if self." << (*f_iter)->get_name() << " then" << endl; + indent_up(); + indent(out) + << "oprot:writeFieldBegin('" << (*f_iter)->get_name() << "', " + << type_to_enum((*f_iter)->get_type()) << ", " + << (*f_iter)->get_key() << ")" << endl; + + // Write field contents + generate_serialize_field(out, *f_iter, "self."); + + indent(out) + << "oprot:writeFieldEnd()" << endl; + indent_down(); + indent(out) << "end" << endl; + } + indent(out) << "oprot:writeFieldStop()" << endl; + indent(out) << "oprot:writeStructEnd()" << endl; + + // end function + indent_down(); + indent(out); + out << "end"; +} + +/** + * Generate a thrift service + */ +void t_lua_generator::generate_service(t_service* tservice) { + // Get output directory + string outdir = get_out_dir(); + + // Open the file for writing + string cur_ns = get_namespace(program_); + string f_service_name = outdir + cur_ns + tservice->get_name() + ".lua"; + f_service_.open(f_service_name.c_str()); + + // Headers + f_service_ << autogen_comment() << lua_includes(); + if (gen_requires_) { + f_service_ << endl << "require '" << cur_ns << "ttypes'" << endl; + + if (tservice->get_extends() != NULL) { + f_service_ + << "require '" << get_namespace(tservice->get_extends()->get_program()) + << tservice->get_extends()->get_name() << "'" << endl; + } + } + + f_service_ << endl; + + generate_service_client(f_service_, tservice); + generate_service_interface(f_service_, tservice); + generate_service_processor(f_service_, tservice); + generate_service_helpers(f_service_, tservice); + + // Close the file + f_service_.close(); +} + +void t_lua_generator::generate_service_interface(ofstream &out, + t_service* tservice) { + string classname = tservice->get_name() + "Iface"; + t_service* extends_s = tservice->get_extends(); + + // Interface object definition + out << classname << " = "; + if (extends_s) { + out << extends_s->get_name() << "Iface:new{" << endl; + } else { + out << "__TObject:new{" << endl; + } + out + << " __type = '" << classname << "'" << endl + << "}" << endl << endl; +} + +void t_lua_generator::generate_service_client(ofstream &out, + t_service* tservice) { + string classname = tservice->get_name() + "Client"; + t_service* extends_s = tservice->get_extends(); + + // Client object definition + out << classname << " = __TObject.new("; + if (extends_s != NULL) { + out << extends_s->get_name() << "Client"; + } else { + out << "__TClient"; + } + out + <<", {" << endl + << " __type = '" << classname << "'" << endl + << "})" << endl; + + // Send/Recv functions + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string sig = function_signature(*f_iter); + string funcname = (*f_iter)->get_name(); + + // Wrapper function + indent(out) << endl << "function " << classname << ":" << sig << endl; + indent_up(); + + indent(out) << "self:send_" << sig << endl << indent(); + if (!(*f_iter)->is_oneway()) { + if (!(*f_iter)->get_returntype()->is_void()) { + out << "return "; + } + out << "self:recv_" << sig << endl; + } + + indent_down(); + indent(out) << "end" << endl; + + // Send function + indent(out) << endl << "function " << classname << ":send_" << sig << endl; + indent_up(); + + indent(out) << "self.oprot:writeMessageBegin('" << funcname << + "', TMessageType.CALL, self._seqid)" << endl; + indent(out) << "local args = " << funcname << "_args:new{}" << endl; + + // Set the args + const vector& args = (*f_iter)->get_arglist()->get_members(); + vector::const_iterator fld_iter; + for (fld_iter = args.begin(); fld_iter != args.end(); ++fld_iter) { + std::string argname = (*fld_iter)->get_name(); + indent(out) << "args." << argname << " = " << argname << endl; + } + + indent(out) << "args:write(self.oprot)" << endl; + indent(out) << "self.oprot:writeMessageEnd()" << endl; + indent(out) << "self.oprot.trans:flush()" << endl; + + indent_down(); + indent(out) << "end" << endl; + + // Recv function + if (!(*f_iter)->is_oneway()) { + indent(out) + << endl << "function " << classname << ":recv_" << sig << endl; + indent_up(); + + out << + indent() << "local fname, mtype, rseqid = self.iprot:" + << "readMessageBegin()"<< endl << + indent() << "if mtype == TMessageType.EXCEPTION then" << endl << + indent() << " local x = TApplicationException:new{}" << endl << + indent() << " x:read(self.iprot)" << endl << + indent() << " self.iprot:readMessageEnd()" << endl << + indent() << " error(x)" << endl << + indent() << "end" << endl << + indent() << "local result = " << funcname << "_result:new{}" + << endl << + indent() << "result:read(self.iprot)" << endl << + indent() << "self.iprot:readMessageEnd()" << endl; + + // Return the result if it's not a void function + if (!(*f_iter)->get_returntype()->is_void()) { + out << + indent() << "if result.success then" << endl << + indent() << " return result.success" << endl; + + // Throw custom exceptions + const std::vector& xf = + (*f_iter)->get_xceptions()->get_members(); + vector::const_iterator x_iter; + for (x_iter = xf.begin(); x_iter != xf.end(); ++x_iter) { + out << + indent() << "elseif result." << (*x_iter)->get_name() << " then" + << endl << + indent() << " error(result." << (*x_iter)->get_name() << ")" + << endl; + } + + out << + indent() << "end" << endl << + indent() << "error(TApplicationException:new{errorCode = " + << "TApplicationException.MISSING_RESULT})" << endl; + } + + indent_down(); + indent(out) << "end" << endl; + } + } +} + +void t_lua_generator::generate_service_processor(ofstream &out, + t_service* tservice) { + string classname = tservice->get_name() + "Processor"; + t_service* extends_s = tservice->get_extends(); + + // Define processor table + out << endl + << classname << " = __TObject.new("; + if (extends_s != NULL) { + out << extends_s << "Processor" << endl; + } else { + out << "__TProcessor" << endl; + } + out + << ", {" << endl + << " __type = '" << classname << "'" << endl + << "})" << endl; + + // Process function + indent(out) << endl << "function " << classname + << ":process(iprot, oprot, server_ctx)" << endl; + indent_up(); + + indent(out) + << "local name, mtype, seqid = iprot:readMessageBegin()" << endl; + indent(out) + << "local func_name = 'process_' .. name" << endl; + indent(out) + << "if not self[func_name] or ttype(self[func_name]) ~= 'function' then"; + indent_up(); + out << endl << + indent() << "iprot:skip(TType.STRUCT)" << endl << + indent() << "iprot:readMessageEnd()" << endl << + indent() << "x = TApplicationException:new{" << endl << + indent() << " errorCode = TApplicationException.UNKNOWN_METHOD" << endl + << indent() << "}" << endl << + indent() << "oprot:writeMessageBegin(name, TMessageType.EXCEPTION, " + << "seqid)" << endl << + indent() << "x:write(oprot)" << endl << + indent() << "oprot:writeMessageEnd()" << endl << + indent() << "oprot.trans:flush()" << endl; + indent_down(); + indent(out) << "else" << endl << + indent() << " self[func_name](self, seqid, iprot, oprot, server_ctx)" + << endl + << indent() << "end" << endl; + + indent_down(); + indent(out) << "end" << endl; + + // Generate the process subfunctions + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_process_function(out, tservice, *f_iter); + } +} + +void t_lua_generator::generate_process_function(ofstream &out, + t_service* tservice, + t_function* tfunction) { + string classname = tservice->get_name() + "Processor"; + string argsname = tfunction->get_name() + "_args"; + string resultname = tfunction->get_name() + "_result"; + string fn_name = tfunction->get_name(); + + indent(out) << endl << "function " << classname << ":process_" << fn_name + << "(seqid, iprot, oprot, server_ctx)" << endl; + indent_up(); + + // Read the request + out << + indent() << "local args = " << argsname << ":new{}" << endl << + indent() << "local reply_type = TMessageType.REPLY" << endl << + indent() << "args:read(iprot)" << endl << + indent() << "iprot:readMessageEnd()" << endl << + indent() << "local result = " << resultname << ":new{}" << endl << + indent() << "local status, res = pcall(self.handler." << fn_name + << ", self.handler"; + + // Print arguments + t_struct *args = tfunction->get_arglist(); + if (args->get_members().size() > 0) { + out << ", " << argument_list(args, "args."); + } + + // Check for errors + out << ")" << endl << + indent() << "if not status then" << endl << + indent() << " reply_type = TMessageType.EXCEPTION" << endl << + indent() << " result = TApplicationException:new{message = res}" + << endl; + + // Handle custom exceptions + const std::vector& xf = tfunction->get_xceptions()->get_members(); + if (xf.size() > 0) { + vector::const_iterator x_iter; + for (x_iter = xf.begin(); x_iter != xf.end(); ++x_iter) { + out << + indent() << "elseif ttype(res) == '" + << (*x_iter)->get_type()->get_name() << "' then" << endl << + indent() << " result." << (*x_iter)->get_name() << " = res" << endl; + } + } + + // Set the result and write the reply + out << + indent() << "else" << endl << + indent() << " result.success = res" << endl << + indent() << "end" << endl << + indent() << "oprot:writeMessageBegin('" << fn_name << "', reply_type, " + << "seqid)" << endl << + indent() << "result:write(oprot)" << endl << + indent() << "oprot:writeMessageEnd()" << endl << + indent() << "oprot.trans:flush()" << endl; + + indent_down(); + indent(out) << "end" << endl; +} + +// Service helpers +void t_lua_generator::generate_service_helpers(ofstream &out, + t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + out << endl << "-- HELPER FUNCTIONS AND STRUCTURES"; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + generate_lua_struct_definition(out, ts, false); + generate_function_helpers(out, *f_iter); + } +} + +void t_lua_generator::generate_function_helpers(ofstream &out, + t_function *tfunction) { + if (!tfunction->is_oneway()) { + t_struct result(program_, tfunction->get_name() + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + generate_lua_struct_definition(out, &result, false); + } +} + +/** + * Deserialize (Read) + */ +void t_lua_generator::generate_deserialize_field(ofstream &out, + t_field* tfield, + string prefix) { + t_type* type = get_true_type(tfield->get_type()); + + if (type->is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + string name = prefix + tfield->get_name(); + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, (t_struct*)type, name); + } else if (type->is_container()) { + generate_deserialize_container(out, type, name); + } else if (type->is_base_type() || type->is_enum()) { + indent(out) << + name << " = iprot:"; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + + name; + break; + case t_base_type::TYPE_STRING: + out << "readString()"; + break; + case t_base_type::TYPE_BOOL: + out << "readBool()"; + break; + case t_base_type::TYPE_BYTE: + out << "readByte()"; + break; + case t_base_type::TYPE_I16: + out << "readI16()"; + break; + case t_base_type::TYPE_I32: + out << "readI32()"; + break; + case t_base_type::TYPE_I64: + out << "readI64()"; + break; + case t_base_type::TYPE_DOUBLE: + out << "readDouble()"; + break; + default: + throw "compiler error: no PHP name for base type " + + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "readI32()"; + } + out << endl; + + } else { + printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n", + tfield->get_name().c_str(), type->get_name().c_str()); + } +} + +void t_lua_generator::generate_deserialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + indent(out) + << prefix << " = " << tstruct->get_name() << ":new{}" << endl + << indent() << prefix << ":read(iprot)" << endl; +} + +void t_lua_generator::generate_deserialize_container(ofstream &out, + t_type* ttype, + string prefix) { + string size = tmp("_size"); + string ktype = tmp("_ktype"); + string vtype = tmp("_vtype"); + string etype = tmp("_etype"); + + t_field fsize(g_type_i32, size); + t_field fktype(g_type_byte, ktype); + t_field fvtype(g_type_byte, vtype); + t_field fetype(g_type_byte, etype); + + // Declare variables, read header + indent(out) << prefix << " = {}" << endl; + if (ttype->is_map()) { + indent(out) << "local " << ktype << ", " << vtype << ", " << size + << " = iprot:readMapBegin() " << endl; + } else if (ttype->is_set()) { + indent(out) << "local " << etype << ", " << size + << " = iprot:readSetBegin()" << endl; + } else if (ttype->is_list()) { + indent(out) << "local " << etype << ", " << size + << " = iprot:readListBegin()" << endl; + } + + // Deserialize + indent(out) << "for _i=1," << size << " do" << endl; + indent_up(); + + if (ttype->is_map()) { + generate_deserialize_map_element(out, (t_map*)ttype, prefix); + } else if (ttype->is_set()) { + generate_deserialize_set_element(out, (t_set*)ttype, prefix); + } else if (ttype->is_list()) { + generate_deserialize_list_element(out, (t_list*)ttype, prefix); + } + + indent_down(); + indent(out) << "end" << endl; + + // Read container end + if (ttype->is_map()) { + indent(out) << "iprot:readMapEnd()" << endl; + } else if (ttype->is_set()) { + indent(out) << "iprot:readSetEnd()" << endl; + } else if (ttype->is_list()) { + indent(out) << "iprot:readListEnd()" << endl; + } +} + +void t_lua_generator::generate_deserialize_map_element(ofstream &out, + t_map* tmap, + string prefix) { + // A map is represented by a table indexable by any lua type + string key = tmp("_key"); + string val = tmp("_val"); + t_field fkey(tmap->get_key_type(), key); + t_field fval(tmap->get_val_type(), val); + + generate_deserialize_field(out, &fkey); + generate_deserialize_field(out, &fval); + + indent(out) << prefix << "[" << key << "] = " << val << endl; +} + +void t_lua_generator::generate_deserialize_set_element(ofstream &out, + t_set* tset, + string prefix) { + // A set is represented by a table indexed by the value + string elem = tmp("_elem"); + t_field felem(tset->get_elem_type(), elem); + + generate_deserialize_field(out, &felem); + + indent(out) << + prefix << "[" << elem << "] = " << elem << endl; +} + +void t_lua_generator::generate_deserialize_list_element(ofstream &out, + t_list* tlist, + string prefix) { + // A list is represented by a table indexed by integer values + // LUA natively provides all of the functions required to maintain a list + string elem = tmp("_elem"); + t_field felem(tlist->get_elem_type(), elem); + + generate_deserialize_field(out, &felem); + + indent(out) << "table.insert(" << prefix << ", " << elem << ")" << endl; +} + +/** + * Serialize (Write) + */ +void t_lua_generator::generate_serialize_field(ofstream &out, + t_field* tfield, + string prefix) { + t_type* type = get_true_type(tfield->get_type()); + string name = prefix + tfield->get_name(); + + // Do nothing for void types + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + name; + } + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, (t_struct*)type, name); + } else if (type->is_container()) { + generate_serialize_container(out, type, name); + } else if (type->is_base_type() || type->is_enum()) { + indent(out) << "oprot:"; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + out << "writeString(" << name << ")"; + break; + case t_base_type::TYPE_BOOL: + out << "writeBool(" << name << ")"; + break; + case t_base_type::TYPE_BYTE: + out << "writeByte(" << name << ")"; + break; + case t_base_type::TYPE_I16: + out << "writeI16(" << name << ")"; + break; + case t_base_type::TYPE_I32: + out << "writeI32(" << name << ")"; + break; + case t_base_type::TYPE_I64: + out << "writeI64(" << name << ")"; + break; + case t_base_type::TYPE_DOUBLE: + out << "writeDouble(" << name << ")"; + break; + default: + throw "compiler error: no PHP name for base type " + + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "writeI32(" << name << ")"; + } + out << endl; + } else { + printf("DO NOT KNOW HOW TO SERIALIZE FIELD '%s' TYPE '%s'\n", + name.c_str(), + type->get_name().c_str()); + } +} + +void t_lua_generator::generate_serialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + indent(out) << prefix << ":write(oprot)" << endl; +} + +void t_lua_generator::generate_serialize_container(ofstream &out, + t_type* ttype, + string prefix) { + // Begin writing + if (ttype->is_map()) { + indent(out) << + "oprot:writeMapBegin(" << + type_to_enum(((t_map*)ttype)->get_key_type()) << ", " << + type_to_enum(((t_map*)ttype)->get_val_type()) << ", " << + "string.len(" << prefix << "))" << endl; + } else if (ttype->is_set()) { + indent(out) << + "oprot:writeSetBegin(" << + type_to_enum(((t_set*)ttype)->get_elem_type()) << ", " << + "string.len(" << prefix << "))" << endl; + } else if (ttype->is_list()) { + indent(out) << + "oprot:writeListBegin(" << + type_to_enum(((t_list*)ttype)->get_elem_type()) << ", " << + "string.len(" << prefix << "))" << endl; + } + + // Serialize + if (ttype->is_map()) { + string kiter = tmp("kiter"); + string viter = tmp("viter"); + indent(out) + << "for " << kiter << "," << viter << " in pairs(" << prefix << ") do" + << endl; + indent_up(); + generate_serialize_map_element(out, (t_map*)ttype, kiter, viter); + indent_down(); + indent(out) << "end" << endl; + } else if (ttype->is_set()) { + string iter = tmp("iter"); + indent(out) << + "for " << iter << ",_ in pairs(" << prefix << ") do" << endl; + indent_up(); + generate_serialize_set_element(out, (t_set*)ttype, iter); + indent_down(); + indent(out) << "end" << endl; + } else if (ttype->is_list()) { + string iter = tmp("iter"); + indent(out) << + "for _," << iter << " in ipairs(" << prefix << ") do" << endl; + indent_up(); + generate_serialize_list_element(out, (t_list*)ttype, iter); + indent_down(); + indent(out) << "end" << endl; + } + + // Finish writing + if (ttype->is_map()) { + indent(out) << "oprot:writeMapEnd()" << endl; + } else if (ttype->is_set()) { + indent(out) << "oprot:writeSetEnd()" << endl; + } else if (ttype->is_list()) { + indent(out) << "oprot:writeListEnd()" << endl; + } +} + +void t_lua_generator::generate_serialize_map_element(ofstream &out, + t_map* tmap, + string kiter, + string viter) { + t_field kfield(tmap->get_key_type(), kiter); + generate_serialize_field(out, &kfield, ""); + + t_field vfield(tmap->get_val_type(), viter); + generate_serialize_field(out, &vfield, ""); +} + +void t_lua_generator::generate_serialize_set_element(ofstream &out, + t_set* tset, + string iter) { + t_field efield(tset->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +void t_lua_generator::generate_serialize_list_element(ofstream &out, + t_list* tlist, + string iter) { + t_field efield(tlist->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +/** + * Helper rendering functions + */ +string t_lua_generator::lua_includes() { + if (gen_requires_) { + return "\n\nrequire 'Thrift'"; + } else { + return ""; + } +} + +string t_lua_generator::get_namespace(const t_program* program) { + std::string real_module = program->get_namespace("lua"); + if (real_module.empty()) { + return program->get_name() + "_"; + } + return real_module + "_"; +} + +string t_lua_generator::function_signature(t_function* tfunction, + string prefix) { + std::string ret = tfunction->get_name() + "(" + + argument_list(tfunction->get_arglist()) + ")"; + return ret; +} + +string t_lua_generator::argument_list(t_struct* tstruct, string prefix) { + const vector& fields = tstruct->get_members(); + vector::const_iterator fld_iter; + std::string ret = ""; + for (fld_iter = fields.begin(); fld_iter != fields.end();) { + ret += prefix + (*fld_iter)->get_name(); + ++fld_iter; + if (fld_iter != fields.end()) { + ret += ", "; + } + } + return ret; +} + +string t_lua_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "TType.STRING"; + case t_base_type::TYPE_BOOL: + return "TType.BOOL"; + case t_base_type::TYPE_BYTE: + return "TType.BYTE"; + case t_base_type::TYPE_I16: + return "TType.I16"; + case t_base_type::TYPE_I32: + return "TType.I32"; + case t_base_type::TYPE_I64: + return "TType.I64"; + case t_base_type::TYPE_DOUBLE: + return "TType.DOUBLE"; + } + } else if (type->is_enum()) { + return "TType.I32"; + } else if (type->is_struct() || type->is_xception()) { + return "TType.STRUCT"; + } else if (type->is_map()) { + return "TType.MAP"; + } else if (type->is_set()) { + return "TType.SET"; + } else if (type->is_list()) { + return "TType.LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + +THRIFT_REGISTER_GENERATOR(lua, "Lua", ""); diff --git a/configure.ac b/configure.ac index 7eea98f4660..483c2839957 100755 --- a/configure.ac +++ b/configure.ac @@ -120,6 +120,7 @@ if test "$enable_libs" = "no"; then with_go="no" with_d="no" with_nodejs="no" + with_lua="no" fi @@ -214,6 +215,16 @@ fi AM_CONDITIONAL(WITH_NODEJS, [test "$have_nodejs" = "yes"]) AM_CONDITIONAL(HAVE_NPM, [test "x$NPM" != "x"]) +AX_THRIFT_LIB(lua, [Lua], yes) +have_lua=no +if test "$with_lua" = "yes"; then + AC_PATH_PROGS([LUA], [lua]) + if test "x$LUA" != "x"; then + have_lua="yes" + fi +fi +AM_CONDITIONAL(WITH_LUA, [test "$have_lua" = "yes"]) + AX_THRIFT_LIB(python, [Python], yes) if test "$with_python" = "yes"; then AM_PATH_PYTHON(2.4,, :) @@ -634,6 +645,7 @@ AC_CONFIG_FILES([ lib/php/test/Makefile lib/py/Makefile lib/rb/Makefile + lib/lua/Makefile test/Makefile test/cpp/Makefile test/hs/Makefile @@ -674,6 +686,7 @@ echo "Building Erlang Library ...... : $have_erlang" echo "Building Go Library .......... : $have_go" echo "Building D Library ........... : $have_d" echo "Building NodeJS Library ...... : $have_nodejs" +echo "Building Lua Library ......... : $have_lua" if test "$have_cpp" = "yes" ; then echo @@ -744,6 +757,11 @@ if test "$have_nodejs" = "yes" ; then echo " Using NodeJS .............. : $NODEJS" echo " Using NodeJS version....... : $($NODEJS --version)" fi +if test "$have_lua" = "yes" ; then + echo + echo "Lua Library:" + echo " Using Lua .............. : $LUA" +fi echo echo "If something is missing that you think should be present," echo "please skim the output of configure to find the missing" diff --git a/lib/Makefile.am b/lib/Makefile.am index 26d9020dfed..0ff7fa018a9 100644 --- a/lib/Makefile.am +++ b/lib/Makefile.am @@ -74,6 +74,9 @@ if WITH_NODEJS SUBDIRS += nodejs endif +if WITH_LUA +SUBDIRS += lua +endif # All of the libs that don't use Automake need to go in here # so they will end up in our release tarballs. diff --git a/lib/lua/Makefile.am b/lib/lua/Makefile.am new file mode 100644 index 00000000000..1c429679578 --- /dev/null +++ b/lib/lua/Makefile.am @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SUBDIRS = . + +lib_LTLIBRARIES = libluasocket.la \ + libluabpack.la \ + libluabitwise.la \ + liblualongnumber.la + +libluasocket_la_SOURCES = src/luasocket.c \ + src/usocket.c + +libluasocket_la_CPPFLAGS = $(AM_CPPFLAGS) -I/usr/include/lua5.2 -DLUA_COMPAT_MODULE +libluasocket_la_LDFLAGS = $(AM_LDFLAGS) -llua5.2 -lm + +libluabpack_la_SOURCES = src/luabpack.c + +libluabpack_la_CPPFLAGS = $(AM_CPPFLAGS) -I/usr/include/lua5.2 -DLUA_COMPAT_MODULE +libluabpack_la_LDFLAGS = $(AM_LDFLAGS) -llua5.2 -lm +libluabpack_la_LIBADD = liblualongnumber.la + +libluabitwise_la_SOURCES = src/luabitwise.c + +libluabitwise_la_CPPFLAGS = $(AM_CPPFLAGS) -I/usr/include/lua5.2 -DLUA_COMPAT_MODULE +libluabitwise_la_LDFLAGS = $(AM_LDFLAGS) -llua5.2 -lm + +liblualongnumber_la_SOURCES = src/lualongnumber.c \ + src/longnumberutils.c + +liblualongnumber_la_CPPFLAGS = $(AM_CPPFLAGS) -I/usr/include/lua5.2 -DLUA_COMPAT_MODULE +liblualongnumber_la_LDFLAGS = $(AM_LDFLAGS) -llua5.2 -lm + +EXTRA_DIST = TBinaryProtocol.lua \ + TBufferedTransport.lua \ + TFramedTransport.lua \ + Thrift.lua \ + TMemoryBuffer.lua \ + TProtocol.lua \ + TServer.lua \ + TSocket.lua \ + TTransport.lua \ No newline at end of file diff --git a/lib/lua/TBinaryProtocol.lua b/lib/lua/TBinaryProtocol.lua new file mode 100644 index 00000000000..df13d61e0c1 --- /dev/null +++ b/lib/lua/TBinaryProtocol.lua @@ -0,0 +1,264 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TProtocol' +require 'libluabpack' +require 'libluabitwise' + +TBinaryProtocol = __TObject.new(TProtocolBase, { + __type = 'TBinaryProtocol', + VERSION_MASK = -65536, -- 0xffff0000 + VERSION_1 = -2147418112, -- 0x80010000 + TYPE_MASK = 0x000000ff, + strictRead = false, + strictWrite = true +}) + +function TBinaryProtocol:writeMessageBegin(name, ttype, seqid) + if self.stirctWrite then + self:writeI32(libluabitwise.bor(TBinaryProtocol.VERSION_1, ttype)) + self:writeString(name) + self:writeI32(seqid) + else + self:writeString(name) + self:writeByte(ttype) + self:writeI32(seqid) + end +end + +function TBinaryProtocol:writeMessageEnd() +end + +function TBinaryProtocol:writeStructBegin(name) +end + +function TBinaryProtocol:writeStructEnd() +end + +function TBinaryProtocol:writeFieldBegin(name, ttype, id) + self:writeByte(ttype) + self:writeI16(id) +end + +function TBinaryProtocol:writeFieldEnd() +end + +function TBinaryProtocol:writeFieldStop() + self:writeByte(TType.STOP); +end + +function TBinaryProtocol:writeMapBegin(ktype, vtype, size) + self:writeByte(ktype) + self:writeByte(vtype) + self:writeI32(size) +end + +function TBinaryProtocol:writeMapEnd() +end + +function TBinaryProtocol:writeListBegin(etype, size) + self:writeByte(etype) + self:writeI32(size) +end + +function TBinaryProtocol:writeListEnd() +end + +function TBinaryProtocol:writeSetBegin(etype, size) + self:writeByte(etype) + self:writeI32(size) +end + +function TBinaryProtocol:writeSetEnd() +end + +function TBinaryProtocol:writeBool(bool) + if bool then + self:writeByte(1) + else + self:writeByte(0) + end +end + +function TBinaryProtocol:writeByte(byte) + local buff = libluabpack.bpack('c', byte) + self.trans:write(buff) +end + +function TBinaryProtocol:writeI16(i16) + local buff = libluabpack.bpack('s', i16) + self.trans:write(buff) +end + +function TBinaryProtocol:writeI32(i32) + local buff = libluabpack.bpack('i', i32) + self.trans:write(buff) +end + +function TBinaryProtocol:writeI64(i64) + local buff = libluabpack.bpack('l', i64) + self.trans:write(buff) +end + +function TBinaryProtocol:writeDouble(dub) + local buff = libluabpack.bpack('d', dub) + self.trans:write(buff) +end + +function TBinaryProtocol:writeString(str) + -- Should be utf-8 + self:writeI32(string.len(str)) + self.trans:write(str) +end + +function TBinaryProtocol:readMessageBegin() + local sz, ttype, name, seqid = self:readI32() + if sz < 0 then + local version = libluabitwise.band(sz, TBinaryProtocol.VERSION_MASK) + if version ~= TBinaryProtocol.VERSION_1 then + terror(TProtocolException:new{ + message = 'Bad version in readMessageBegin: ' .. sz + }) + end + ttype = libluabitwise.band(sz, TBinaryProtocol.TYPE_MASK) + name = self:readString() + seqid = self:readI32() + else + if self.strictRead then + terror(TProtocolException:new{message = 'No protocol version header'}) + end + name = self.trans:readAll(sz) + ttype = self:readByte() + seqid = self:readI32() + end + return name, ttype, seqid +end + +function TBinaryProtocol:readMessageEnd() +end + +function TBinaryProtocol:readStructBegin() + return nil +end + +function TBinaryProtocol:readStructEnd() +end + +function TBinaryProtocol:readFieldBegin() + local ttype = self:readByte() + if ttype == TType.STOP then + return nil, ttype, 0 + end + local id = self:readI16() + return nil, ttype, id +end + +function TBinaryProtocol:readFieldEnd() +end + +function TBinaryProtocol:readMapBegin() + local ktype = self:readByte() + local vtype = self:readByte() + local size = self:readI32() + return ktype, vtype, size +end + +function TBinaryProtocol:readMapEnd() +end + +function TBinaryProtocol:readListBegin() + local etype = self:readByte() + local size = self:readI32() + return etype, size +end + +function TBinaryProtocol:readListEnd() +end + +function TBinaryProtocol:readSetBegin() + local etype = self:readByte() + local size = self:readI32() + return etype, size +end + +function TBinaryProtocol:readSetEnd() +end + +function TBinaryProtocol:readBool() + local byte = self:readByte() + if byte == 0 then + return false + end + return true +end + +function TBinaryProtocol:readByte() + local buff = self.trans:readAll(1) + local val = libluabpack.bunpack('c', buff) + return val +end + +function TBinaryProtocol:readI16() + local buff = self.trans:readAll(2) + local val = libluabpack.bunpack('s', buff) + return val +end + +function TBinaryProtocol:readI32() + local buff = self.trans:readAll(4) + local val = libluabpack.bunpack('i', buff) + return val +end + +function TBinaryProtocol:readI64() + local buff = self.trans:readAll(8) + local val = libluabpack.bunpack('l', buff) + return val +end + +function TBinaryProtocol:readDouble() + local buff = self.trans:readAll(8) + local val = libluabpack.bunpack('d', buff) + return val +end + +function TBinaryProtocol:readString() + local len = self:readI32() + local str = self.trans:readAll(len) + return str +end + +TBinaryProtocolFactory = TProtocolFactory:new{ + __type = 'TBinaryProtocolFactory', + strictRead = false +} + +function TBinaryProtocolFactory:getProtocol(trans) + -- TODO Enforce that this must be a transport class (ie not a bool) + if not trans then + terror(TProtocolException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TBinaryProtocol:new{ + trans = trans, + strictRead = self.strictRead, + strictWrite = true + } +end diff --git a/lib/lua/TBufferedTransport.lua b/lib/lua/TBufferedTransport.lua new file mode 100644 index 00000000000..2b0b94647cf --- /dev/null +++ b/lib/lua/TBufferedTransport.lua @@ -0,0 +1,91 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' + +TBufferedTransport = TTransportBase:new{ + __type = 'TBufferedTransport', + rBufSize = 2048, + wBufSize = 2048, + wBuf = '', + rBuf = '' +} + +function TBufferedTransport:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return TTransportBase:new(obj) +end + +function TBufferedTransport:isOpen() + return self.trans:isOpen() +end + +function TBufferedTransport:open() + return self.trans:open() +end + +function TBufferedTransport:close() + return self.trans:close() +end + +function TBufferedTransport:read(len) + return self.trans:read(len) +end + +function TBufferedTransport:readAll(len) + return self.trans:readAll(len) +end + +function TBufferedTransport:write(buf) + self.wBuf = self.wBuf .. buf + if string.len(self.wBuf) >= self.wBufSize then + self.trans:write(self.wBuf) + self.wBuf = '' + end +end + +function TBufferedTransport:flush() + if string.len(self.wBuf) > 0 then + self.trans:write(self.wBuf) + self.wBuf = '' + end +end + +TBufferedTransportFactory = TTransportFactoryBase:new{ + __type = 'TBufferedTransportFactory' +} + +function TBufferedTransportFactory:getTransport(trans) + if not trans then + terror(TTransportException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TBufferedTransport:new{ + trans = trans + } +end diff --git a/lib/lua/TFramedTransport.lua b/lib/lua/TFramedTransport.lua new file mode 100644 index 00000000000..84ae3ecf2c0 --- /dev/null +++ b/lib/lua/TFramedTransport.lua @@ -0,0 +1,119 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' +require 'libluabpack' + +TFramedTransport = TTransportBase:new{ + __type = 'TFramedTransport', + doRead = true, + doWrite = true, + wBuf = '', + rBuf = '' +} + +function TFramedTransport:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return TTransportBase:new(obj) +end + +function TFramedTransport:isOpen() + return self.trans:isOpen() +end + +function TFramedTransport:open() + return self.trans:open() +end + +function TFramedTransport:close() + return self.trans:close() +end + +function TFramedTransport:read(len) + if string.len(self.rBuf) == 0 then + self:__readFrame() + end + + if self.doRead == false then + return self.trans:read(len) + end + + if len > string.len(self.rBuf) then + local val = self.rBuf + self.rBuf = '' + return val + end + + local val = string.sub(self.rBuf, 0, len) + self.rBuf = string.sub(self.rBuf, len) + return val +end + +function TFramedTransport:__readFrame() + local buf = self.trans:readAll(4) + local frame_len = libluabpack.bunpack('i', buf) + self.rBuf = self.trans:readAll(frame_len) +end + +function TFramedTransport:readAll(len) + return self.trans:readAll(len) +end + +function TFramedTransport:write(buf, len) + if self.doWrite == false then + return self.trans:write(buf, len) + end + + if len and len < string.len(buf) then + buf = string.sub(buf, 0, len) + end + self.wBuf = self.wBuf + buf +end + +function TFramedTransport:flush() + if self.doWrite == false then + return self.trans:flush() + end + + -- If the write fails we still want wBuf to be clear + local tmp = self.wBuf + self.wBuf = '' + self.trans:write(tmp) + self.trans:flush() +end + +TFramedTransportFactory = TTransportFactoryBase:new{ + __type = 'TFramedTransportFactory' +} +function TFramedTransportFactory:getTransport(trans) + if not trans then + terror(TProtocolException:new{ + message = 'Must supply a transport to ' .. ttype(self) + }) + end + return TFramedTransport:new{trans = trans} +end diff --git a/lib/lua/TMemoryBuffer.lua b/lib/lua/TMemoryBuffer.lua new file mode 100644 index 00000000000..3d4368674fd --- /dev/null +++ b/lib/lua/TMemoryBuffer.lua @@ -0,0 +1,91 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' + +TMemoryBuffer = TTransportBase:new{ + __type = 'TMemoryBuffer', + buffer = '', + bufferSize = 1024, + wPos = 0, + rPos = 0 +} +function TMemoryBuffer:isOpen() + return 1 +end +function TMemoryBuffer:open() end +function TMemoryBuffer:close() end + +function TMemoryBuffer:peak() + return self.rPos < self.wPos +end + +function TMemoryBuffer:getBuffer() + return self.buffer +end + +function TMemoryBuffer:resetBuffer(buf) + if buf then + self.buffer = buf + self.bufferSize = string.len(buf) + else + self.buffer = '' + self.bufferSize = 1024 + end + self.wPos = string.len(buf) + self.rPos = 0 +end + +function TMemoryBuffer:available() + return self.wPos - self.rPos +end + +function TMemoryBuffer:read(len) + local avail = self:available() + if avail == 0 then + return '' + end + + if avail < len then + len = avail + end + + local val = string.sub(self.buffer, self.rPos, len) + self.rPos = self.rPos + len + return val +end + +function TMemoryBuffer:readAll(len) + local avail = self:available() + + if avail < len then + local msg = string.format('Attempt to readAll(%d) found only %d available', + len, avail) + terror(TTransportException:new{message = msg}) + end + -- read should block so we don't need a loop here + return self:read(len) +end + +function TMemoryBuffer:write(buf) + self.buffer = self.buffer + buf + self.wPos = self.wPos + buf +end + +function TMemoryBuffer:flush() end diff --git a/lib/lua/TProtocol.lua b/lib/lua/TProtocol.lua new file mode 100644 index 00000000000..9eb94f59571 --- /dev/null +++ b/lib/lua/TProtocol.lua @@ -0,0 +1,162 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'Thrift' + +TProtocolException = TException:new { + UNKNOWN = 0, + INVALID_DATA = 1, + NEGATIVE_SIZE = 2, + SIZE_LIMIT = 3, + BAD_VERSION = 4, + INVALID_PROTOCOL = 5, + MISSING_REQUIRED_FIELD = 6, + errorCode = 0, + __type = 'TProtocolException' +} +function TProtocolException:__errorCodeToString() + if self.errorCode == self.INVALID_DATA then + return 'Invalid data' + elseif self.errorCode == self.NEGATIVE_SIZE then + return 'Negative size' + elseif self.errorCode == self.SIZE_LIMIT then + return 'Size limit' + elseif self.errorCode == self.BAD_VERSION then + return 'Bad version' + elseif self.errorCode == self.INVALID_PROTOCOL then + return 'Invalid protocol' + elseif self.errorCode == self.MISSING_REQUIRED_FIELD then + return 'Missing required field' + else + return 'Default (unknown)' + end +end + +TProtocolBase = __TObject:new{ + __type = 'TProtocolBase', + trans +} + +function TProtocolBase:new(obj) + if ttype(obj) ~= 'table' then + error(ttype(self) .. 'must be initialized with a table') + end + + -- Ensure a transport is provided + if not obj.trans then + error('You must provide ' .. ttype(self) .. ' with a trans') + end + + return __TObject.new(self, obj) +end + +function TProtocolBase:writeMessageBegin(name, ttype, seqid) end +function TProtocolBase:writeMessageEnd() end +function TProtocolBase:writeStructBegin(name) end +function TProtocolBase:writeStructEnd() end +function TProtocolBase:writeFieldBegin(name, ttype, id) end +function TProtocolBase:writeFieldEnd() end +function TProtocolBase:writeFieldStop() end +function TProtocolBase:writeMapBegin(ktype, vtype, size) end +function TProtocolBase:writeMapEnd() end +function TProtocolBase:writeListBegin(ttype, size) end +function TProtocolBase:writeListEnd() end +function TProtocolBase:writeSetBegin(ttype, size) end +function TProtocolBase:writeSetEnd() end +function TProtocolBase:writeBool(bool) end +function TProtocolBase:writeByte(byte) end +function TProtocolBase:writeI16(i16) end +function TProtocolBase:writeI32(i32) end +function TProtocolBase:writeI64(i64) end +function TProtocolBase:writeDouble(dub) end +function TProtocolBase:writeString(str) end +function TProtocolBase:readMessageBegin() end +function TProtocolBase:readMessageEnd() end +function TProtocolBase:readStructBegin() end +function TProtocolBase:readStructEnd() end +function TProtocolBase:readFieldBegin() end +function TProtocolBase:readFieldEnd() end +function TProtocolBase:readMapBegin() end +function TProtocolBase:readMapEnd() end +function TProtocolBase:readListBegin() end +function TProtocolBase:readListEnd() end +function TProtocolBase:readSetBegin() end +function TProtocolBase:readSetEnd() end +function TProtocolBase:readBool() end +function TProtocolBase:readByte() end +function TProtocolBase:readI16() end +function TProtocolBase:readI32() end +function TProtocolBase:readI64() end +function TProtocolBase:readDouble() end +function TProtocolBase:readString() end + +function TProtocolBase:skip(ttype) + if type == TType.STOP then + return + elseif ttype == TType.BOOL then + self:readBool() + elseif ttype == TType.BYTE then + self:readByte() + elseif ttype == TType.I16 then + self:readI16() + elseif ttype == TType.I32 then + self:readI32() + elseif ttype == TType.I64 then + self:readI64() + elseif ttype == TType.DOUBLE then + self:readDouble() + elseif ttype == TType.STRING then + self:readString() + elseif ttype == TType.STRUCT then + local name = self:readStructBegin() + while true do + local name, ttype, id = self:readFieldBegin() + if ttype == TType.STOP then + break + end + self:skip(ttype) + self:readFieldEnd() + end + self:readStructEnd() + elseif ttype == TType.MAP then + local kttype, vttype, size = self:readMapBegin() + for i = 1, size, 1 do + self:skip(kttype) + self:skip(vttype) + end + self:readMapEnd() + elseif ttype == TType.SET then + local ettype, size = self:readSetBegin() + for i = 1, size, 1 do + self:skip(ettype) + end + self:readSetEnd() + elseif ttype == TType.LIST then + local ettype, size = self:readListBegin() + for i = 1, size, 1 do + self:skip(ettype) + end + self:readListEnd() + end +end + +TProtocolFactory = __TObject:new{ + __type = 'TProtocolFactory', +} +function TProtocolFactory:getProtocol(trans) end diff --git a/lib/lua/TServer.lua b/lib/lua/TServer.lua new file mode 100644 index 00000000000..d6b9cd07658 --- /dev/null +++ b/lib/lua/TServer.lua @@ -0,0 +1,139 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'Thrift' +require 'TFramedTransport' +require 'TBinaryProtocol' + +-- TServer +TServer = __TObject:new{ + __type = 'TServer' +} + +-- 2 possible constructors +-- 1. {processor, serverTransport} +-- 2. {processor, serverTransport, transportFactory, protocolFactory} +function TServer:new(args) + if ttype(args) ~= 'table' then + error('TServer must be initialized with a table') + end + if args.processor == nil then + terror('You must provide ' .. ttype(self) .. ' with a processor') + end + if args.serverTransport == nil then + terror('You must provide ' .. ttype(self) .. ' with a serverTransport') + end + + -- Create the object + local obj = __TObject.new(self, args) + + if obj.transportFactory then + obj.inputTransportFactory = obj.transportFactory + obj.outputTransportFactory = obj.transportFactory + obj.transportFactory = nil + else + obj.inputTransportFactory = TFramedTransportFactory:new{} + obj.outputTransportFactory = obj.inputTransportFactory + end + + if obj.protocolFactory then + obj.inputProtocolFactory = obj.protocolFactory + obj.outputProtocolFactory = obj.protocolFactory + obj.protocolFactory = nil + else + obj.inputProtocolFactory = TBinaryProtocolFactory:new{} + obj.outputProtocolFactory = obj.inputProtocolFactory + end + + -- Set the __server variable in the handler so we can stop the server + obj.processor.handler.__server = self + + return obj +end + +function TServer:setServerEventHandler(handler) + self.serverEventHandler = handler +end + +function TServer:_clientBegin(content, iprot, oprot) + if self.serverEventHandler and + type(self.serverEventHandler.clientBegin) == 'function' then + self.serverEventHandler:clientBegin(iprot, oprot) + end +end + +function TServer:_preServe() + if self.serverEventHandler and + type(self.serverEventHandler.preServe) == 'function' then + self.serverEventHandler:preServe(self.serverTransport:getSocketInfo()) + end +end + +function TServer:_handleException(err) + if string.find(err, 'TTransportException') == nil then + print(err) + end +end + +function TServer:serve() end +function TServer:handle(client) + local itrans, otrans, iprot, oprot, ret, err = + self.inputTransportFactory:getTransport(client), + self.outputTransportFactory:getTransport(client), + self.inputProtocolFactory:getProtocol(client), + self.outputProtocolFactory:getProtocol(client) + + self:_clientBegin(iprot, oprot) + while true do + ret, err = pcall(self.processor.process, self.processor, iprot, oprot) + if ret == false and err then + if not string.find(err, "TTransportException") then + self:_handleException(err) + end + break + end + end + itrans:close() + otrans:close() +end + +function TServer:close() + self.serverTransport:close() +end + +-- TSimpleServer +-- Single threaded server that handles one transport (connection) +TSimpleServer = __TObject:new(TServer, { + __type = 'TSimpleServer', + __stop = false +}) + +function TSimpleServer:serve() + self.serverTransport:listen() + self:_preServe() + while not self.__stop do + client = self.serverTransport:accept() + self:handle(client) + end + self:close() +end + +function TSimpleServer:stop() + self.__stop = true +end diff --git a/lib/lua/TSocket.lua b/lib/lua/TSocket.lua new file mode 100644 index 00000000000..d71fc1f984e --- /dev/null +++ b/lib/lua/TSocket.lua @@ -0,0 +1,132 @@ +---- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'TTransport' +require 'libluasocket' + +-- TSocketBase +TSocketBase = TTransportBase:new{ + __type = 'TSocketBase', + timeout = 1000, + host = 'localhost', + port = 9090, + handle +} + +function TSocketBase:close() + if self.handle then + self.handle:destroy() + self.handle = nil + end +end + +-- Returns a table with the fields host and port +function TSocketBase:getSocketInfo() + if self.handle then + return self.handle:getsockinfo() + end + terror(TTransportException:new{errorCode = TTransportException.NOT_OPEN}) +end + +function TSocketBase:setTimeout(timeout) + if timeout and ttype(timeout) == 'number' then + if self.handle then + self.handle:settimeout(timeout) + end + self.timeout = timeout + end +end + +-- TSocket +TSocket = TSocketBase:new{ + __type = 'TSocket', + host = 'localhost', + port = 9090 +} + +function TSocket:isOpen() + if self.handle then + return true + end + return false +end + +function TSocket:open() + if self.handle then + self:close() + end + + -- Create local handle + local sock, err = luasocket.create_and_connect( + self.host, self.port, self.timeout) + if err == nil then + self.handle = sock + end + + if err then + terror(TTransportException:new{ + message = 'Could not connect to ' .. self.host .. ':' .. self.port + .. ' (' .. err .. ')' + }) + end +end + +function TSocket:read(len) + local buf = self.handle:receive(self.handle, len) + if not buf or string.len(buf) ~= len then + terror(TTransportException:new{errorCode = TTransportException.UNKNOWN}) + end + return buf +end + +function TSocket:write(buf) + self.handle:send(self.handle, buf) +end + +function TSocket:flush() +end + +-- TServerSocket +TServerSocket = TSocketBase:new{ + __type = 'TServerSocket', + host = 'localhost', + port = 9090 +} + +function TServerSocket:listen() + if self.handle then + self:close() + end + + local sock, err = luasocket.create(self.host, self.port) + if not err then + self.handle = sock + else + terror(err) + end + self.handle:settimeout(self.timeout) + self.handle:listen() +end + +function TServerSocket:accept() + local client, err = self.handle:accept() + if err then + terror(err) + end + return TSocket:new({handle = client}) +end diff --git a/lib/lua/TTransport.lua b/lib/lua/TTransport.lua new file mode 100644 index 00000000000..01c7e597971 --- /dev/null +++ b/lib/lua/TTransport.lua @@ -0,0 +1,93 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +require 'Thrift' + +TTransportException = TException:new { + UNKNOWN = 0, + NOT_OPEN = 1, + ALREADY_OPEN = 2, + TIMED_OUT = 3, + END_OF_FILE = 4, + INVALID_FRAME_SIZE = 5, + INVALID_TRANSFORM = 6, + INVALID_CLIENT_TYPE = 7, + errorCode = 0, + __type = 'TTransportException' +} + +function TTransportException:__errorCodeToString() + if self.errorCode == self.NOT_OPEN then + return 'Transport not open' + elseif self.errorCode == self.ALREADY_OPEN then + return 'Transport already open' + elseif self.errorCode == self.TIMED_OUT then + return 'Transport timed out' + elseif self.errorCode == self.END_OF_FILE then + return 'End of file' + elseif self.errorCode == self.INVALID_FRAME_SIZE then + return 'Invalid frame size' + elseif self.errorCode == self.INVALID_TRANSFORM then + return 'Invalid transform' + elseif self.errorCode == self.INVALID_CLIENT_TYPE then + return 'Invalid client type' + else + return 'Default (unknown)' + end +end + +TTransportBase = __TObject:new{ + __type = 'TTransportBase' +} + +function TTransportBase:isOpen() end +function TTransportBase:open() end +function TTransportBase:close() end +function TTransportBase:read(len) end +function TTransportBase:readAll(len) + local buf, have, chunk = '', 0 + while have < len do + chunk = self:read(len - have) + have = have + string.len(chunk) + buf = buf .. chunk + + if string.len(chunk) == 0 then + terror(TTransportException:new{ + errorCode = TTransportException.END_OF_FILE + }) + end + end + return buf +end +function TTransportBase:write(buf) end +function TTransportBase:flush() end + +TServerTransportBase = __TObject:new{ + __type = 'TServerTransportBase' +} +function TServerTransportBase:listen() end +function TServerTransportBase:accept() end +function TServerTransportBase:close() end + +TTransportFactoryBase = __TObject:new{ + __type = 'TTransportFactoryBase' +} +function TTransportFactoryBase:getTransport(trans) + return trans +end diff --git a/lib/lua/Thrift.lua b/lib/lua/Thrift.lua new file mode 100644 index 00000000000..6ff8ecbc181 --- /dev/null +++ b/lib/lua/Thrift.lua @@ -0,0 +1,273 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +---- namespace thrift +--thrift = {} +--setmetatable(thrift, {__index = _G}) --> perf hit for accessing global methods +--setfenv(1, thrift) + +package.cpath = package.cpath .. ';bin/?.so' -- TODO FIX +function ttype(obj) + if type(obj) == 'table' and + obj.__type and + type(obj.__type) == 'string' then + return obj.__type + end + return type(obj) +end + +function terror(e) + if e and e.__tostring then + error(e:__tostring()) + return + end + error(e) +end + +version = 1.0 + +TType = { + STOP = 0, + VOID = 1, + BOOL = 2, + BYTE = 3, + I08 = 3, + DOUBLE = 4, + I16 = 6, + I32 = 8, + I64 = 10, + STRING = 11, + UTF7 = 11, + STRUCT = 12, + MAP = 13, + SET = 14, + LIST = 15, + UTF8 = 16, + UTF16 = 17 +} + +TMessageType = { + CALL = 1, + REPLY = 2, + EXCEPTION = 3, + ONEWAY = 4 +} + +-- Recursive __index function to achive inheritance +function __tobj_index(self, key) + local v = rawget(self, key) + if v ~= nil then + return v + end + + local p = rawget(self, '__parent') + if p then + return __tobj_index(p, key) + end + + return nil +end + +-- Basic Thrift-Lua Object +__TObject = { + __type = '__TObject', + __mt = { + __index = __tobj_index + } +} +function __TObject:new(init_obj) + local obj = {} + if ttype(obj) == 'table' then + obj = init_obj + end + + -- Use the __parent key and the __index function to achieve inheritance + obj.__parent = self + setmetatable(obj, __TObject.__mt) + return obj +end + +-- Return a string representation of any lua variable +function thrift_print_r(t) + local ret = '' + local ltype = type(t) + if (ltype == 'table') then + ret = ret .. '{ ' + for key,value in pairs(t) do + ret = ret .. tostring(key) .. '=' .. thrift_print_r(value) .. ' ' + end + ret = ret .. '}' + elseif ltype == 'string' then + ret = ret .. "'" .. tostring(t) .. "'" + else + ret = ret .. tostring(t) + end + return ret +end + +-- Basic Exception +TException = __TObject:new{ + message, + errorCode, + __type = 'TException' +} +function TException:__tostring() + if self.message then + return string.format('%s: %s', self.__type, self.message) + else + local message + if self.errorCode and self.__errorCodeToString then + message = string.format('%d: %s', self.errorCode, self:__errorCodeToString()) + else + message = thrift_print_r(self) + end + return string.format('%s:%s', self.__type, message) + end +end + +TApplicationException = TException:new{ + UNKNOWN = 0, + UNKNOWN_METHOD = 1, + INVALID_MESSAGE_TYPE = 2, + WRONG_METHOD_NAME = 3, + BAD_SEQUENCE_ID = 4, + MISSING_RESULT = 5, + INTERNAL_ERROR = 6, + PROTOCOL_ERROR = 7, + INVALID_TRANSFORM = 8, + INVALID_PROTOCOL = 9, + UNSUPPORTED_CLIENT_TYPE = 10, + errorCode = 0, + __type = 'TApplicationException' +} + +function TApplicationException:__errorCodeToString() + if self.errorCode == self.UNKNOWN_METHOD then + return 'Unknown method' + elseif self.errorCode == self.INVALID_MESSAGE_TYPE then + return 'Invalid message type' + elseif self.errorCode == self.WRONG_METHOD_NAME then + return 'Wrong method name' + elseif self.errorCode == self.BAD_SEQUENCE_ID then + return 'Bad sequence ID' + elseif self.errorCode == self.MISSING_RESULT then + return 'Missing result' + elseif self.errorCode == self.INTERNAL_ERROR then + return 'Internal error' + elseif self.errorCode == self.PROTOCOL_ERROR then + return 'Protocol error' + elseif self.errorCode == self.INVALID_TRANSFORM then + return 'Invalid transform' + elseif self.errorCode == self.INVALID_PROTOCOL then + return 'Invalid protocol' + elseif self.errorCode == self.UNSUPPORTED_CLIENT_TYPE then + return 'Unsupported client type' + else + return 'Default (unknown)' + end +end + +function TException:read(iprot) + iprot:readStructBegin() + while true do + local fname, ftype, fid = iprot:readFieldBegin() + if ftype == TType.STOP then + break + elseif fid == 1 then + if ftype == TType.STRING then + self.message = iprot:readString() + else + iprot:skip(ftype) + end + elseif fid == 2 then + if ftype == TType.I32 then + self.errorCode = iprot:readI32() + else + iprot:skip(ftype) + end + else + iprot:skip(ftype) + end + iprot:readFieldEnd() + end + iprot:readStructEnd() +end + +function TException:write(oprot) + oprot:writeStructBegin('TApplicationException') + if self.message then + oprot:writeFieldBegin('message', TType.STRING, 1) + oprot:writeString(self.message) + oprot:writeFieldEnd() + end + if self.errorCode then + oprot:writeFieldBegin('type', TType.I32, 2) + oprot:writeI32(self.errorCode) + oprot:writeFieldEnd() + end + oprot:writeFieldStop() + oprot:writeStructEnd() +end + +-- Basic Client (used in generated lua code) +__TClient = __TObject:new{ + __type = '__TClient', + _seqid = 0 +} +function __TClient:new(obj) + if ttype(obj) ~= 'table' then + error('TClient must be initialized with a table') + end + + -- Set iprot & oprot + if obj.protocol then + obj.iprot = obj.protocol + obj.oprot = obj.protocol + obj.protocol = nil + elseif not obj.iprot then + error('You must provide ' .. ttype(self) .. ' with an iprot') + end + if not obj.oprot then + obj.oprot = obj.iprot + end + + return __TObject.new(self, obj) +end + +function __TClient:close() + self.iprot.trans:close() + self.oprot.trans:close() +end + +-- Basic Processor (used in generated lua code) +__TProcessor = __TObject:new{ + __type = '__TProcessor' +} +function __TProcessor:new(obj) + if ttype(obj) ~= 'table' then + error('TProcessor must be initialized with a table') + end + + -- Ensure a handler is provided + if not obj.handler then + error('You must provide ' .. ttype(self) .. ' with a handler') + end + + return __TObject.new(self, obj) +end diff --git a/lib/lua/src/longnumberutils.c b/lib/lua/src/longnumberutils.c new file mode 100644 index 00000000000..fbc678900c6 --- /dev/null +++ b/lib/lua/src/longnumberutils.c @@ -0,0 +1,47 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include +#include +#include +#include + +const char * LONG_NUM_TYPE = "__thrift_longnumber"; +int64_t lualongnumber_checklong(lua_State *L, int index) { + switch (lua_type(L, index)) { + case LUA_TNUMBER: + return (int64_t)lua_tonumber(L, index); + case LUA_TSTRING: + return atoll(lua_tostring(L, index)); + default: + return *((int64_t *)luaL_checkudata(L, index, LONG_NUM_TYPE)); + } +} + +// Creates a new longnumber and pushes it onto the statck +int64_t * lualongnumber_pushlong(lua_State *L, int64_t *val) { + int64_t *data = (int64_t *)lua_newuserdata(L, sizeof(int64_t)); // longnum + luaL_getmetatable(L, LONG_NUM_TYPE); // longnum, mt + lua_setmetatable(L, -2); // longnum + if (val) { + *data = *val; + } + return data; +} + diff --git a/lib/lua/src/luabitwise.c b/lib/lua/src/luabitwise.c new file mode 100644 index 00000000000..2e07e1724ce --- /dev/null +++ b/lib/lua/src/luabitwise.c @@ -0,0 +1,83 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include +#include + +static int l_not(lua_State *L) { + int a = luaL_checkinteger(L, 1); + a = ~a; + lua_pushnumber(L, a); + return 1; +} + +static int l_xor(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a ^= b; + lua_pushnumber(L, a); + return 1; +} + +static int l_and(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a &= b; + lua_pushnumber(L, a); + return 1; +} + +static int l_or(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a |= b; + lua_pushnumber(L, a); + return 1; +} + +static int l_shiftr(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a = a >> b; + lua_pushnumber(L, a); + return 1; +} + +static int l_shiftl(lua_State *L) { + int a = luaL_checkinteger(L, 1); + int b = luaL_checkinteger(L, 2); + a = a << b; + lua_pushnumber(L, a); + return 1; +} + +static const struct luaL_Reg funcs[] = { + {"band", l_and}, + {"bor", l_or}, + {"bxor", l_xor}, + {"bnot", l_not}, + {"shiftl", l_shiftl}, + {"shiftr", l_shiftr}, + {NULL, NULL} +}; + +int luaopen_libluabitwise(lua_State *L) { + luaL_register(L, "libluabitwise", funcs); + return 1; +} diff --git a/lib/lua/src/luabpack.c b/lib/lua/src/luabpack.c new file mode 100644 index 00000000000..c936428cd93 --- /dev/null +++ b/lib/lua/src/luabpack.c @@ -0,0 +1,162 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include +#include +#include +#include +#include + +extern int64_t lualongnumber_checklong(lua_State *L, int index); +extern int64_t lualongnumber_pushlong(lua_State *L, int64_t *val); + +// host order to network order (64-bit) +static int64_t T_htonll(uint64_t data) { + uint32_t d1 = htonl((uint32_t)data); + uint32_t d2 = htonl((uint32_t)(data >> 32)); + return ((uint64_t)d1 << 32) + (uint64_t)d2; +} + +// network order to host order (64-bit) +static int64_t T_ntohll(uint64_t data) { + uint32_t d1 = ntohl((uint32_t)data); + uint32_t d2 = ntohl((uint32_t)(data >> 32)); + return ((uint64_t)d1 << 32) + (uint64_t)d2; +} + +/** + * bpack(type, data) + * c - Signed Byte + * s - Signed Short + * i - Signed Int + * l - Signed Long + * d - Double + */ +static int l_bpack(lua_State *L) { + const char *code = luaL_checkstring(L, 1); + luaL_argcheck(L, code[1] == '\0', 0, "Format code must be one character."); + luaL_Buffer buf; + luaL_buffinit(L, &buf); + + switch (code[0]) { + case 'c': { + int8_t data = luaL_checknumber(L, 2); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 's': { + int16_t data = luaL_checknumber(L, 2); + data = (int16_t)htons(data); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 'i': { + int32_t data = luaL_checkinteger(L, 2); + data = (int32_t)htonl(data); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 'l': { + int64_t data = lualongnumber_checklong(L, 2); + data = (int64_t)T_htonll(data); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + case 'd': { + double data = luaL_checknumber(L, 2); + luaL_addlstring(&buf, (void*)&data, sizeof(data)); + break; + } + default: + luaL_argcheck(L, 0, 0, "Invalid format code."); + } + + luaL_pushresult(&buf); + return 1; +} + +/** + * bunpack(type, data) + * c - Signed Byte + * s - Signed Short + * i - Signed Int + * l - Signed Long + * d - Double + */ +static int l_bunpack(lua_State *L) { + const char *code = luaL_checkstring(L, 1); + luaL_argcheck(L, code[1] == '\0', 0, "Format code must be one character."); + const char *data = luaL_checkstring(L, 2); + size_t len = lua_rawlen(L, 2); + + switch (code[0]) { + case 'c': { + int8_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + lua_pushnumber(L, val); + break; + } + case 's': { + int16_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + val = (int16_t)ntohs(val); + lua_pushnumber(L, val); + break; + } + case 'i': { + int32_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + val = (int32_t)ntohl(val); + lua_pushnumber(L, val); + break; + } + case 'l': { + int64_t val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + val = (int64_t)T_ntohll(val); + lualongnumber_pushlong(L, &val); + break; + } + case 'd': { + double val; + luaL_argcheck(L, len == sizeof(val), 1, "Invalid input string size."); + memcpy(&val, data, sizeof(val)); + lua_pushnumber(L, val); + break; + } + default: + luaL_argcheck(L, 0, 0, "Invalid format code."); + } + return 1; +} + +static const struct luaL_Reg lua_bpack[] = { + {"bpack", l_bpack}, + {"bunpack", l_bunpack}, + {NULL, NULL} +}; + +int luaopen_libluabpack(lua_State *L) { + luaL_register(L, "libluabpack", lua_bpack); + return 1; +} diff --git a/lib/lua/src/lualongnumber.c b/lib/lua/src/lualongnumber.c new file mode 100644 index 00000000000..9001e4a90dd --- /dev/null +++ b/lib/lua/src/lualongnumber.c @@ -0,0 +1,228 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include +#include +#include +#include +#include +#include + +extern const char * LONG_NUM_TYPE; +extern int64_t lualongnumber_checklong(lua_State *L, int index); +extern int64_t lualongnumber_pushlong(lua_State *L, int64_t *val); + +//////////////////////////////////////////////////////////////////////////////// + +static void l_serialize(char *buf, int len, int64_t val) { + snprintf(buf, len, "%"PRId64, val); +} + +static int64_t l_deserialize(const char *buf) { + int64_t data; + int rv; + // Support hex prefixed with '0x' + if (strstr(buf, "0x") == buf) { + rv = sscanf(buf, "%"PRIx64, &data); + } else { + rv = sscanf(buf, "%"PRId64, &data); + } + if (rv == 1) { + return data; + } + return 0; // Failed +} + +//////////////////////////////////////////////////////////////////////////////// + +static int l_new(lua_State *L) { + int64_t val; + const char *str = NULL; + if (lua_type(L, 1) == LUA_TSTRING) { + str = lua_tostring(L, 1); + val = l_deserialize(str); + } else if (lua_type(L, 1) == LUA_TNUMBER) { + val = (int64_t)lua_tonumber(L, 1); + str = (const char *)1; + } + lualongnumber_pushlong(L, (str ? &val : NULL)); + return 1; +} + +//////////////////////////////////////////////////////////////////////////////// + +// a + b +static int l_add(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a + b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a / b +static int l_div(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a / b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a == b (both a and b are lualongnumber's) +static int l_eq(lua_State *L) { + int64_t a, b; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + lua_pushboolean(L, (a == b ? 1 : 0)); + return 1; +} + +// garbage collection +static int l_gc(lua_State *L) { + lua_pushnil(L); + lua_setmetatable(L, 1); + return 0; +} + +// a < b +static int l_lt(lua_State *L) { + int64_t a, b; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + lua_pushboolean(L, (a < b ? 1 : 0)); + return 1; +} + +// a <= b +static int l_le(lua_State *L) { + int64_t a, b; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + lua_pushboolean(L, (a <= b ? 1 : 0)); + return 1; +} + +// a % b +static int l_mod(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a % b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a * b +static int l_mul(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a * b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// a ^ b +static int l_pow(lua_State *L) { + long double a, b; + int64_t c; + a = (long double)lualongnumber_checklong(L, 1); + b = (long double)lualongnumber_checklong(L, 2); + c = (int64_t)pow(a, b); + lualongnumber_pushlong(L, &c); + return 1; +} + +// a - b +static int l_sub(lua_State *L) { + int64_t a, b, c; + a = lualongnumber_checklong(L, 1); + b = lualongnumber_checklong(L, 2); + c = a - b; + lualongnumber_pushlong(L, &c); + return 1; +} + +// tostring() +static int l_tostring(lua_State *L) { + int64_t a; + char str[256]; + l_serialize(str, 256, lualongnumber_checklong(L, 1)); + lua_pushstring(L, str); + return 1; +} + +// -a +static int l_unm(lua_State *L) { + int64_t a, c; + a = lualongnumber_checklong(L, 1); + c = -a; + lualongnumber_pushlong(L, &c); + return 1; +} + +//////////////////////////////////////////////////////////////////////////////// + +static const luaL_Reg methods[] = { + {"__add", l_add}, + {"__div", l_div}, + {"__eq", l_eq}, + {"__gc", l_gc}, + {"__lt", l_lt}, + {"__le", l_le}, + {"__mod", l_mod}, + {"__mul", l_mul}, + {"__pow", l_pow}, + {"__sub", l_sub}, + {"__tostring", l_tostring}, + {"__unm", l_unm}, + {NULL, NULL}, +}; + +static const luaL_Reg funcs[] = { + {"new", l_new}, + {NULL, NULL} +}; + +//////////////////////////////////////////////////////////////////////////////// + +static void set_methods(lua_State *L, + const char *metatablename, + const struct luaL_Reg *methods) { + luaL_getmetatable(L, metatablename); // mt + // No need for a __index table since everything is __* + for (; methods->name; methods++) { + lua_pushstring(L, methods->name); // mt, "name" + lua_pushcfunction(L, methods->func); // mt, "name", func + lua_rawset(L, -3); // mt + } + lua_pop(L, 1); +} + +LUALIB_API int luaopen_liblualongnumber(lua_State *L) { + luaL_newmetatable(L, LONG_NUM_TYPE); + lua_pop(L, 1); + set_methods(L, LONG_NUM_TYPE, methods); + + luaL_register(L, "liblualongnumber", funcs); + return 1; +} diff --git a/lib/lua/src/luasocket.c b/lib/lua/src/luasocket.c new file mode 100644 index 00000000000..c8a678ff511 --- /dev/null +++ b/lib/lua/src/luasocket.c @@ -0,0 +1,386 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include +#include + +#include +#include "string.h" +#include "socket.h" + +//////////////////////////////////////////////////////////////////////////////// + +static const char *SOCKET_ANY = "__thrift_socket_any"; +static const char *SOCKET_CONN = "__thrift_socket_connected"; + +static const char *SOCKET_GENERIC = "__thrift_socket_generic"; +static const char *SOCKET_CLIENT = "__thrift_socket_client"; +static const char *SOCKET_SERVER = "__thrift_socket_server"; + +static const char *DEFAULT_HOST = "localhost"; + +typedef struct __t_tcp { + t_socket sock; + int timeout; // Milliseconds +} t_tcp; +typedef t_tcp *p_tcp; + +//////////////////////////////////////////////////////////////////////////////// +// Util + +static void throw_argerror(lua_State *L, int index, const char *expected) { + char msg[256]; + sprintf(msg, "%s expected, got %s", expected, luaL_typename(L, index)); + luaL_argerror(L, index, msg); +} + +static void *checkgroup(lua_State *L, int index, const char *groupname) { + if (!lua_getmetatable(L, index)) { + throw_argerror(L, index, groupname); + } + + lua_pushstring(L, groupname); + lua_rawget(L, -2); + if (lua_isnil(L, -1)) { + lua_pop(L, 2); + throw_argerror(L, index, groupname); + } else { + lua_pop(L, 2); + return lua_touserdata(L, index); + } + return NULL; // Not reachable +} + +static void *checktype(lua_State *L, int index, const char *typename) { + if (strcmp(typename, SOCKET_ANY) == 0 || + strcmp(typename, SOCKET_CONN) == 0) { + return checkgroup(L, index, typename); + } else { + return luaL_checkudata(L, index, typename); + } +} + +static void settype(lua_State *L, int index, const char *typename) { + luaL_getmetatable(L, typename); + lua_setmetatable(L, index); +} + +#define LUA_SUCCESS_RETURN(L) \ + lua_pushnumber(L, 1); \ + return 1 + +#define LUA_CHECK_RETURN(L, err) \ + if (err) { \ + lua_pushnil(L); \ + lua_pushstring(L, err); \ + return 2; \ + } \ + LUA_SUCCESS_RETURN(L) + +//////////////////////////////////////////////////////////////////////////////// + +static int l_socket_create(lua_State *L); +static int l_socket_destroy(lua_State *L); +static int l_socket_settimeout(lua_State *L); +static int l_socket_getsockinfo(lua_State *L); + +static int l_socket_accept(lua_State *L); +static int l_socket_listen(lua_State *L); + +static int l_socket_create_and_connect(lua_State *L); +static int l_socket_connect(lua_State *L); +static int l_socket_send(lua_State *L); +static int l_socket_receive(lua_State *L); + +//////////////////////////////////////////////////////////////////////////////// + +static const struct luaL_Reg methods_generic[] = { + {"destroy", l_socket_destroy}, + {"settimeout", l_socket_settimeout}, + {"getsockinfo", l_socket_getsockinfo}, + {"listen", l_socket_listen}, + {"connect", l_socket_connect}, + {NULL, NULL} +}; + +static const struct luaL_Reg methods_server[] = { + {"destroy", l_socket_destroy}, + {"getsockinfo", l_socket_getsockinfo}, + {"accept", l_socket_accept}, + {"send", l_socket_send}, + {"receive", l_socket_receive}, + {NULL, NULL} +}; + +static const struct luaL_Reg methods_client[] = { + {"destroy", l_socket_destroy}, + {"settimeout", l_socket_settimeout}, + {"getsockinfo", l_socket_getsockinfo}, + {"send", l_socket_send}, + {"receive", l_socket_receive}, + {NULL, NULL} +}; + +static const struct luaL_Reg funcs_luasocket[] = { + {"create", l_socket_create}, + {"create_and_connect", l_socket_create_and_connect}, + {NULL, NULL} +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Check/enforce inheritance +static void add_to_group(lua_State *L, + const char *metatablename, + const char *groupname) { + luaL_getmetatable(L, metatablename); // mt + lua_pushstring(L, groupname); // mt, "name" + lua_pushboolean(L, 1); // mt, "name", true + lua_rawset(L, -3); // mt + lua_pop(L, 1); +} + +static void set_methods(lua_State *L, + const char *metatablename, + const struct luaL_Reg *methods) { + luaL_getmetatable(L, metatablename); // mt + // Create the __index table + lua_pushstring(L, "__index"); // mt, "__index" + lua_newtable(L); // mt, "__index", t + for (; methods->name; methods++) { + lua_pushstring(L, methods->name); // mt, "__index", t, "name" + lua_pushcfunction(L, methods->func); // mt, "__index", t, "name", func + lua_rawset(L, -3); // mt, "__index", t + } + lua_rawset(L, -3); // mt + lua_pop(L, 1); +} + +int luaopen_libluasocket(lua_State *L) { + luaL_newmetatable(L, SOCKET_GENERIC); + luaL_newmetatable(L, SOCKET_CLIENT); + luaL_newmetatable(L, SOCKET_SERVER); + lua_pop(L, 3); + add_to_group(L, SOCKET_GENERIC, SOCKET_ANY); + add_to_group(L, SOCKET_CLIENT, SOCKET_ANY); + add_to_group(L, SOCKET_SERVER, SOCKET_ANY); + add_to_group(L, SOCKET_CLIENT, SOCKET_CONN); + add_to_group(L, SOCKET_SERVER, SOCKET_CONN); + set_methods(L, SOCKET_GENERIC, methods_generic); + set_methods(L, SOCKET_CLIENT, methods_client); + set_methods(L, SOCKET_SERVER, methods_server); + + luaL_register(L, "luasocket", funcs_luasocket); + return 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// General + +// sock,err create(bind_host, bind_port) +// sock,err create(bind_host) -> any port +// sock,err create() -> any port on localhost +static int l_socket_create(lua_State *L) { + const char *err; + t_socket sock; + const char *addr = lua_tostring(L, 1); + if (!addr) { + addr = DEFAULT_HOST; + } + unsigned short port = lua_tonumber(L, 2); + err = tcp_create(&sock); + if (!err) { + err = tcp_bind(&sock, addr, port); // bind on create + if (err) { + tcp_destroy(&sock); + } else { + p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + settype(L, -2, SOCKET_GENERIC); + socket_setnonblocking(&sock); + tcp->sock = sock; + tcp->timeout = 0; + return 1; // Return userdata + } + } + LUA_CHECK_RETURN(L, err); +} + +// destroy() +static int l_socket_destroy(lua_State *L) { + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_ANY); + const char *err = tcp_destroy(&tcp->sock); + LUA_CHECK_RETURN(L, err); +} + +// send(socket, data) +static int l_socket_send(lua_State *L) { + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_CONN); + p_tcp tcp = (p_tcp) checktype(L, 2, SOCKET_CONN); + size_t len; + const char *data = luaL_checklstring(L, 3, &len); + const char *err = + tcp_send(&tcp->sock, data, len, tcp->timeout); + LUA_CHECK_RETURN(L, err); +} + +#define LUA_READ_STEP 8192 +static int l_socket_receive(lua_State *L) { + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_CONN); + p_tcp handle = (p_tcp) checktype(L, 2, SOCKET_CONN); + size_t len = luaL_checknumber(L, 3); + char buf[LUA_READ_STEP]; + const char *err = NULL; + int received; + size_t got = 0, step = 0; + luaL_Buffer b; + + luaL_buffinit(L, &b); + do { + step = (LUA_READ_STEP < len - got ? LUA_READ_STEP : len - got); + err = tcp_raw_receive(&handle->sock, buf, step, self->timeout, &received); + if (err == NULL) { + luaL_addlstring(&b, buf, received); + got += received; + } + } while (err == NULL && got < len); + + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + luaL_pushresult(&b); + return 1; +} + +// settimeout(timeout) +static int l_socket_settimeout(lua_State *L) { + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_ANY); + int timeout = luaL_checknumber(L, 2); + self->timeout = timeout; + LUA_SUCCESS_RETURN(L); +} + +// table getsockinfo() +static int l_socket_getsockinfo(lua_State *L) { + char buf[256]; + short port = 0; + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_ANY); + if (socket_get_info(&tcp->sock, &port, buf, 256) == SUCCESS) { + lua_newtable(L); // t + lua_pushstring(L, "host"); // t, "host" + lua_pushstring(L, buf); // t, "host", buf + lua_rawset(L, -3); // t + lua_pushstring(L, "port"); // t, "port" + lua_pushnumber(L, port); // t, "port", port + lua_rawset(L, -3); // t + return 1; + } + return 0; +} + +//////////////////////////////////////////////////////////////////////////////// +// Server + +// accept() +static int l_socket_accept(lua_State *L) { + const char *err; + p_tcp self = (p_tcp) checktype(L, 1, SOCKET_SERVER); + t_socket sock; + err = tcp_accept(&self->sock, &sock, self->timeout); + if (!err) { // Success + // Create a reference to the client + p_tcp client = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + settype(L, 2, SOCKET_CLIENT); + socket_setnonblocking(&sock); + client->sock = sock; + client->timeout = self->timeout; + return 1; + } + LUA_CHECK_RETURN(L, err); +} + +static int l_socket_listen(lua_State *L) { + const char* err; + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_GENERIC); + int backlog = 10; + err = tcp_listen(&tcp->sock, backlog); + if (!err) { + // Set the current as a server + settype(L, 1, SOCKET_SERVER); // Now a server + } + LUA_CHECK_RETURN(L, err); +} + +//////////////////////////////////////////////////////////////////////////////// +// Client + +// create_and_connect(host, port, timeout) +extern double __gettime(); +static int l_socket_create_and_connect(lua_State *L) { + const char* err = NULL; + double end; + t_socket sock; + const char *host = luaL_checkstring(L, 1); + unsigned short port = luaL_checknumber(L, 2); + int timeout = luaL_checknumber(L, 3); + + // Create and connect loop for timeout milliseconds + end = __gettime() + timeout/1000; + do { + // Create the socket + err = tcp_create(&sock); + if (!err) { + // Bind to any port on localhost + err = tcp_bind(&sock, DEFAULT_HOST, 0); + if (err) { + tcp_destroy(&sock); + } else { + // Connect + err = tcp_connect(&sock, host, port, timeout); + if (err) { + tcp_destroy(&sock); + usleep(100000); // sleep for 100ms + } else { + p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + settype(L, -2, SOCKET_CLIENT); + socket_setnonblocking(&sock); + tcp->sock = sock; + tcp->timeout = timeout; + return 1; // Return userdata + } + } + } + } while (err && __gettime() < end); + + LUA_CHECK_RETURN(L, err); +} + +// connect(host, port) +static int l_socket_connect(lua_State *L) { + const char *err; + p_tcp tcp = (p_tcp) checktype(L, 1, SOCKET_GENERIC); + const char *host = luaL_checkstring(L, 2); + unsigned short port = luaL_checknumber(L, 3); + err = tcp_connect(&tcp->sock, host, port, tcp->timeout); + if (!err) { + settype(L, 1, SOCKET_CLIENT); // Now a client + } + LUA_CHECK_RETURN(L, err); +} diff --git a/lib/lua/src/socket.h b/lib/lua/src/socket.h new file mode 100644 index 00000000000..8019ffed8e3 --- /dev/null +++ b/lib/lua/src/socket.h @@ -0,0 +1,78 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#ifndef LUA_THRIFT_SOCKET_H +#define LUA_THRIFT_SOCKET_H + +#include + +#ifdef _WIN32 +// SOL +#else +typedef int t_socket; +typedef t_socket* p_socket; +#endif + +// Error Codes +enum { + SUCCESS = 0, + TIMEOUT = -1, + CLOSED = -2, +}; +typedef int T_ERRCODE; + +static const char * TIMEOUT_MSG = "Timeout"; +static const char * CLOSED_MSG = "Connection Closed"; + +typedef struct sockaddr t_sa; +typedef t_sa * p_sa; + +T_ERRCODE socket_create(p_socket sock, int domain, int type, int protocol); +T_ERRCODE socket_destroy(p_socket sock); +T_ERRCODE socket_bind(p_socket sock, p_sa addr, int addr_len); +T_ERRCODE socket_get_info(p_socket sock, short *port, char *buf, size_t len); +T_ERRCODE socket_send(p_socket sock, const char *data, size_t len, int timeout); +T_ERRCODE socket_recv(p_socket sock, char *data, size_t len, int timeout, + int *received); + +void socket_setblocking(p_socket sock); +void socket_setnonblocking(p_socket sock); + +T_ERRCODE socket_accept(p_socket sock, p_socket sibling, + p_sa addr, socklen_t *addr_len, int timeout); +T_ERRCODE socket_listen(p_socket sock, int backlog); + +T_ERRCODE socket_connect(p_socket sock, p_sa addr, int addr_len, int timeout); + +const char * tcp_create(p_socket sock); +const char * tcp_destroy(p_socket sock); +const char * tcp_bind(p_socket sock, const char *host, unsigned short port); +const char * tcp_send(p_socket sock, const char *data, size_t w_len, + int timeout); +const char * tcp_receive(p_socket sock, char *data, size_t r_len, int timeout); +const char * tcp_raw_receive(p_socket sock, char * data, size_t r_len, + int timeout, int *received); + +const char * tcp_listen(p_socket sock, int backlog); +const char * tcp_accept(p_socket sock, p_socket client, int timeout); + +const char * tcp_connect(p_socket sock, const char *host, unsigned short port, + int timeout); + +#endif diff --git a/lib/lua/src/usocket.c b/lib/lua/src/usocket.c new file mode 100644 index 00000000000..be696e06e23 --- /dev/null +++ b/lib/lua/src/usocket.c @@ -0,0 +1,362 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include // TODO REMOVE + +#include "socket.h" + +//////////////////////////////////////////////////////////////////////////////// +// Private + +// Num seconds since Jan 1 1970 (UTC) +#ifdef _WIN32 +// SOL +#else + double __gettime() { + struct timeval v; + gettimeofday(&v, (struct timezone*) NULL); + return v.tv_sec + v.tv_usec/1.0e6; + } +#endif + +#define WAIT_MODE_R 1 +#define WAIT_MODE_W 2 +#define WAIT_MODE_C (WAIT_MODE_R|WAIT_MODE_W) +T_ERRCODE socket_wait(p_socket sock, int mode, int timeout) { + int ret = 0; + fd_set rfds, wfds; + struct timeval tv; + double end, t; + if (!timeout) { + return TIMEOUT; + } + + end = __gettime() + timeout/1000; + do { + // Specify what I/O operations we care about + if (mode & WAIT_MODE_R) { + FD_ZERO(&rfds); + FD_SET(*sock, &rfds); + } + if (mode & WAIT_MODE_W) { + FD_ZERO(&wfds); + FD_SET(*sock, &wfds); + } + + // Check for timeout + t = end - __gettime(); + if (t < 0.0) { + break; + } + + // Wait + tv.tv_sec = (int)t; + tv.tv_usec = (int)((t - tv.tv_sec) * 1.0e6); + ret = select(*sock+1, &rfds, &wfds, NULL, &tv); + } while (ret == -1 && errno == EINTR); + if (ret == -1) { + return errno; + } + + // Check for timeout + if (ret == 0) { + return TIMEOUT; + } + + // Verify that we can actually read from the remote host + if (mode & WAIT_MODE_C && FD_ISSET(*sock, &rfds) && + recv(*sock, (char*) &rfds, 0, 0) != 0) { + return errno; + } + + return SUCCESS; +} + +//////////////////////////////////////////////////////////////////////////////// +// General + +T_ERRCODE socket_create(p_socket sock, int domain, int type, int protocol) { + *sock = socket(domain, type, protocol); + if (*sock > 0) { + return SUCCESS; + } else { + return errno; + } +} + +T_ERRCODE socket_destroy(p_socket sock) { + // TODO Figure out if I should be free-ing this + if (*sock > 0) { + socket_setblocking(sock); + close(*sock); + *sock = -1; + } + return SUCCESS; +} + +T_ERRCODE socket_bind(p_socket sock, p_sa addr, int addr_len) { + int ret = SUCCESS; + socket_setblocking(sock); + if (bind(*sock, addr, addr_len)) { + ret = errno; + } + socket_setnonblocking(sock); + return ret; +} + +T_ERRCODE socket_get_info(p_socket sock, short *port, char *buf, size_t len) { + struct sockaddr_in sa; + socklen_t addrlen; + memset(&sa, 0, sizeof(sa)); + int rc = getsockname(*sock, (struct sockaddr*)&sa, &addrlen); + if (!rc) { + char *addr = inet_ntoa(sa.sin_addr); + *port = ntohs(sa.sin_port); + if (strlen(addr) < len) { + len = strlen(addr); + } + memcpy(buf, addr, len); + return SUCCESS; + } + return rc; +} + +//////////////////////////////////////////////////////////////////////////////// +// Server + +T_ERRCODE socket_accept(p_socket sock, p_socket client, + p_sa addr, socklen_t *addrlen, int timeout) { + int err; + if (*sock < 0) { + return CLOSED; + } + do { + *client = accept(*sock, addr, addrlen); + if (*client > 0) { + return SUCCESS; + } + err = errno; + } while (err != EINTR); + if (err == EAGAIN || err == ECONNABORTED) { + return socket_wait(sock, WAIT_MODE_R, timeout); + } + return err; +} + +T_ERRCODE socket_listen(p_socket sock, int backlog) { + int ret = SUCCESS; + socket_setblocking(sock); + if (listen(*sock, backlog)) { + ret = errno; + } + socket_setnonblocking(sock); + return ret; +} + +//////////////////////////////////////////////////////////////////////////////// +// Client + +T_ERRCODE socket_connect(p_socket sock, p_sa addr, int addr_len, int timeout) { + int err; + if (*sock < 0) { + return CLOSED; + } + + do { + if (connect(*sock, addr, addr_len) == 0) { + return SUCCESS; + } + } while ((err = errno) == EINTR); + if (err != EINPROGRESS && err != EAGAIN) { + return err; + } + return socket_wait(sock, WAIT_MODE_C, timeout); +} + +T_ERRCODE socket_send( + p_socket sock, const char *data, size_t len, int timeout) { + int err, put = 0; + if (*sock < 0) { + return CLOSED; + } + do { + put = send(*sock, data, len, 0); + if (put > 0) { + return SUCCESS; + } + err = errno; + } while (err != EINTR); + + if (err == EAGAIN) { + return socket_wait(sock, WAIT_MODE_W, timeout); + } + return err; +} + +T_ERRCODE socket_recv( + p_socket sock, char *data, size_t len, int timeout, int *received) { + int err, got = 0; + if (*sock < 0) { + return CLOSED; + } + + int flags = fcntl(*sock, F_GETFL, 0); + do { + got = recv(*sock, data, len, 0); + if (got > 0) { + *received = got; + return SUCCESS; + } + err = errno; + + // Connection has been closed by peer + if (got == 0) { + return CLOSED; + } + } while (err != EINTR); + + if (err == EAGAIN) { + return socket_wait(sock, WAIT_MODE_R, timeout); + } + return err; +} + +//////////////////////////////////////////////////////////////////////////////// +// Util + +void socket_setnonblocking(p_socket sock) { + int flags = fcntl(*sock, F_GETFL, 0); + flags |= O_NONBLOCK; + fcntl(*sock, F_SETFL, flags); +} + +void socket_setblocking(p_socket sock) { + int flags = fcntl(*sock, F_GETFL, 0); + flags &= (~(O_NONBLOCK)); + fcntl(*sock, F_SETFL, flags); +} + +//////////////////////////////////////////////////////////////////////////////// +// TCP + +#define ERRORSTR_RETURN(err) \ + if (err == SUCCESS) { \ + return NULL; \ + } else if (err == TIMEOUT) { \ + return TIMEOUT_MSG; \ + } else if (err == CLOSED) { \ + return CLOSED_MSG; \ + } \ + return strerror(err) + +const char * tcp_create(p_socket sock) { + int err = socket_create(sock, AF_INET, SOCK_STREAM, 0); + ERRORSTR_RETURN(err); +} + +const char * tcp_destroy(p_socket sock) { + int err = socket_destroy(sock); + ERRORSTR_RETURN(err); +} + +const char * tcp_bind(p_socket sock, const char *host, unsigned short port) { + int err; + struct hostent *h; + struct sockaddr_in local; + memset(&local, 0, sizeof(local)); + local.sin_family = AF_INET; + local.sin_addr.s_addr = htonl(INADDR_ANY); + local.sin_port = htons(port); + if (strcmp(host, "*") && !inet_aton(host, &local.sin_addr)) { + h = gethostbyname(host); + if (!h) { + return hstrerror(h_errno); + } + memcpy(&local.sin_addr, + (struct in_addr *)h->h_addr_list[0], + sizeof(struct in_addr)); + } + err = socket_bind(sock, (p_sa) &local, sizeof(local)); + ERRORSTR_RETURN(err); +} + +const char * tcp_listen(p_socket sock, int backlog) { + int err = socket_listen(sock, backlog); + ERRORSTR_RETURN(err); +} + +const char * tcp_accept(p_socket sock, p_socket client, int timeout) { + int err = socket_accept(sock, client, NULL, NULL, timeout); + ERRORSTR_RETURN(err); +} + +const char * tcp_connect(p_socket sock, + const char *host, + unsigned short port, + int timeout) { + int err; + struct hostent *h; + struct sockaddr_in remote; + memset(&remote, 0, sizeof(remote)); + remote.sin_family = AF_INET; + remote.sin_port = htons(port); + if (strcmp(host, "*") && !inet_aton(host, &remote.sin_addr)) { + h = gethostbyname(host); + if (!h) { + return hstrerror(h_errno); + } + memcpy(&remote.sin_addr, + (struct in_addr *)h->h_addr_list[0], + sizeof(struct in_addr)); + } + err = socket_connect(sock, (p_sa) &remote, sizeof(remote), timeout); + ERRORSTR_RETURN(err); +} + +#define WRITE_STEP 8192 +const char * tcp_send( + p_socket sock, const char * data, size_t w_len, int timeout) { + int err; + size_t put = 0, step; + if (!w_len) { + return NULL; + } + + do { + step = (WRITE_STEP < w_len - put ? WRITE_STEP : w_len - put); + err = socket_send(sock, data + put, step, timeout); + put += step; + } while (err == SUCCESS && put < w_len); + ERRORSTR_RETURN(err); +} + +const char * tcp_raw_receive( + p_socket sock, char * data, size_t r_len, int timeout, int *received) { + int err = socket_recv(sock, data, r_len, timeout, received); + ERRORSTR_RETURN(err); +} diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift index 568ed1b68b6..7ca194efd49 100644 --- a/test/ThriftTest.thrift +++ b/test/ThriftTest.thrift @@ -35,6 +35,7 @@ namespace go ThriftTest namespace php ThriftTest namespace delphi Thrift.Test namespace cocoa ThriftTest +namespace lua ThriftTest // Presence of namespaces and sub-namespaces for which there is // no generator should compile with warnings only diff --git a/test/lua/test_basic_client.lua b/test/lua/test_basic_client.lua new file mode 100644 index 00000000000..e2e0d48dcf1 --- /dev/null +++ b/test/lua/test_basic_client.lua @@ -0,0 +1,136 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + + +require('TSocket') +require('TBinaryProtocol') +require('ThriftTest_ThriftTest') +require('liblualongnumber') + +local client + +function teardown() + if client then + -- Shuts down the server + client:testVoid() + + -- close the connection + client:close() + end +end + +function assertEqual(val1, val2, msg) + assert(val1 == val2, msg) +end + +function testBasicClient() + local socket = TSocket:new{ + port = 9090 + } + assert(socket, 'Failed to create client socket') + socket:setTimeout(5000) + + local protocol = TBinaryProtocol:new{ + trans = socket + } + assert(protocol, 'Failed to create binary protocol') + + client = ThriftTestClient:new{ + protocol = protocol + } + assert(client, 'Failed to create client') + + -- Open the socket + local status, _ = pcall(socket.open, socket) + assert(status, 'Failed to connect to server') + + -- String + assertEqual(client:testString('lala'), 'lala', 'Failed testString') + assertEqual(client:testString('wahoo'), 'wahoo', 'Failed testString') + + -- Byte + assertEqual(client:testByte(0x01), 1, 'Failed testByte 1') + assertEqual(client:testByte(0x40), 64, 'Failed testByte 2') + assertEqual(client:testByte(0x7f), 127, 'Failed testByte 3') + assertEqual(client:testByte(0x80), -128, 'Failed testByte 4') + assertEqual(client:testByte(0xbf), -65, 'Failed testByte 5') + assertEqual(client:testByte(0xff), -1, 'Failed testByte 6') + assertEqual(client:testByte(128), -128, 'Failed testByte 7') + assertEqual(client:testByte(255), -1, 'Failed testByte 8') + + -- I32 + assertEqual(client:testI32(0x00000001), 1, 'Failed testI32 1') + assertEqual(client:testI32(0x40000000), 1073741824, 'Failed testI32 2') + assertEqual(client:testI32(0x7fffffff), 2147483647, 'Failed testI32 3') + assertEqual(client:testI32(0x80000000), -2147483648, 'Failed testI32 4') + assertEqual(client:testI32(0xbfffffff), -1073741825, 'Failed testI32 5') + assertEqual(client:testI32(0xffffffff), -1, 'Failed testI32 6') + assertEqual(client:testI32(2147483648), -2147483648, 'Failed testI32 7') + assertEqual(client:testI32(4294967295), -1, 'Failed testI32 8') + + -- I64 (lua only supports 16 decimal precision so larger numbers are + -- initialized by their string value) + local long = liblualongnumber.new + assertEqual(client:testI64(long(0x0000000000000001)), + long(1), + 'Failed testI64 1') + assertEqual(client:testI64(long(0x4000000000000000)), + long(4611686018427387904), + 'Failed testI64 2') + assertEqual(client:testI64(long('0x7fffffffffffffff')), + long('9223372036854775807'), + 'Failed testI64 3') + assertEqual(client:testI64(long(0x8000000000000000)), + long(-9223372036854775808), + 'Failed testI64 4') + assertEqual(client:testI64(long('0xbfffffffffffffff')), + long('-4611686018427387905'), + 'Failed testI64 5') + assertEqual(client:testI64(long('0xffffffffffffffff')), + long(-1), + 'Failed testI64 6') + + -- Double + assertEqual( + client:testDouble(1.23456789), 1.23456789, 'Failed testDouble 1') + assertEqual( + client:testDouble(0.123456789), 0.123456789, 'Failed testDouble 2') + assertEqual( + client:testDouble(0.123456789), 0.123456789, 'Failed testDouble 3') + + -- Accuracy of 16 decimal digits (rounds) + local a, b = 1.12345678906666663, 1.12345678906666661 + assertEqual(a, b) + assertEqual(client:testDouble(a), b, 'Failed testDouble 5') + + -- Struct + local a = { + string_thing = 'Zero', + byte_thing = 1, + i32_thing = -3, + i64_thing = long(-5) + } + + -- TODO fix client struct equality + --assertEqual(client:testStruct(a), a, 'Failed testStruct') + + -- Call the void function and end the test (handler stops server) + client:testVoid() +end + +testBasicClient() +teardown() \ No newline at end of file diff --git a/test/lua/test_basic_server.lua b/test/lua/test_basic_server.lua new file mode 100644 index 00000000000..7c175daca9a --- /dev/null +++ b/test/lua/test_basic_server.lua @@ -0,0 +1,104 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +require('ThriftTest_ThriftTest') +require('TSocket') +require('TFramedTransport') +require('TBinaryProtocol') +require('TServer') +require('liblualongnumber') + +-------------------------------------------------------------------------------- +-- Handler +TestHandler = ThriftTestIface:new{} + +-- Stops the server +function TestHandler:testVoid() + self.__server:stop() +end + +function TestHandler:testString(str) + return str +end + +function TestHandler:testByte(byte) + return byte +end + +function TestHandler:testI32(i32) + return i32 +end + +function TestHandler:testI64(i64) + return i64 +end + +function TestHandler:testDouble(d) + return d +end + +function TestHandler:testStruct(thing) + return thing +end + +-------------------------------------------------------------------------------- +-- Test +local server + +function teardown() + if server then + server:close() + end +end + +function testBasicServer() + -- Handler & Processor + local handler = TestHandler:new{} + assert(handler, 'Failed to create handler') + local processor = ThriftTestProcessor:new{ + handler = handler + } + assert(processor, 'Failed to create processor') + + -- Server Socket + local socket = TServerSocket:new{ + port = 9090 + } + assert(socket, 'Failed to create server socket') + + -- Transport & Factory + local trans_factory = TFramedTransportFactory:new{} + assert(trans_factory, 'Failed to create framed transport factory') + local prot_factory = TBinaryProtocolFactory:new{} + assert(prot_factory, 'Failed to create binary protocol factory') + + -- Simple Server + server = TSimpleServer:new{ + processor = processor, + serverTransport = socket, + transportFactory = trans_factory, + protocolFactory = prot_factory + } + assert(server, 'Failed to create server') + + -- Serve + server:serve() + server = nil +end + +testBasicServer() +teardown() \ No newline at end of file