Skip to content

Commit

Permalink
THRIFT-3580 THeader for Haskell
Browse files Browse the repository at this point in the history
Client: hs

This closes apache#820
This closes apache#1423
  • Loading branch information
nsuke authored and jeking3 committed Nov 30, 2017
1 parent 2147466 commit 3c42007
Show file tree
Hide file tree
Showing 15 changed files with 690 additions and 141 deletions.
64 changes: 31 additions & 33 deletions compiler/cpp/src/thrift/generate/t_hs_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,13 +711,13 @@ void t_hs_generator::generate_hs_struct_reader(ofstream& out, t_struct* tstruct)
string tmap = type_name(tstruct, "typemap_");
indent(out) << "to_" << sname << " _ = P.error \"not a struct\"" << endl;

indent(out) << "read_" << sname << " :: (T.Transport t, T.Protocol p) => p t -> P.IO " << sname
indent(out) << "read_" << sname << " :: T.Protocol p => p -> P.IO " << sname
<< endl;
indent(out) << "read_" << sname << " iprot = to_" << sname;
out << " <$> T.readVal iprot (T.T_STRUCT " << tmap << ")" << endl;

indent(out) << "decode_" << sname
<< " :: (T.Protocol p, T.Transport t) => p t -> LBS.ByteString -> " << sname << endl;
<< " :: T.StatelessProtocol p => p -> LBS.ByteString -> " << sname << endl;
indent(out) << "decode_" << sname << " iprot bs = to_" << sname << " $ ";
out << "T.deserializeVal iprot (T.T_STRUCT " << tmap << ") bs" << endl;
}
Expand Down Expand Up @@ -818,13 +818,13 @@ void t_hs_generator::generate_hs_struct_writer(ofstream& out, t_struct* tstruct)
indent_down();

// write
indent(out) << "write_" << name << " :: (T.Protocol p, T.Transport t) => p t -> " << name
indent(out) << "write_" << name << " :: T.Protocol p => p -> " << name
<< " -> P.IO ()" << endl;
indent(out) << "write_" << name << " oprot record = T.writeVal oprot $ from_";
out << name << " record" << endl;

