Skip to content

Commit

Permalink
ssl扩展
Browse files Browse the repository at this point in the history
  • Loading branch information
xiyoo0812 committed Apr 11, 2024
1 parent b81ebcf commit 8e165e4
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 10 deletions.
2 changes: 2 additions & 0 deletions extend/lcodec/src/http.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ namespace lcodec {
if (parts.size() < 2) {
throw lua_exception("invalid http header");
}
//proto
lua_pushstring(L, "HTTP");
//status
string status = string(parts[1]);
if (lua_stringtonumber(L, status.c_str()) == 0) {
Expand Down
1 change: 1 addition & 0 deletions extend/lssl/lssl.lmak
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ WINDOWS_DEFINES = {

DEFINES = {
"WOLFSSL_LIB",
"WOLFSSL_SRTP",
"WOLFSSL_NO_SOCK",
"WOLFSSL_USER_IO",
"WOLFSSL_USER_SETTINGS"
Expand Down
1 change: 1 addition & 0 deletions extend/lssl/lssl.mak
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ MYCFLAGS += -I../luakit/include

#需要定义的选项
MYCFLAGS += -DWOLFSSL_LIB
MYCFLAGS += -DWOLFSSL_SRTP
MYCFLAGS += -DWOLFSSL_NO_SOCK
MYCFLAGS += -DWOLFSSL_USER_IO
MYCFLAGS += -DWOLFSSL_USER_SETTINGS
Expand Down
2 changes: 1 addition & 1 deletion extend/lssl/lssl.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@
<ClCompile>
<Optimization>Disabled</Optimization>
<AdditionalIncludeDirectories>.\src;..\lua\lua;..\luakit\include;$(SolutionDir)extend\mimalloc\mimalloc\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<PreprocessorDefinitions>WIN32;NDEBUG;_WINDOWS;_CRT_SECURE_NO_WARNINGS;WOLFSSL_LIB;WOLFSSL_NO_SOCK;WOLFSSL_USER_IO;WOLFSSL_USER_SETTINGS;LUA_BUILD_AS_DLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions>WIN32;NDEBUG;_WINDOWS;_CRT_SECURE_NO_WARNINGS;WOLFSSL_LIB;WOLFSSL_SRTP;WOLFSSL_NO_SOCK;WOLFSSL_USER_IO;WOLFSSL_USER_SETTINGS;LUA_BUILD_AS_DLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<BasicRuntimeChecks>Default</BasicRuntimeChecks>
<RuntimeLibrary>MultiThreadedDLL</RuntimeLibrary>
<PrecompiledHeader></PrecompiledHeader>
Expand Down
20 changes: 20 additions & 0 deletions extend/lssl/src/ssl/lssl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define HEX(v,c) { char tmp = (char) c; if (tmp >= '0' && tmp <= '9') { v = tmp-'0'; } else { v = tmp - 'a' + 10; } }

namespace lssl {
thread_local luakit::luabuf thread_buff;

static void hash(const char* str, int sz, char key[8]) {
long djb_hash = 5381L;
Expand Down Expand Up @@ -350,6 +351,13 @@ namespace lssl {
return 1;
}

static tlscodec* tls_codec(codec_base* codec) {
tlscodec* tcodec = new tlscodec();
tcodec->set_codec(codec);
tcodec->set_buff(&thread_buff);
return tcodec;
}

luakit::lua_table open_lssl(lua_State* L) {
luakit::kit_state kit_state(L);
auto luassl = kit_state.new_table("ssl");
Expand Down Expand Up @@ -382,18 +390,30 @@ namespace lssl {
luassl.set_function("xxtea_decode", lxxtea_decode);
luassl.set_function("rsa_init_pubkey", lrsa_init_pubkey);
luassl.set_function("rsa_init_prikey", lrsa_init_prikey);
luassl.set_function("tlscodec", tls_codec);
kit_state.new_class<lua_rsa_key>(
"pub_encode", &lua_rsa_key::pub_encode,
"pub_decode", &lua_rsa_key::pub_decode,
"pri_encode", &lua_rsa_key::pri_encode,
"pri_decode", &lua_rsa_key::pri_decode
);
kit_state.new_class<tlscodec>(
"init_tls", &tlscodec::init_tls,
"set_cert", &tlscodec::set_cert,
"set_ciphers", &tlscodec::set_ciphers
);
return luassl;
}
}

extern "C" {
static bool SSL_IS_INIT = false;
LUALIB_API int luaopen_lssl(lua_State* L) {
if (!SSL_IS_INIT) {
SSL_IS_INIT = true;
SSL_library_init();
OpenSSL_add_all_algorithms();
}
auto luassl = lssl::open_lssl(L);
return luassl.push_stack();
}
Expand Down
172 changes: 172 additions & 0 deletions extend/lssl/src/ssl/lssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

#include "lua_kit.h"

using namespace luakit;

#define RSA_PADDING_LEN 11
#define SSL_TLS_READ_SIZE 1024
#define RSA_ENCODE_LEN(m) (m) - RSA_PADDING_LEN
#define BASE64_DECODE_OUT_SIZE(s) ((unsigned int)(((s) / 4) * 3))
#define BASE64_ENCODE_OUT_SIZE(s) ((unsigned int)((((s) + 2) / 3) * 4 + 1))
Expand Down Expand Up @@ -145,4 +148,173 @@ namespace lssl {
RSA* rsa = nullptr;
char buf[RSA_MAX_SIZE / 8];
};

class tlscodec : public codec_base {
public:
~tlscodec() {
if (ssl) SSL_free(ssl);
if (ctx) SSL_CTX_free(ctx);
}



virtual int load_packet(size_t data_len) {
if (!m_slice) return 0;
return data_len;
}

virtual uint8_t* encode(lua_State* L, int index, size_t* len) {
if (!is_handshake) {
return (uint8_t*)lua_tolstring(L, index, len);
}
size_t slen = 0;
uint8_t* body = m_hcodec->encode(L, index, &slen);
while (slen > 0) {
int written = SSL_write(ssl, body, slen);
if (written <= 0 || written > slen) {
int err = SSL_get_error(ssl, written);
ERR_clear_error();
luaL_error(L, "SSL_write error:%d", err);
}
body += written;
slen -= written;
}
bio_read(L);
return m_buf->data(len);
}

virtual size_t decode(lua_State* L) {
if (!is_handshake) {
return handshake(L);
}
m_buf->clean();
bio_write(L, m_slice->size());
do {
uint8_t* outbuff = m_buf->peek_space(SSL_TLS_READ_SIZE);
int read = SSL_read(ssl, outbuff, sizeof(SSL_TLS_READ_SIZE));
if (read == 0) break;
if (read < 0 || read > SSL_TLS_READ_SIZE) {
int err = SSL_get_error(ssl, read);
ERR_clear_error();
throw lua_exception("SSL_read error:%d", err);
}
m_buf->pop_space(SSL_TLS_READ_SIZE);
} while (true);
m_hcodec->set_slice(m_buf->get_slice());
return m_hcodec->decode(L);
}

void set_codec(codec_base* codec) {
m_hcodec = codec;
}

int init_tls(lua_State* L, bool is_server = false) {
ssl = SSL_new(ctx);
if (!ssl) luaL_error(L, "SSL_new faild");
in_bio = BIO_new(BIO_s_mem());
if (!in_bio) luaL_error(L, "new in bio faild");
out_bio = BIO_new(BIO_s_mem());
if (!out_bio) luaL_error(L, "new out bio faild");
BIO_set_mem_eof_return(in_bio, -1);
BIO_set_mem_eof_return(out_bio, -1);
SSL_set_bio(ssl, in_bio, out_bio);
ctx = SSL_CTX_new(SSLv23_method());
if (!ctx) {
char buf[256];
ERR_error_string_n(ERR_get_error(), buf, sizeof(buf));
luaL_error(L, "SSL_CTX_new faild. %s\n", buf);
}
if (is_server) {
SSL_set_accept_state(ssl);
} else {
SSL_set_connect_state(ssl);
}
return 0;
}

int set_ciphers(lua_State* L, std::string_view cipher) {
if (int ret = SSL_CTX_set_tlsext_use_srtp(ctx, cipher.data()) != 0) {
luaL_error(L, "SSL_CTX_set_tlsext_use_srtp error: %d", ret);
}
return 0;
}

int set_cert(lua_State* L, std::string_view certfile, std::string_view key) {
if (int ret = SSL_CTX_use_certificate_chain_file(ctx, certfile.data()) != 1) {
luaL_error(L, "SSL_CTX_use_certificate_chain_file error:%d", ret);
}
if (int ret = SSL_CTX_use_PrivateKey_file(ctx, key.data(), SSL_FILETYPE_PEM) != 1) {
luaL_error(L, "SSL_CTX_use_PrivateKey_file error:%d", ret);
}
if (int ret = SSL_CTX_check_private_key(ctx) != 1) {
luaL_error(L, "SSL_CTX_check_private_key error:%d", ret);
}
return 0;
}

protected:
size_t handshake(lua_State* L) {
int top = lua_gettop(L);
size_t sz = m_slice->size();
is_handshake = SSL_is_init_finished(ssl);
lua_pushstring(L, "TLS");
lua_pushinteger(L, is_handshake);
if (!is_handshake) {
bio_write(L, sz);
if (int ret = SSL_do_handshake(ssl) > 0) {
int err = SSL_get_error(ssl, ret);
ERR_clear_error();
throw lua_exception("SSL_do_handshake error:%d ret:%d", err, ret);
} else if (ret < 0) {
int err = SSL_get_error(ssl, ret);
ERR_clear_error();
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
if (size_t len = bio_read(L) > 0) {
lua_pushlstring(L, (const char*)m_buf->head(), len);
}
}
}
}
m_slice->erase(sz);
return lua_gettop(L) - top;
}

void bio_write(lua_State* L, size_t sz) {
char* p = (char*)m_slice->head();
while (sz > 0) {
int written = BIO_write(in_bio, p, sz);
if (written <= 0 || written > sz) {
throw lua_exception("BIO_write error:%d", written);
}
sz -= written;
p += written;
}
}

size_t bio_read(lua_State* L) {
int pending = BIO_ctrl_pending(out_bio);
if (pending > 0) {
m_buf->clean();
while (pending > 0) {
uint8_t* outbuff = m_buf->peek_space(SSL_TLS_READ_SIZE);
int read = BIO_read(out_bio, outbuff, SSL_TLS_READ_SIZE);
if (read <= 0 || read > SSL_TLS_READ_SIZE) {
throw lua_exception("BIO_read error:%d", read);
}
m_buf->pop_space(SSL_TLS_READ_SIZE);
pending = BIO_ctrl_pending(out_bio);
}
return m_buf->size();
}
return 0;
}

protected:
SSL* ssl = nullptr;
BIO* in_bio = nullptr;
BIO* out_bio = nullptr;
SSL_CTX* ctx = nullptr;
codec_base* m_hcodec = nullptr;
bool is_handshake = false;
};
}
41 changes: 32 additions & 9 deletions script/network/http_client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ local sgmatch = string.gmatch
local jsoncodec = json.jsoncodec
local httpccodec = codec.httpccodec
local luencode = codec.url_encode
local tlscodec = ssl.tlscodec

local thread_mgr = quanta.get("thread_mgr")
local update_mgr = quanta.get("update_mgr")
Expand Down Expand Up @@ -42,8 +43,21 @@ function HttpClient:on_hour()
self.domains = {}
end

function HttpClient:on_socket_recv(socket, status, headers, body)
local token = socket.token
function HttpClient:on_socket_recv(socket, proto, ...)
if proto == "TLS" then
self:on_handshake(socket.token, ...)
end
return self:on_http_recv(socket.token, ...)
end

function HttpClient:on_handshake(token, message)
local client = self.clients[token]
if client and message then
client:send_data(message)
end
end

function HttpClient:on_http_recv(token, status, headers, body)
local client = self.clients[token]
if client then
client:close()
Expand All @@ -59,7 +73,7 @@ end

--构建请求
function HttpClient:send_request(url, timeout, querys, headers, method, datas)
local host, port, path = self:parse_url_addr(url)
local host, port, path, proto = self:parse_url_addr(url)
if not host then
log_err("[HttpClient][send_request] failed : {}", port)
return false, port
Expand All @@ -69,15 +83,24 @@ function HttpClient:send_request(url, timeout, querys, headers, method, datas)
if not ok then
return false, cerr
end
if proto == "https" then
local codec = tlscodec(self.hcodec)
if not codec then
return false, "tls codec create failed!"
end
codec:init_tls()
socket:set_codec(codec)
else
socket:set_codec(self.hcodec)
end
if not headers then
headers = {["Content-Type"] = "text/plain" }
end
if type(datas) == "table" then
headers["Content-Type"] = "application/json"
end
local fmt_url = self:format_url(path, querys)
local session_id = thread_mgr:build_session_id()
socket:set_codec(self.hcodec)
local fmt_url = self:format_url(path, querys)
socket.session_id = session_id
self.clients[socket.token] = socket
socket:send_data(fmt_url, method, headers, datas or "")
Expand Down Expand Up @@ -129,13 +152,13 @@ function HttpClient:parse_url_addr(url)
if url:sub(-1) ~= "/" then
url = sformat("%s/", url)
end
local http, addr, path = sgmatch(url, "(.+)://([^/]-)(/.*)")()
if not http then
local proto, addr, path = sgmatch(url, "(.+)://([^/]-)(/.*)")()
if not proto then
return nil, "Illegal htpp url"
end
local host, port = qsaddr(addr)
if not port then
port = http == "https" and 443 or 80
port = proto == "https" and 443 or 80
end
local ip = self.domains[host]
if not ip then
Expand All @@ -146,7 +169,7 @@ function HttpClient:parse_url_addr(url)
ip = ips[1]
self.domains[host] = ip
end
return ip, port, path
return ip, port, path, proto
end

quanta.http_client = HttpClient()
Expand Down

0 comments on commit 8e165e4

Please sign in to comment.