// encode
indent(out) << "encode_" << name << " :: (T.Protocol p, T.Transport t) => p t -> " << name
indent(out) << "encode_" << name << " :: T.StatelessProtocol p => p -> " << name
<< " -> LBS.ByteString" << endl;
indent(out) << "encode_" << name << " oprot record = T.serializeVal oprot $ ";
out << "from_" << name << " record" << endl;
Expand Down Expand Up @@ -1085,8 +1085,9 @@ void t_hs_generator::generate_service_client(t_service* tservice) {
// Serialize the request header
string fname = (*f_iter)->get_name();
string msgType = (*f_iter)->is_oneway() ? "T.M_ONEWAY" : "T.M_CALL";
indent(f_client_) << "T.writeMessageBegin op (\"" << fname << "\", " << msgType << ", seqn)"
indent(f_client_) << "T.writeMessage op (\"" << fname << "\", " << msgType << ", seqn) $"
<< endl;
indent_up();
indent(f_client_) << "write_" << argsname << " op (" << argsname << "{";

bool first = true;
Expand All @@ -1102,10 +1103,7 @@ void t_hs_generator::generate_service_client(t_service* tservice) {
first = false;
}
f_client_ << "})" << endl;
indent(f_client_) << "T.writeMessageEnd op" << endl;

// Write to the stream
indent(f_client_) << "T.tFlush (T.getTransport op)" << endl;
indent_down();
indent_down();

if (!(*f_iter)->is_oneway()) {
Expand All @@ -1119,12 +1117,12 @@ void t_hs_generator::generate_service_client(t_service* tservice) {
indent(f_client_) << funname << " ip = do" << endl;
indent_up();

indent(f_client_) << "(fname, mtype, rseqid) <- T.readMessageBegin ip" << endl;
indent(f_client_) << "T.readMessage ip $ \\(fname, mtype, rseqid) -> do" << endl;
indent_up();
indent(f_client_) << "M.when (mtype == T.M_EXCEPTION) $ do { exn <- T.readAppExn ip ; "
"T.readMessageEnd ip ; X.throw exn }" << endl;
"X.throw exn }" << endl;

indent(f_client_) << "res <- read_" << resultname << " ip" << endl;
indent(f_client_) << "T.readMessageEnd ip" << endl;

t_struct* xs = (*f_iter)->get_xceptions();
const vector<t_field*>& xceptions = xs->get_members();
Expand All @@ -1142,6 +1140,7 @@ void t_hs_generator::generate_service_client(t_service* tservice) {

// Close function
indent_down();
indent_down();
}
}

Expand Down Expand Up @@ -1180,11 +1179,11 @@ void t_hs_generator::generate_service_server(t_service* tservice) {
f_service_ << "do" << endl;
indent_up();
indent(f_service_) << "_ <- T.readVal iprot (T.T_STRUCT Map.empty)" << endl;
indent(f_service_) << "T.writeMessageBegin oprot (name,T.M_EXCEPTION,seqid)" << endl;
indent(f_service_) << "T.writeMessage oprot (name,T.M_EXCEPTION,seqid) $" << endl;
indent_up();
indent(f_service_) << "T.writeAppExn oprot (T.AppExn T.AE_UNKNOWN_METHOD (\"Unknown function "
"\" ++ LT.unpack name))" << endl;
indent(f_service_) << "T.writeMessageEnd oprot" << endl;
indent(f_service_) << "T.tFlush (T.getTransport oprot)" << endl;
indent_down();
indent_down();
}

Expand All @@ -1194,9 +1193,8 @@ void t_hs_generator::generate_service_server(t_service* tservice) {
indent(f_service_) << "process handler (iprot, oprot) = do" << endl;
indent_up();

indent(f_service_) << "(name, typ, seqid) <- T.readMessageBegin iprot" << endl;
indent(f_service_) << "proc_ handler (iprot,oprot) (name,typ,seqid)" << endl;
indent(f_service_) << "T.readMessageEnd iprot" << endl;
indent(f_service_) << "T.readMessage iprot (" << endl;
indent(f_service_) << " proc_ handler (iprot,oprot))" << endl;
indent(f_service_) << "P.return P.True" << endl;
indent_down();
}
Expand Down Expand Up @@ -1286,11 +1284,11 @@ void t_hs_generator::generate_process_function(t_service* tservice, t_function*
if (tfunction->is_oneway()) {
indent(f_service_) << "P.return ()";
} else {
indent(f_service_) << "T.writeMessageBegin oprot (\"" << tfunction->get_name()
<< "\", T.M_REPLY, seqid)" << endl;
indent(f_service_) << "write_" << resultname << " oprot res" << endl;
indent(f_service_) << "T.writeMessageEnd oprot" << endl;
indent(f_service_) << "T.tFlush (T.getTransport oprot)";
indent(f_service_) << "T.writeMessage oprot (\"" << tfunction->get_name()
<< "\", T.M_REPLY, seqid) $" << endl;
indent_up();
indent(f_service_) << "write_" << resultname << " oprot res";
indent_down();
}
if (n > 0) {
f_service_ << ")";
Expand All @@ -1307,11 +1305,11 @@ void t_hs_generator::generate_process_function(t_service* tservice, t_function*
indent(f_service_) << "let res = default_" << resultname << "{"
<< field_name(resultname, (*x_iter)->get_name()) << " = P.Just e}"
<< endl;
indent(f_service_) << "T.writeMessageBegin oprot (\"" << tfunction->get_name()
<< "\", T.M_REPLY, seqid)" << endl;
indent(f_service_) << "write_" << resultname << " oprot res" << endl;
indent(f_service_) << "T.writeMessageEnd oprot" << endl;
indent(f_service_) << "T.tFlush (T.getTransport oprot)";
indent(f_service_) << "T.writeMessage oprot (\"" << tfunction->get_name()
<< "\", T.M_REPLY, seqid) $" << endl;
indent_up();
indent(f_service_) << "write_" << resultname << " oprot res";
indent_down();
} else {
indent(f_service_) << "P.return ()";
}
Expand All @@ -1324,11 +1322,11 @@ void t_hs_generator::generate_process_function(t_service* tservice, t_function*
indent_up();

if (!tfunction->is_oneway()) {
indent(f_service_) << "T.writeMessageBegin oprot (\"" << tfunction->get_name()
<< "\", T.M_EXCEPTION, seqid)" << endl;
indent(f_service_) << "T.writeAppExn oprot (T.AppExn T.AE_UNKNOWN \"\")" << endl;
indent(f_service_) << "T.writeMessageEnd oprot" << endl;
indent(f_service_) << "T.tFlush (T.getTransport oprot)";
indent(f_service_) << "T.writeMessage oprot (\"" << tfunction->get_name()
<< "\", T.M_EXCEPTION, seqid) $" << endl;
indent_up();
indent(f_service_) << "T.writeAppExn oprot (T.AppExn T.AE_UNKNOWN \"\")";
indent_down();
} else {
indent(f_service_) << "P.return ()";
}
Expand Down
4 changes: 2 additions & 2 deletions lib/hs/src/Thrift.hs
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ data AppExn = AppExn { ae_type :: AppExnType, ae_message :: String }
deriving ( Show, Typeable )
instance Exception AppExn

writeAppExn :: (Protocol p, Transport t) => p t -> AppExn -> IO ()
writeAppExn :: Protocol p => p -> AppExn -> IO ()
writeAppExn pt ae = writeVal pt $ TStruct $ Map.fromList
[ (1, ("message", TString $ encodeUtf8 $ pack $ ae_message ae))
, (2, ("type", TI32 $ fromIntegral $ fromEnum (ae_type ae)))
]

readAppExn :: (Protocol p, Transport t) => p t -> IO AppExn
readAppExn :: Protocol p => p -> IO AppExn
readAppExn pt = do
let typemap = Map.fromList [(1,("message",T_STRING)),(2,("type",T_I32))]
TStruct fields <- readVal pt $ T_STRUCT typemap
Expand Down
41 changes: 14 additions & 27 deletions lib/hs/src/Thrift/Protocol.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,18 @@

module Thrift.Protocol
( Protocol(..)
, StatelessProtocol(..)
, ProtocolExn(..)
, ProtocolExnType(..)
, getTypeOf
, runParser
, versionMask
, version1
, bsToDouble
, bsToDoubleLE
) where

import Control.Exception
import Data.Attoparsec.ByteString
import Data.Bits
import Data.ByteString.Lazy (ByteString, toStrict)
import Data.ByteString.Unsafe
import Data.Functor ((<$>))
import Data.Int
Expand All @@ -44,37 +42,26 @@ import Data.Text.Lazy (Text)
import Data.Typeable (Typeable)
import Data.Word
import Foreign.Ptr (castPtr)
import Foreign.Storable (Storable, peek, poke)
import Foreign.Storable (peek, poke)
import System.IO.Unsafe
import qualified Data.ByteString as BS
import qualified Data.HashMap.Strict as Map
import qualified Data.ByteString.Lazy as LBS

import Thrift.Types
import Thrift.Transport

versionMask :: Int32
versionMask = fromIntegral (0xffff0000 :: Word32)

version1 :: Int32
version1 = fromIntegral (0x80010000 :: Word32)
import Thrift.Types

class Protocol a where
getTransport :: Transport t => a t -> t

writeMessageBegin :: Transport t => a t -> (Text, MessageType, Int32) -> IO ()
writeMessageEnd :: Transport t => a t -> IO ()
writeMessageEnd _ = return ()

readMessageBegin :: Transport t => a t -> IO (Text, MessageType, Int32)
readMessageEnd :: Transport t => a t -> IO ()
readMessageEnd _ = return ()
readByte :: a -> IO LBS.ByteString
readVal :: a -> ThriftType -> IO ThriftVal
readMessage :: a -> ((Text, MessageType, Int32) -> IO b) -> IO b

serializeVal :: Transport t => a t -> ThriftVal -> ByteString
deserializeVal :: Transport t => a t -> ThriftType -> ByteString -> ThriftVal
writeVal :: a -> ThriftVal -> IO ()
writeMessage :: a -> (Text, MessageType, Int32) -> IO () -> IO ()

writeVal :: Transport t => a t -> ThriftVal -> IO ()
writeVal p = tWrite (getTransport p) . serializeVal p
readVal :: Transport t => a t -> ThriftType -> IO ThriftVal
class Protocol a => StatelessProtocol a where
serializeVal :: a -> ThriftVal -> LBS.ByteString
deserializeVal :: a -> ThriftType -> LBS.ByteString -> ThriftVal

data ProtocolExnType
= PE_UNKNOWN
Expand Down Expand Up @@ -105,10 +92,10 @@ getTypeOf v = case v of
TBinary{} -> T_BINARY
TDouble{} -> T_DOUBLE

runParser :: (Protocol p, Transport t, Show a) => p t -> Parser a -> IO a
runParser :: (Protocol p, Show a) => p -> Parser a -> IO a
runParser prot p = refill >>= getResult . parse p
where
refill = handle handleEOF $ toStrict <$> tReadAll (getTransport prot) 1
refill = handle handleEOF $ LBS.toStrict <$> readByte prot
getResult (Done _ a) = return a
getResult (Partial k) = refill >>= getResult . k
getResult f = throw $ ProtocolExn PE_INVALID_DATA (show f)
Expand Down
59 changes: 40 additions & 19 deletions lib/hs/src/Thrift/Protocol/Binary.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
module Thrift.Protocol.Binary
( module Thrift.Protocol
, BinaryProtocol(..)
, versionMask
, version1
) where

import Control.Exception ( throw )
Expand All @@ -35,6 +37,7 @@ import Data.Functor
import Data.Int
import Data.Monoid
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
import Data.Word

import Thrift.Protocol
import Thrift.Transport
Expand All @@ -47,37 +50,55 @@ import qualified Data.ByteString.Lazy as LBS
import qualified Data.HashMap.Strict as Map
import qualified Data.Text.Lazy as LT

data BinaryProtocol a = BinaryProtocol a
versionMask :: Int32
versionMask = fromIntegral (0xffff0000 :: Word32)

version1 :: Int32
version1 = fromIntegral (0x80010000 :: Word32)

data BinaryProtocol a = Transport a => BinaryProtocol a

getTransport :: Transport t => BinaryProtocol t -> t
getTransport (BinaryProtocol t) = t

-- NOTE: Reading and Writing functions rely on Builders and Data.Binary to
-- encode and decode data. Data.Binary assumes that the binary values it is
-- encoding to and decoding from are in BIG ENDIAN format, and converts the
-- endianness as necessary to match the local machine.
instance Protocol BinaryProtocol where
getTransport (BinaryProtocol t) = t

writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
buildBinaryValue (TString $ encodeUtf8 n) <>
buildBinaryValue (TI32 s)

readMessageBegin p = runParser p $ do
TI32 ver <- parseBinaryValue T_I32
if ver .&. versionMask /= version1
then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
else do
TString s <- parseBinaryValue T_STRING
TI32 sz <- parseBinaryValue T_I32
return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
instance Transport t => Protocol (BinaryProtocol t) where
readByte p = tReadAll (getTransport p) 1
-- flushTransport p = tFlush (getTransport p)
writeMessage p (n, t, s) f = do
tWrite (getTransport p) messageBegin
f
tFlush $ getTransport p
where
messageBegin = toLazyByteString $
buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
buildBinaryValue (TString $ encodeUtf8 n) <>
buildBinaryValue (TI32 s)

readMessage p = (readMessageBegin p >>=)
where
readMessageBegin p = runParser p $ do
TI32 ver <- parseBinaryValue T_I32
if ver .&. versionMask /= version1
then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
else do
TString s <- parseBinaryValue T_STRING
TI32 sz <- parseBinaryValue T_I32
return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)

writeVal p = tWrite (getTransport p) . toLazyByteString . buildBinaryValue
readVal p = runParser p . parseBinaryValue

instance Transport t => StatelessProtocol (BinaryProtocol t) where
serializeVal _ = toLazyByteString . buildBinaryValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseBinaryValue ty) bs of
Left s -> error s
Right val -> val

readVal p = runParser p . parseBinaryValue

-- | Writing Functions
buildBinaryValue :: ThriftVal -> Builder
buildBinaryValue (TStruct fields) = buildBinaryStruct fields <> buildType T_STOP
Expand Down
Loading

0 comments on commit 3c42007

Please sign in to comment.