diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..1ff0c423 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,63 @@ +############################################################################### +# Set default behavior to automatically normalize line endings. +############################################################################### +* text=auto + +############################################################################### +# Set default behavior for command prompt diff. +# +# This is need for earlier builds of msysgit that does not have it on by +# default for csharp files. +# Note: This is only used by command line +############################################################################### +#*.cs diff=csharp + +############################################################################### +# Set the merge driver for project and solution files +# +# Merging from the command prompt will add diff markers to the files if there +# are conflicts (Merging from VS is not affected by the settings below, in VS +# the diff markers are never inserted). Diff markers may cause the following +# file extensions to fail to load in VS. An alternative would be to treat +# these files as binary and thus will always conflict and require user +# intervention with every merge. To do so, just uncomment the entries below +############################################################################### +#*.sln merge=binary +#*.csproj merge=binary +#*.vbproj merge=binary +#*.vcxproj merge=binary +#*.vcproj merge=binary +#*.dbproj merge=binary +#*.fsproj merge=binary +#*.lsproj merge=binary +#*.wixproj merge=binary +#*.modelproj merge=binary +#*.sqlproj merge=binary +#*.wwaproj merge=binary + +############################################################################### +# behavior for image files +# +# image files are treated as binary by default. +############################################################################### +#*.jpg binary +#*.png binary +#*.gif binary + +############################################################################### +# diff behavior for common document formats +# +# Convert binary document formats to text before diffing them. This feature +# is only available from the command line. Turn it on by uncommenting the +# entries below. +############################################################################### +#*.doc diff=astextplain +#*.DOC diff=astextplain +#*.docx diff=astextplain +#*.DOCX diff=astextplain +#*.dot diff=astextplain +#*.DOT diff=astextplain +#*.pdf diff=astextplain +#*.PDF diff=astextplain +#*.rtf diff=astextplain +#*.RTF diff=astextplain diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..225dfec3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +.vs +artifacts +bin +obj +*.user +TestFiles/ +*.suo +Tools/PEVerify.* +packages/ +project.lock.json +Tests.Dnx/*.dll +Tests.Dnx/*.bin +Tests.Dnx/*.dat +Tests.Dnx/*.dump +Tests.Dnx/Data.protobuf +Tests.Dnx/big.file +Tests.Dnx/sandbox.txt +Tests.Dnx/protoTest.txt +Tests.Dnx/TraceCompile.txt +*.NoCommit.* +*.pubxml +del.me +_ReSharper.Caches/ +BenchmarkDotNet.Artifacts/ +filedata.bin +*.binlog +src/VBTest/* +src/Benchmark/DalSerializer.dll \ No newline at end of file diff --git a/Directory.build.props b/Directory.build.props new file mode 100644 index 00000000..4c38dd63 --- /dev/null +++ b/Directory.build.props @@ -0,0 +1,43 @@ + + + ProtoBuf + Marc Gravell + Library + true + False + NU5105;CS1701;BC42016 + $(MSBuildThisFileDirectory)ProtoBuf.snk + See https://github.com/mgravell/protobuf-net + https://github.com/mgravell/protobuf-net/blob/master/Licence.txt + https://github.com/mgravell/protobuf-net + https://github.com/mgravell/protobuf-net + git + protobuf-net ($(TargetFramework)) + https://mgravell.github.io/protobuf-net/releasenotes#$(VersionPrefix) + + binary;serialization;protobuf;grpc + + true + embedded + en-US + false + $(MSBuildProjectName.Contains('Test')) + preview + enable + $(MSBuildThisFileDirectory)Shared.ruleset + + + true + true + + + + + + + + all + runtime; build; native; contentfiles; analyzers + + + \ No newline at end of file diff --git a/ProtoBuf.snk b/ProtoBuf.snk new file mode 100644 index 00000000..5a323bf5 Binary files /dev/null and b/ProtoBuf.snk differ diff --git a/global.json b/global.json new file mode 100644 index 00000000..2900f204 --- /dev/null +++ b/global.json @@ -0,0 +1,5 @@ +{ + "sdk": { + "version": "3.0.100-preview7-012341" + } +} diff --git a/protobuf-net.Grpc.sln b/protobuf-net.Grpc.sln new file mode 100644 index 00000000..4baf6ec7 --- /dev/null +++ b/protobuf-net.Grpc.sln @@ -0,0 +1,41 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.29006.145 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "protobuf-net.Grpc", "src\protobuf-net.Grpc\protobuf-net.Grpc.csproj", "{30D0874E-DA1A-497E-A181-37739A46CF32}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{3E0CF81A-BA7A-4AAB-B46D-5AC8E22B0644}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "examples", "examples", "{F7FAC6AD-62B0-4B79-98AA-DBD99F84E4E9}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "grpc", "grpc", "{ABFDBC40-BB23-4E19-80C8-8ACC96671B63}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "dotnet-grpc", "dotnet-grpc", "{BCE4682E-1594-4DEF-BC23-35E8571FD002}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "pb-net-grpc", "pb-net-grpc", "{7F457A71-E0C6-4DB1-B692-56541CAEDB5A}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {30D0874E-DA1A-497E-A181-37739A46CF32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {30D0874E-DA1A-497E-A181-37739A46CF32}.Debug|Any CPU.Build.0 = Debug|Any CPU + {30D0874E-DA1A-497E-A181-37739A46CF32}.Release|Any CPU.ActiveCfg = Release|Any CPU + {30D0874E-DA1A-497E-A181-37739A46CF32}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {30D0874E-DA1A-497E-A181-37739A46CF32} = {3E0CF81A-BA7A-4AAB-B46D-5AC8E22B0644} + {ABFDBC40-BB23-4E19-80C8-8ACC96671B63} = {F7FAC6AD-62B0-4B79-98AA-DBD99F84E4E9} + {BCE4682E-1594-4DEF-BC23-35E8571FD002} = {F7FAC6AD-62B0-4B79-98AA-DBD99F84E4E9} + {7F457A71-E0C6-4DB1-B692-56541CAEDB5A} = {F7FAC6AD-62B0-4B79-98AA-DBD99F84E4E9} + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {BA14B07C-CA29-430D-A600-F37A050636D3} + EndGlobalSection +EndGlobal diff --git a/src/protobuf-net.Grpc/CallContext.cs b/src/protobuf-net.Grpc/CallContext.cs new file mode 100644 index 00000000..f1051d05 --- /dev/null +++ b/src/protobuf-net.Grpc/CallContext.cs @@ -0,0 +1,66 @@ +using Grpc.Core; +using ProtoBuf.Grpc.Internal; +using System; +using System.Threading; +using System.Runtime.CompilerServices; +using System.ComponentModel; + +namespace ProtoBuf.Grpc +{ + /// + /// Unifies the API for client and server gRPC call contexts; the API intersection is available + /// directly - for client-specific or server-specific options: use .Client or .Server; note that + /// whether this is a client or server context depends on the usage. Silent conversions are available. + /// + public readonly struct CallContext + { + public static readonly CallContext Default; // it is **not** accidental that this is a field - allows effective ldsflda usage + + public CallOptions Client { get; } + public ServerCallContext? Server { get; } + + public Metadata RequestHeaders => Client.Headers; + public CancellationToken CancellationToken => Client.CancellationToken; + public DateTime? Deadline => Client.Deadline; + public WriteOptions WriteOptions => Client.WriteOptions; + + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + internal MetadataContext? Prepare() => _metadataContext?.Reset(); + + public CallContext(ServerCallContext server) + { + Server = server; + Client = server == null ? default : new CallOptions(server.RequestHeaders, server.Deadline, server.CancellationToken, server.WriteOptions); + _metadataContext = null; + } + + public CallContext(in CallOptions client, CallContextFlags flags = CallContextFlags.None) + { + Client = client; + Server = default; + _metadataContext = (flags & CallContextFlags.CaptureMetadata) == 0 ? null : new MetadataContext(); + } + + private readonly MetadataContext? _metadataContext; + + public Metadata ResponseHeaders() => _metadataContext?.Headers ?? ThrowNoContext(); + + public Metadata ResponseTrailers() => _metadataContext?.Trailers ?? ThrowNoContext(); + + public Status ResponseStatus() => _metadataContext?.Status ?? ThrowNoContext(); + + [MethodImpl] + private T ThrowNoContext() + { + if (Server != null) throw new InvalidOperationException("Response metadata is not available for server contexts"); + throw new InvalidOperationException("The CaptureMetadata flag must be specified when creating the CallContext to enable response metadata"); + } + } + + [Flags] + public enum CallContextFlags + { + None = 0, + CaptureMetadata = 1, + } +} \ No newline at end of file diff --git a/src/protobuf-net.Grpc/Client/ClientFactory.cs b/src/protobuf-net.Grpc/Client/ClientFactory.cs new file mode 100644 index 00000000..9ab78175 --- /dev/null +++ b/src/protobuf-net.Grpc/Client/ClientFactory.cs @@ -0,0 +1,330 @@ +using Grpc.Core; +using Grpc.Net.Client; +using Microsoft.Extensions.Logging; +using ProtoBuf.Grpc.Internal; +using System; +using System.Linq; +using System.Net.Http; +using System.Reflection; +using System.Reflection.Emit; + +namespace ProtoBuf.Grpc.Client +{ + //public readonly struct ClientProxy : IDisposable + // where T : class + //{ + // private readonly ClientBase _client; + + // internal ClientProxy(ClientBase client) => _client = client; + + + // public T Channel + // { + // // assume default behaviour is for the client to implement it directly, but allow alternatives + // [MethodImpl(MethodImplOptions.AggressiveInlining)] + // get => (T)(object)_client; + // } + + // [MethodImpl(MethodImplOptions.AggressiveInlining)] + // public void Dispose() => (_client as IDisposable)?.Dispose(); + + // [MethodImpl(MethodImplOptions.AggressiveInlining)] + // public static implicit operator T (ClientProxy proxy) => proxy.Channel; + //} + + public static class ClientFactory + { + public static TService Create(HttpClient httpClient, ILoggerFactory? loggerFactory = null) + where TService : class => ProxyCache.Create(httpClient, loggerFactory); + public static TService Create(Channel channel) + where TService : class => ProxyCache.Create(channel); + public static TService Create(CallInvoker callInvoker) + where TService : class => ProxyCache.Create(callInvoker); + + internal readonly struct ProxyCache where TService : class + { + private static readonly ProxyCache s_factory = ProxyEmitter.CreateFactory(); + + public static TService Create(HttpClient httpClient, ILoggerFactory? loggerFactory) => s_factory._httpClient(httpClient, loggerFactory); + public static TService Create(CallInvoker callInvoker) => s_factory._callInvoker(callInvoker); + public static TService Create(Channel channel) => s_factory._channel(channel); + + private readonly Func _httpClient; + private readonly Func _callInvoker; + private readonly Func _channel; + // public readonly Func ClientBaseConfiguration; + + public ProxyCache(Type type) + { + if (!FindFactory(type, out _httpClient!)) _httpClient = (a, b) => throw new NotSupportedException(); + if (!FindFactory(type, out _callInvoker!)) _callInvoker = a => throw new NotSupportedException(); + if (!FindFactory(type, out _channel!)) _channel = a => throw new NotSupportedException(); + // if (!FindFactory(type, out ClientBaseConfiguration!)) ClientBaseConfiguration = a => throw new NotSupportedException(); + } + static bool FindFactory(Type type, out T? field) where T : Delegate + { + field = default; + if (type == null) return false; + var invoke = typeof(T).GetMethod("Invoke"); + if (invoke == null) return false; + var signature = Array.ConvertAll(invoke.GetParameters(), x => x.ParameterType); + var factory = type.GetMethod(ProxyEmitter.FactoryName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static, null, signature, null); + if (factory == null) return false; + field = (T)Delegate.CreateDelegate(typeof(T), factory); + return true; + } + } + + // this **abstract** inheritance is just to get access to ClientBaseConfiguration + // (without that, this could be a static class) + abstract class ProxyEmitter : ClientBase + { + private ProxyEmitter() { } + + private static readonly string ProxyIdentity = typeof(ClientFactory).Namespace + ".Proxies"; + + private static readonly ModuleBuilder s_module = AssemblyBuilder.DefineDynamicAssembly( + new AssemblyName(ProxyIdentity), AssemblyBuilderAccess.Run).DefineDynamicModule(ProxyIdentity); + + private static readonly MethodInfo s_ClientBase_CallInvoker = typeof(ClientBase).GetProperty(nameof(CallInvoker), + BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)!.GetGetMethod(true)!, + s_Object_ToString = typeof(object).GetMethod(nameof(object.ToString))!; + private static readonly FieldInfo s_CallContext_Default = typeof(CallContext).GetField(nameof(CallContext.Default))!; + + private static void Ldc_I4(ILGenerator il, int value) + { + switch (value) + { + case -1: il.Emit(OpCodes.Ldc_I4_M1); break; + case 0: il.Emit(OpCodes.Ldc_I4_0); break; + case 1: il.Emit(OpCodes.Ldc_I4_1); break; + case 2: il.Emit(OpCodes.Ldc_I4_2); break; + case 3: il.Emit(OpCodes.Ldc_I4_3); break; + case 4: il.Emit(OpCodes.Ldc_I4_4); break; + case 5: il.Emit(OpCodes.Ldc_I4_5); break; + case 6: il.Emit(OpCodes.Ldc_I4_6); break; + case 7: il.Emit(OpCodes.Ldc_I4_7); break; + case 8: il.Emit(OpCodes.Ldc_I4_8); break; + case int i when (i >= -128 & i < 127): il.Emit(OpCodes.Ldc_I4_S, (sbyte)i); break; + default: il.Emit(OpCodes.Ldc_I4, value); break; + } + } + + private static void LoadDefault(ILGenerator il) where T : struct + { + var local = il.DeclareLocal(typeof(T)); + Ldloca(il, local); + il.Emit(OpCodes.Initobj, typeof(T)); + Ldloc(il, local); + } + + private static void Ldloc(ILGenerator il, LocalBuilder local) + { + switch (local.LocalIndex) + { + case 0: il.Emit(OpCodes.Ldloc_0); break; + case 1: il.Emit(OpCodes.Ldloc_1); break; + case 2: il.Emit(OpCodes.Ldloc_2); break; + case 3: il.Emit(OpCodes.Ldloc_3); break; + case int i when (i >= 0 & i <= 255): il.Emit(OpCodes.Ldloc_S, (byte)i); break; + default: il.Emit(OpCodes.Ldloc, local); break; + } + } + + private static void Ldloca(ILGenerator il, LocalBuilder local) + { + switch (local.LocalIndex) + { + case int i when (i >= 0 & i <= 255): il.Emit(OpCodes.Ldloca_S, (byte)i); break; + default: il.Emit(OpCodes.Ldloca, local); break; + } + } + private static void Ldarga(ILGenerator il, ushort index) + { + if (index <= 255) + { + il.Emit(OpCodes.Ldarga_S, (byte)index); + } + else + { + il.Emit(OpCodes.Ldarga, index); + } + } + private static void Ldarg(ILGenerator il, ushort index) + { + switch(index) + { + case 0: il.Emit(OpCodes.Ldarg_0); break; + case 1: il.Emit(OpCodes.Ldarg_1); break; + case 2: il.Emit(OpCodes.Ldarg_2); break; + case 3: il.Emit(OpCodes.Ldarg_3); break; + case ushort x when x <= 255: il.Emit(OpCodes.Ldarg_S, (byte)x); break; + default: il.Emit(OpCodes.Ldarg, index); break; + } + } + + internal static ProxyCache CreateFactory() + where TService : class + { + // front-load reflection discovery + if (!typeof(TService).IsInterface) + throw new InvalidOperationException("Type is not an interface: " + typeof(TService).FullName); + ContractOperation.TryGetServiceName(typeof(TService), out var serviceName); + var ops = ContractOperation.FindOperations(typeof(TService)); + + lock (s_module) + { + // private sealed class IFooProxy... + var type = s_module.DefineType(ProxyIdentity + "." + typeof(TService).Name + "_Proxy", + TypeAttributes.Class | TypeAttributes.Sealed | TypeAttributes.NotPublic | TypeAttributes.BeforeFieldInit); + + // : ClientBase + Type baseType = typeof(ClientBase); + type.SetParent(baseType); + + // : TService + type.AddInterfaceImplementation(typeof(TService)); + + // private IFooProxy() : base() { } + type.DefineDefaultConstructor(MethodAttributes.Private); + + // public IFooProxy(CallInvoker callInvoker) : base(callInvoker) { } + var ctorCallInvoker = WritePassThruCtor(MethodAttributes.Public); + + // public IFooProxy(Channel channel) : base(callIchannelnvoker) { } + var ctorChannel = WritePassThruCtor(MethodAttributes.Public); + + // private IFooProxy(ClientBaseConfiguration configuration) : base(configuration) { } + var ctorClientBaseConfig = WritePassThruCtor(MethodAttributes.Family); + + // override ToString + { + var toString = type.DefineMethod(nameof(ToString), s_Object_ToString.Attributes, s_Object_ToString.CallingConvention, + typeof(string), Type.EmptyTypes); + var il = toString.GetILGenerator(); + il.Emit(OpCodes.Ldstr, serviceName); + il.Emit(OpCodes.Ret); + type.DefineMethodOverride(toString, s_Object_ToString); + } + + var cctor = type.DefineTypeInitializer().GetILGenerator(); + + // add each method of the interface + int fieldIndex = 0; + foreach (var op in ops) + { + Type[] fromTo = new Type[] { op.From, op.To }; + // public static readonly Method s_{i} + var field = type.DefineField("s_op_" + fieldIndex++, typeof(Method<,>).MakeGenericType(fromTo), + FieldAttributes.Static | FieldAttributes.Public | FieldAttributes.InitOnly); + // = new FullyNamedMethod(opName, methodType, serviceName, method.Name); + cctor.Emit(OpCodes.Ldstr, op.Name); // opName + Ldc_I4(cctor, (int)op.MethodType); // methodType + cctor.Emit(OpCodes.Ldstr, serviceName); // serviceName + cctor.Emit(OpCodes.Ldnull); // methodName: leave null (uses opName) + cctor.Emit(OpCodes.Ldnull); // requestMarshaller: always null + cctor.Emit(OpCodes.Ldnull); // responseMarshaller: always null + cctor.Emit(OpCodes.Newobj, typeof(FullyNamedMethod<,>).MakeGenericType(fromTo) + .GetConstructors(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance).Single()); // new FullyNamedMethod + cctor.Emit(OpCodes.Stsfld, field); + + var impl = type.DefineMethod(typeof(TService).Name + "." + op.Method.Name, + MethodAttributes.HideBySig | MethodAttributes.Final | MethodAttributes.NewSlot | MethodAttributes.Private | MethodAttributes.Virtual, + op.Method.CallingConvention, op.Method.ReturnType, op.ParameterTypes); + + // implement the method + var il = impl.GetILGenerator(); + + switch(op.Context) + { + case ContextKind.CallOptions: + // we only support this for signatures that match the exat google pattern, but: + // defer for now + il.ThrowException(typeof(NotImplementedException)); + break; + case ContextKind.NoContext: + case ContextKind.CallContext: + // typically looks something like (where this is an extension method on Reshape): + // => context.{ReshapeMethod}(CallInvoker, {method}, request, [host: null]); + var method = op.TryGetClientHelper(); + if (method == null) + { + // unexpected, but... + il.ThrowException(typeof(NotSupportedException)); + } + else + { + if (op.Context == ContextKind.CallContext) + { + Ldarga(il, 2); + } + else + { + il.Emit(OpCodes.Ldsflda, s_CallContext_Default); + } + il.Emit(OpCodes.Ldarg_0); // this. + il.EmitCall(OpCodes.Callvirt, s_ClientBase_CallInvoker, null); // get_CallInvoker + + il.Emit(OpCodes.Ldsfld, field); // {method} + il.Emit(OpCodes.Ldarg_1); // request + il.Emit(OpCodes.Ldnull); // host (always null) + il.EmitCall(OpCodes.Call, method, null); + il.Emit(OpCodes.Ret); // return + } + break; + case ContextKind.ServerCallContext: // server call? we're writing a client! + default: // who knows! + il.ThrowException(typeof(NotSupportedException)); + break; + } + + // mark it as the interface implementation + type.DefineMethodOverride(impl, op.Method); + } + + cctor.Emit(OpCodes.Ret); // end the type initializer (after creating all the field types) + + // write a factory method + WriteFactory(new[] { typeof(HttpClient), typeof(ILoggerFactory) }, typeof(HttpClientCallInvoker), ctorCallInvoker); + WriteFactory(new[] { typeof(CallInvoker) }, null, ctorCallInvoker); + WriteFactory(new[] { typeof(Channel) }, null, ctorChannel); + // WriteFactory(new[] { typeof(ClientBaseConfiguration) }, null, ctorClientBaseConfig); + + // return the factory + return new ProxyCache(type.CreateType()); + + void WriteFactory(Type[] signature, Type? via, ConstructorBuilder? ctor) + { + if (ctor == null) return; + ConstructorInfo? viaCtor = via?.GetConstructor(signature); + if (via != null && viaCtor == null) return; // nope! + + var factory = type.DefineMethod(FactoryName, MethodAttributes.Public | MethodAttributes.Static, typeof(TService), signature); + var il = factory.GetILGenerator(); + for (ushort i = 0; i < signature.Length; i++) + Ldarg(il, i); + if (viaCtor != null) il.Emit(OpCodes.Newobj, viaCtor); + il.Emit(OpCodes.Newobj, ctor); + il.Emit(OpCodes.Ret); + } + + ConstructorBuilder? WritePassThruCtor(MethodAttributes accessibility) + { + var signature = new[] { typeof(T) }; + var baseCtor = baseType.GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, signature, null); + if (baseCtor == null) return null; + + var ctor = type.DefineConstructor(accessibility, CallingConventions.HasThis, signature); + var il = ctor.GetILGenerator(); + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Ldarg_1); + il.Emit(OpCodes.Call, baseCtor); + il.Emit(OpCodes.Ret); + return ctor; + } + } + } + internal const string FactoryName = "Create"; + } + } +} diff --git a/src/protobuf-net.Grpc/Internal/ContractOperation.cs b/src/protobuf-net.Grpc/Internal/ContractOperation.cs new file mode 100644 index 00000000..09b1f710 --- /dev/null +++ b/src/protobuf-net.Grpc/Internal/ContractOperation.cs @@ -0,0 +1,325 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using Grpc.Core; +using System.ServiceModel; +using System.Buffers; +using System.Threading.Tasks; +using System.Linq; + +namespace ProtoBuf.Grpc.Internal +{ + internal readonly struct ContractOperation + { + public string Name { get; } + public Type From { get; } + public Type To { get; } + public MethodInfo Method { get; } + public Type[] ParameterTypes { get; } + public MethodType MethodType { get; } + public ContextKind Context { get; } + public ResultKind Result { get; } + + public override string ToString() => $"{Name}: {From.Name}=>{To.Name}, {MethodType}, {Result}, {Context}"; + + public ContractOperation(string name, Type from, Type to, MethodInfo method, + MethodType methodType, ContextKind contextKind, ResultKind resultKind, + Type[] parameterTypes) + { + Name = name; + From = from; + To = to; + Method = method; + MethodType = methodType; + Context = contextKind; + Result = resultKind; + ParameterTypes = parameterTypes; + } + + public static bool TryGetServiceName(Type contractType, out string? serviceName, bool demandAttribute = false) + { + var sca = (ServiceContractAttribute?)Attribute.GetCustomAttribute(contractType, typeof(ServiceContractAttribute), inherit: true); + if (demandAttribute && sca == null) + { + serviceName = null; + return false; + } + serviceName = sca?.Name; + if (string.IsNullOrWhiteSpace(serviceName)) + { + serviceName = contractType.Name; + if (contractType.IsInterface && serviceName.StartsWith('I')) serviceName = serviceName.Substring(1); // IFoo => Foo + serviceName = contractType.Namespace + serviceName; // Whatever.Foo + serviceName = serviceName.Replace('+', '.'); // nested types + } + return !string.IsNullOrWhiteSpace(serviceName); + } + + // do **not** replace these with a `params` etc version; the point here is to be as cheap + // as possible for misses + internal static bool IsMatch(Type returnType, ParameterInfo[] parameters, Type?[] types, Type? tRet) + => parameters.Length == 0 + && IsMatch(tRet, returnType, out types[0]); + internal static bool IsMatch(Type returnType, ParameterInfo[] parameters, Type?[] types, Type? t0, Type? tRet) + => parameters.Length == 1 + && IsMatch(t0, parameters[0].ParameterType, out types[0]) + && IsMatch(tRet, returnType, out types[1]); + internal static bool IsMatch(Type returnType, ParameterInfo[] parameters, Type?[] types, Type? t0, Type? t1, Type? tRet) + => parameters.Length == 2 + && IsMatch(t0, parameters[0].ParameterType, out types[0]) + && IsMatch(t1, parameters[1].ParameterType, out types[1]) + && IsMatch(tRet, returnType, out types[2]); + internal static bool IsMatch(Type returnType, ParameterInfo[] parameters, Type?[] types, Type? t0, Type? t1, Type? t2, Type? tRet) + => parameters.Length == 3 + && IsMatch(t0, parameters[0].ParameterType, out types[0]) + && IsMatch(t1, parameters[1].ParameterType, out types[1]) + && IsMatch(t2, parameters[2].ParameterType, out types[2]) + && IsMatch(tRet, returnType, out types[3]); + + private static bool IsMatch(in Type? template, in Type actual, out Type result) + { + if (template == null || template == actual) + { + result = actual; + return true; + } // fine + if (actual.IsGenericType && template.IsGenericTypeDefinition + && actual.GetGenericTypeDefinition() == template) + { + // expected Foo<>, got Foo: report T + result = actual.GetGenericArguments()[0]; + return true; + } + result = typeof(void); + return false; + } + + public static List FindOperations(Type contractType, bool demandAttribute = false) + { + var ops = new List(); + var types = ArrayPool.Shared.Rent(10); + try + { + foreach (var method in contractType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)) + { + if (method.IsGenericMethodDefinition) continue; // can't work with methods + + var oca = (OperationContractAttribute?)Attribute.GetCustomAttribute(method, typeof(OperationContractAttribute), inherit: true); + if (demandAttribute && oca == null) continue; + string? opName = oca?.Name; + if (string.IsNullOrWhiteSpace(opName)) + { + opName = method.Name; + if (opName.EndsWith("Async")) + opName = opName.Substring(0, opName.Length - 5); + } + if (string.IsNullOrWhiteSpace(opName)) continue; + + var args = method.GetParameters(); + if (args.Length == 0) continue; // no way of inferring anything! + + var ret = method.ReturnType; + + ContextKind contextKind = default; + MethodType methodType = default; + ResultKind resultKind = ResultKind.Unknown; + Type? from = null, to = null; + + void Configure(ContextKind ck, MethodType mt, ResultKind rt, Type f, Type t) + { + contextKind = ck; + methodType = mt; + resultKind = rt; + from = f; + to = t; + } + + // google server APIs + if (IsMatch(ret, args, types, typeof(IAsyncStreamReader<>), typeof(ServerCallContext), typeof(Task<>))) + { + Configure(ContextKind.ServerCallContext, MethodType.ClientStreaming, ResultKind.Task, types[0], types[2]); + } + else if (IsMatch(ret, args, types, typeof(IAsyncStreamReader<>), typeof(IAsyncStreamWriter<>), typeof(ServerCallContext), typeof(Task))) + { + Configure(ContextKind.ServerCallContext, MethodType.DuplexStreaming, ResultKind.Task, types[0], types[1]); + } + else if (IsMatch(ret, args, types, null, typeof(IServerStreamWriter<>), typeof(ServerCallContext), typeof(Task))) + { + Configure(ContextKind.ServerCallContext, MethodType.ServerStreaming, ResultKind.Task, types[0], types[1]); + } + else if (IsMatch(ret, args, types, null, typeof(ServerCallContext), typeof(Task<>))) + { + Configure(ContextKind.ServerCallContext, MethodType.Unary, ResultKind.Task, types[0], types[2]); + } + + // google client APIs + else if (IsMatch(ret, args, types, null, typeof(CallOptions), typeof(AsyncUnaryCall<>))) + { + Configure(ContextKind.CallOptions, MethodType.Unary, ResultKind.Grpc, types[0], types[2]); + } + else if (IsMatch(ret, args, types, typeof(CallOptions), typeof(AsyncClientStreamingCall<,>))) + { + Configure(ContextKind.CallOptions, MethodType.ClientStreaming, ResultKind.Grpc, types[1], ret.GetGenericArguments()[1]); + } + else if (IsMatch(ret, args, types, typeof(CallOptions), typeof(AsyncDuplexStreamingCall<,>))) + { + Configure(ContextKind.CallOptions, MethodType.DuplexStreaming, ResultKind.Grpc, types[1], ret.GetGenericArguments()[1]); + } + else if (IsMatch(ret, args, types, null, typeof(CallOptions), typeof(AsyncServerStreamingCall<>))) + { + Configure(ContextKind.CallOptions, MethodType.ServerStreaming, ResultKind.Grpc, types[0], types[2]); + } + else if (IsMatch(ret, args, types, null, typeof(CallOptions), null)) + { + Configure(ContextKind.CallOptions, MethodType.Unary, ResultKind.Sync, types[0], types[2]); + } + + + else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(CallContext), typeof(IAsyncEnumerable<>))) + { + Configure(ContextKind.CallContext, MethodType.DuplexStreaming, ResultKind.AsyncEnumerable, types[0], types[2]); + } + else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(IAsyncEnumerable<>))) + { + Configure(ContextKind.NoContext, MethodType.DuplexStreaming, ResultKind.AsyncEnumerable, types[0], types[1]); + } + else if (IsMatch(ret, args, types, null, typeof(CallContext), typeof(IAsyncEnumerable<>))) + { + Configure(ContextKind.CallContext, MethodType.ServerStreaming, ResultKind.AsyncEnumerable, types[0], types[2]); + } + else if (IsMatch(ret, args, types, null, typeof(IAsyncEnumerable<>))) + { + Configure(ContextKind.NoContext, MethodType.ServerStreaming, ResultKind.AsyncEnumerable, types[0], types[1]); + } + else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(CallContext), typeof(Task<>))) + { + Configure(ContextKind.CallContext, MethodType.ClientStreaming, ResultKind.Task, types[0], types[2]); + } + else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(Task<>))) + { + Configure(ContextKind.NoContext, MethodType.ClientStreaming, ResultKind.Task, types[0], types[1]); + } + else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(CallContext), typeof(ValueTask<>))) + { + Configure(ContextKind.CallContext, MethodType.ClientStreaming, ResultKind.ValueTask, types[0], types[2]); + } + else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(ValueTask<>))) + { + Configure(ContextKind.NoContext, MethodType.ClientStreaming, ResultKind.ValueTask, types[0], types[1]); + } + else if (IsMatch(ret, args, types, null, typeof(CallContext), typeof(Task<>))) + { + Configure(ContextKind.CallContext, MethodType.Unary, ResultKind.Task, types[0], types[2]); + } + else if (IsMatch(ret, args, types, null, typeof(Task<>))) + { + Configure(ContextKind.NoContext, MethodType.Unary, ResultKind.Task, types[0], types[1]); + } + else if (IsMatch(ret, args, types, null, typeof(CallContext), typeof(ValueTask<>))) + { + Configure(ContextKind.CallContext, MethodType.Unary, ResultKind.ValueTask, types[0], types[2]); + } + else if (IsMatch(ret, args, types, null, typeof(ValueTask<>))) + { + Configure(ContextKind.NoContext, MethodType.Unary, ResultKind.ValueTask, types[0], types[1]); + } + else if (IsMatch(ret, args, types, null, typeof(CallContext), null)) + { + Configure(ContextKind.CallContext, MethodType.Unary, ResultKind.Sync, types[0], types[2]); + } + else if (IsMatch(ret, args, types, null, null)) + { + Configure(ContextKind.NoContext, MethodType.Unary, ResultKind.Sync, types[0], types[1]); + } + + Type[] argTypes = Array.ConvertAll(args, x => x.ParameterType); + if (resultKind != ResultKind.Unknown && from != null && to != null) + { + ops.Add(new ContractOperation(opName, from, to, method, methodType, contextKind, resultKind, argTypes)); + } + } + } + finally + { + ArrayPool.Shared.Return(types); + } + return ops; + } + + + internal MethodInfo? TryGetClientHelper() + { + var name = GetClientHelperName(); + if (name == null || !s_reshaper.TryGetValue(name, out var method)) return null; + return method.MakeGenericMethod(From, To); + } +#pragma warning disable CS0618 + static readonly Dictionary s_reshaper = + + (from method in typeof(Reshape).GetMethods(BindingFlags.Public | BindingFlags.Static) + where method.IsGenericMethodDefinition + let parameters = method.GetParameters() + where parameters[1].ParameterType == typeof(CallInvoker) + && parameters[0].ParameterType == typeof(CallContext).MakeByRefType() + select method).ToDictionary(x => x.Name); + + static readonly Dictionary<(MethodType, ResultKind), string> _clientReshapeMap = new Dictionary<(MethodType, ResultKind), string> + { + {(MethodType.DuplexStreaming, ResultKind.AsyncEnumerable), nameof(Reshape.DuplexAsync) }, + {(MethodType.ServerStreaming, ResultKind.AsyncEnumerable), nameof(Reshape.ServerStreamingAsync) }, + {(MethodType.ClientStreaming, ResultKind.Task), nameof(Reshape.ClientStreamingTaskAsync) }, + {(MethodType.ClientStreaming, ResultKind.ValueTask), nameof(Reshape.ClientStreamingValueTaskAsync) }, + {(MethodType.Unary, ResultKind.Task), nameof(Reshape.UnaryTaskAsync) }, + {(MethodType.Unary, ResultKind.ValueTask), nameof(Reshape.UnaryValueTaskAsync) }, + {(MethodType.Unary, ResultKind.Sync), nameof(Reshape.UnarySync) }, + }; +#pragma warning restore CS0618 + private string? GetClientHelperName() + { + switch (Context) + { + case ContextKind.CallContext: + case ContextKind.NoContext: + return _clientReshapeMap.TryGetValue((MethodType, Result), out var helper) ? helper : null; + default: + return null; + } + } + + + internal bool IsSyncT() + { + return Method.ReturnType == To; + } + internal bool IsTaskT() + { + var ret = Method.ReturnType; + return ret.IsGenericType && ret.GetGenericTypeDefinition() == typeof(Task<>) + && ret.GetGenericArguments()[0] == To; + } + internal bool IsValueTaskT() + { + var ret = Method.ReturnType; + return ret.IsGenericType && ret.GetGenericTypeDefinition() == typeof(ValueTask<>) + && ret.GetGenericArguments()[0] == To; + } + } + + internal enum ContextKind + { + NoContext, // no context + CallContext, // pb-net shared context kind + CallOptions, // GRPC core client context kind + ServerCallContext, // GRPC core server context kind + } + + internal enum ResultKind + { + Unknown, + Sync, + Task, + ValueTask, + AsyncEnumerable, + Grpc, + } +} diff --git a/src/protobuf-net.Grpc/Internal/Empty.cs b/src/protobuf-net.Grpc/Internal/Empty.cs new file mode 100644 index 00000000..807f3915 --- /dev/null +++ b/src/protobuf-net.Grpc/Internal/Empty.cs @@ -0,0 +1,21 @@ +using Grpc.Core; +using System; +using System.ComponentModel; + +namespace ProtoBuf.Grpc.Internal +{ + [Obsolete(Reshape.WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Empty : IEquatable + { + public static Empty Instance { get; } = new Empty(); + private Empty() { } + public override string ToString() => nameof(Empty); + public override bool Equals(object? obj) => obj is Empty; + public override int GetHashCode() => 42; + bool IEquatable.Equals(Empty other) => other != null; + + internal static readonly Marshaller Marshaller + = new Marshaller((Empty _)=> Array.Empty(), (byte[] _) => Instance); + } +} diff --git a/src/protobuf-net.Grpc/Internal/FullyNamedMethod.cs b/src/protobuf-net.Grpc/Internal/FullyNamedMethod.cs new file mode 100644 index 00000000..e5af88ff --- /dev/null +++ b/src/protobuf-net.Grpc/Internal/FullyNamedMethod.cs @@ -0,0 +1,25 @@ +using Grpc.Core; + +namespace ProtoBuf.Grpc.Internal +{ + public class FullyNamedMethod : Method, IMethod + { + private readonly string _fullName; + + public FullyNamedMethod( + string operationName, + MethodType type, + string serviceName, + string? methodName = null, + Marshaller? requestMarshaller = null, + Marshaller? responseMarshaller = null) + : base(type, serviceName, methodName ?? operationName, + requestMarshaller ?? MarshallerCache.Instance, + responseMarshaller ?? MarshallerCache.Instance) + { + _fullName = serviceName + "/" + operationName; + } + + string IMethod.FullName => _fullName; + } +} diff --git a/src/protobuf-net.Grpc/Internal/MarshallerCache.cs b/src/protobuf-net.Grpc/Internal/MarshallerCache.cs new file mode 100644 index 00000000..20d32680 --- /dev/null +++ b/src/protobuf-net.Grpc/Internal/MarshallerCache.cs @@ -0,0 +1,33 @@ +using Grpc.Core; +using ProtoBuf.Meta; +using System.IO; + +namespace ProtoBuf.Grpc.Internal +{ + internal static class MarshallerCache + { +#pragma warning disable CS0618 + public static Marshaller Instance { get; } = typeof(T) == typeof(Empty) + ? (Marshaller)(object)Empty.Marshaller : new Marshaller(Serialize, Deserialize); +#pragma warning restore CS0618 + + private static readonly RuntimeTypeModel _model = RuntimeTypeModel.Default; + + private static T Deserialize(byte[] payload) + { + using (var reader = ProtoReader.Create(out var state, payload, _model)) + { + return (T)_model.Deserialize(reader, ref state, null, typeof(T)); + } + } + + private static byte[] Serialize(T value) + { + using (var ms = new MemoryStream()) + { + Serializer.Serialize(ms, value); + return ms.ToArray(); + } + } + } +} diff --git a/src/protobuf-net.Grpc/Internal/MetadataContext.cs b/src/protobuf-net.Grpc/Internal/MetadataContext.cs new file mode 100644 index 00000000..8cb60451 --- /dev/null +++ b/src/protobuf-net.Grpc/Internal/MetadataContext.cs @@ -0,0 +1,34 @@ +using System; +using System.Runtime.CompilerServices; +using Grpc.Core; + +namespace ProtoBuf.Grpc.Internal +{ + internal sealed class MetadataContext + { + internal MetadataContext() { } + + private Metadata? _headers, _trailers; + internal Metadata Headers + { + get => _headers ?? Throw("Headers are not yet available"); + set => _headers = value; + } + internal Metadata Trailers + { + get => _trailers ?? Throw("Trailers are not yet available"); + set => _trailers = value; + } + internal Status Status { get; set; } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static Metadata Throw(string message) => throw new InvalidOperationException(message); + + internal MetadataContext Reset() + { + Status = default; + _headers = _trailers = null; + return this; + } + } +} diff --git a/src/protobuf-net.Grpc/Internal/Reshape.cs b/src/protobuf-net.Grpc/Internal/Reshape.cs new file mode 100644 index 00000000..222bddb5 --- /dev/null +++ b/src/protobuf-net.Grpc/Internal/Reshape.cs @@ -0,0 +1,221 @@ +using Grpc.Core; +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +namespace ProtoBuf.Grpc.Internal +{ + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static class Reshape + { + internal const string WarningMessage = "This API is intended for use by runtime-generated code; all methods can be changed without notice - it is only guaranteed to work with the internally generated code"; + + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static async IAsyncEnumerable AsAsyncEnumerable(this IAsyncStreamReader reader, [EnumeratorCancellation] CancellationToken cancellationToken) + { + using (reader) + { + while (await reader.MoveNext(cancellationToken)) + { + yield return reader.Current; + } + } + } + + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static async Task WriteTo(this IAsyncEnumerable reader, IServerStreamWriter writer, CancellationToken cancellationToken) + { + await using (var iter = reader.GetAsyncEnumerator(cancellationToken)) + { + while (await iter.MoveNextAsync()) + { + await writer.WriteAsync(iter.Current); + } + } + } + + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static TResponse UnarySync( + this in CallContext context, + CallInvoker invoker, Method method, TRequest request, string? host = null) + where TRequest : class + where TResponse : class + { + context.Prepare(); + return invoker.BlockingUnaryCall(method, host, context.Client, request); + } + + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static Task UnaryTaskAsync( + this in CallContext context, + CallInvoker invoker, Method method, TRequest request, string? host = null) + where TRequest : class + where TResponse : class + => UnaryTaskAsyncImpl(invoker.AsyncUnaryCall(method, host, context.Client, request), context.Prepare()); + + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static ValueTask UnaryValueTaskAsync( + this in CallContext context, CallInvoker invoker, + Method method, TRequest request, string? host = null) + where TRequest : class + where TResponse : class + => new ValueTask(UnaryTaskAsyncImpl(invoker.AsyncUnaryCall(method, host, context.Client, request), context.Prepare())); + + private static async Task UnaryTaskAsyncImpl( + AsyncUnaryCall call, MetadataContext? metadata) + { + using (call) + { + if (metadata != null) metadata.Headers = await call.ResponseHeadersAsync; + var value = await call; + if (metadata != null) + { + metadata.Trailers = call.GetTrailers(); + metadata.Status = call.GetStatus(); + } + return value; + } + } + + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static IAsyncEnumerable ServerStreamingAsync( + this in CallContext context, + CallInvoker invoker, Method method, TRequest request, string? host = null) + where TRequest : class + where TResponse : class + => ServerStreamingAsyncImpl(invoker.AsyncServerStreamingCall(method, host, context.Client, request), context.Prepare(), context.CancellationToken); + + private static async IAsyncEnumerable ServerStreamingAsyncImpl( + AsyncServerStreamingCall call, MetadataContext? metadata, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + using (call) + { + if (metadata != null) metadata.Headers = await call.ResponseHeadersAsync; + + using (var seq = call.ResponseStream) + { + while (await seq.MoveNext(default)) + { + yield return seq.Current; + } + if (metadata != null) + { + metadata.Trailers = call.GetTrailers(); + metadata.Status = call.GetStatus(); + } + } + } + } + + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static Task ClientStreamingTaskAsync( + this in CallContext options, + CallInvoker invoker, Method method, IAsyncEnumerable request, string? host = null) + where TRequest : class + where TResponse : class + => ClientStreamingTaskAsyncImpl(invoker.AsyncClientStreamingCall(method, host, options.Client), options.Prepare(), options.CancellationToken, request); + + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static ValueTask ClientStreamingValueTaskAsync( + this in CallContext options, + CallInvoker invoker, Method method, IAsyncEnumerable request, string? host = null) + where TRequest : class + where TResponse : class + => new ValueTask(ClientStreamingTaskAsyncImpl(invoker.AsyncClientStreamingCall(method, host, options.Client), options.Prepare(), options.CancellationToken, request)); + + private static async Task ClientStreamingTaskAsyncImpl( + AsyncClientStreamingCall call, MetadataContext? metadata, + CancellationToken cancellationToken, IAsyncEnumerable request) + { + using (call) + { + var output = call.RequestStream; + await using (var iter = request.GetAsyncEnumerator(cancellationToken)) + { + while (await iter.MoveNextAsync()) + { + await output.WriteAsync(iter.Current); + } + } + await output.CompleteAsync(); + + if (metadata != null) metadata.Headers = await call.ResponseHeadersAsync; + + var result = await call.ResponseAsync; + + if (metadata != null) + { + metadata.Trailers = call.GetTrailers(); + metadata.Status = call.GetStatus(); + } + return result; + } + } + + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static IAsyncEnumerable DuplexAsync( + this in CallContext options, + CallInvoker invoker, Method method, IAsyncEnumerable request, string? host = null) + where TRequest : class + where TResponse : class + => DuplexAsyncImpl(invoker.AsyncDuplexStreamingCall(method, host, options.Client), options.Prepare(), options.CancellationToken, request); + + private static async IAsyncEnumerable DuplexAsyncImpl( + AsyncDuplexStreamingCall call, MetadataContext? metadata, + [EnumeratorCancellation] CancellationToken cancellationToken, IAsyncEnumerable request) + { + using (call) + { + // we'll run the "send" as a concurrent operation + var sendAll = Task.Run(() => SendAll(call.RequestStream, request, cancellationToken)); + + if (metadata != null) metadata.Headers = await call.ResponseHeadersAsync; + + using (var seq = call.ResponseStream) + { + while (await seq.MoveNext(default)) + { + yield return seq.Current; + } + await sendAll; // observe any problems from sending + + if (metadata != null) + { + metadata.Trailers = call.GetTrailers(); + metadata.Status = call.GetStatus(); + } + } + } + } + + private static async Task SendAll(IClientStreamWriter output, IAsyncEnumerable request, CancellationToken cancellationToken) + { + try + { + await using (var iter = request.GetAsyncEnumerator(cancellationToken)) + { + while (await iter.MoveNextAsync()) + { + var item = iter.Current; + await output.WriteAsync(item); + } + } + await output.CompleteAsync(); + } + catch (TaskCanceledException) { } + } + } +} diff --git a/src/protobuf-net.Grpc/Server/ServicesExtensions.cs b/src/protobuf-net.Grpc/Server/ServicesExtensions.cs new file mode 100644 index 00000000..3af1ae1f --- /dev/null +++ b/src/protobuf-net.Grpc/Server/ServicesExtensions.cs @@ -0,0 +1,239 @@ +using Grpc.AspNetCore.Server.Model; +using Grpc.Core; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Logging; +using ProtoBuf.Grpc.Internal; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.ServiceModel; +using System.Threading.Tasks; + +namespace ProtoBuf.Grpc.Server +{ + public static class ServicesExtensions + { + public static void AddCodeFirstGrpc(this IServiceCollection services) + { + services.TryAddEnumerable(ServiceDescriptor.Singleton(typeof(IServiceMethodProvider<>), typeof(CodeFirstServiceMethodProvider<>))); + } + + private sealed class CodeFirstServiceMethodProvider : IServiceMethodProvider where TService : class + { + private readonly ILogger> _logger; + + public CodeFirstServiceMethodProvider(ILoggerFactory loggerFactory) + { + _logger = _logger = loggerFactory.CreateLogger>(); + } + public void OnServiceMethodDiscovery(ServiceMethodProviderContext context) + { + // ignore any services that are known to be the default handler + if (Attribute.IsDefined(typeof(TService), typeof(BindServiceMethodAttribute))) return; + + // we support methods that match suitable signatures, where: + // 1. (removed - no longer supported) the method is directly on TService and is marked [OperationContract] + // 2. the method is on an interface that TService implements, and the interface is marked [ServiceContract] + // AddMethodsForService(context,typeof(TService)); + + foreach (var iType in typeof(TService).GetInterfaces()) + { + AddMethodsForService(context, iType); + } + } + + static Expression ToTaskT(Expression expression) + { + var type = expression.Type; + if (type.IsGenericType) + { + if (type.GetGenericTypeDefinition() == typeof(Task<>)) + return expression; + if (type.GetGenericTypeDefinition() == typeof(ValueTask<>)) + return Expression.Call(expression, nameof(ValueTask.AsTask), null); + } + return Expression.Call(typeof(Task), nameof(Task.FromResult), new Type[] { expression.Type }, expression); + } + + internal static readonly ConstructorInfo s_CallContext_FromServerContext = typeof(CallContext).GetConstructor(new[] { typeof(ServerCallContext) })!; + static Expression ToCallContext(Expression context) => Expression.New(s_CallContext_FromServerContext, context); +#pragma warning disable CS0618 + static Expression AsAsyncEnumerable(ParameterExpression value, ParameterExpression context) + => Expression.Call(typeof(Reshape), nameof(Reshape.AsAsyncEnumerable), + typeArguments: value.Type.GetGenericArguments(), + arguments: new Expression[] { value, Expression.Property(context, nameof(ServerCallContext.CancellationToken)) }); + + static Expression WriteTo(Expression value, ParameterExpression writer, ParameterExpression context) + => Expression.Call(typeof(Reshape), nameof(Reshape.WriteTo), + typeArguments: value.Type.GetGenericArguments(), + arguments: new Expression[] {value, writer, Expression.Property(context, nameof(ServerCallContext.CancellationToken)) }); +#pragma warning restore CS0618 + + static readonly Dictionary<(MethodType, ContextKind, ResultKind), Func?> _invokers + = new Dictionary<(MethodType, ContextKind, ResultKind), Func?> + { + // GRPC-style server methods are direct match; no mapping required + // => service.{method}(args) + { (MethodType.Unary, ContextKind.ServerCallContext, ResultKind.Task), null }, + { (MethodType.ServerStreaming, ContextKind.ServerCallContext, ResultKind.Task), null }, + { (MethodType.ClientStreaming, ContextKind.ServerCallContext, ResultKind.Task), null }, + { (MethodType.DuplexStreaming, ContextKind.ServerCallContext, ResultKind.Task), null }, + + // Unary: Task Foo(TService service, TRequest request, ServerCallContext serverCallContext); + // => service.{method}(request, [new CallContext(serverCallContext)]) + {(MethodType.Unary, ContextKind.NoContext, ResultKind.Task), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1])) }, + {(MethodType.Unary, ContextKind.NoContext, ResultKind.ValueTask), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1])) }, + {(MethodType.Unary, ContextKind.NoContext, ResultKind.Sync), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1])) }, + + {(MethodType.Unary, ContextKind.CallContext, ResultKind.Task), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1], ToCallContext(args[2]))) }, + {(MethodType.Unary, ContextKind.CallContext, ResultKind.ValueTask), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1], ToCallContext(args[2]))) }, + {(MethodType.Unary, ContextKind.CallContext, ResultKind.Sync), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1], ToCallContext(args[2]))) }, + + // Client Streaming: Task Foo(TService service, IAsyncStreamReader stream, ServerCallContext serverCallContext); + // => service.{method}(reader.AsAsyncEnumerable(serverCallContext.CancellationToken), [new CallContext(serverCallContext)]) + {(MethodType.ClientStreaming, ContextKind.NoContext, ResultKind.Task), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]))) }, + {(MethodType.ClientStreaming, ContextKind.NoContext, ResultKind.ValueTask), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]))) }, + {(MethodType.ClientStreaming, ContextKind.NoContext, ResultKind.Sync), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]))) }, + + {(MethodType.ClientStreaming, ContextKind.CallContext, ResultKind.Task), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]), ToCallContext(args[2]))) }, + {(MethodType.ClientStreaming, ContextKind.CallContext, ResultKind.ValueTask), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]), ToCallContext(args[2]))) }, + {(MethodType.ClientStreaming, ContextKind.CallContext, ResultKind.Sync), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]), ToCallContext(args[2]))) }, + + // Server Streaming: Task Foo(TService service, TRequest request, IServerStreamWriter stream, ServerCallContext serverCallContext); + // => service.{method}(request, [new CallContext(serverCallContext)]).WriteTo(stream, serverCallContext.CancellationToken) + {(MethodType.ServerStreaming, ContextKind.NoContext, ResultKind.AsyncEnumerable), (method, args) => WriteTo(Expression.Call(args[0], method, args[1]), args[2], args[3])}, + {(MethodType.ServerStreaming, ContextKind.CallContext, ResultKind.AsyncEnumerable), (method, args) => WriteTo(Expression.Call(args[0], method, args[1], ToCallContext(args[3])), args[2], args[3])}, + + // Duplex: Task Foo(TService service, IAsyncStreamReader input, IServerStreamWriter output, ServerCallContext serverCallContext); + // => service.{method}(input.AsAsyncEnumerable(serverCallContext.CancellationToken), [new CallContext(serverCallContext)]).WriteTo(output, serverCallContext.CancellationToken) + {(MethodType.DuplexStreaming, ContextKind.NoContext, ResultKind.AsyncEnumerable), (method, args) => WriteTo(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[3])), args[2], args[3]) }, + {(MethodType.DuplexStreaming, ContextKind.CallContext, ResultKind.AsyncEnumerable), (method, args) => WriteTo(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[3]), ToCallContext(args[3])), args[2], args[3]) }, + }; + private void AddMethodsForService(ServiceMethodProviderContext context, Type serviceContract) + { + bool isPublicContract = typeof(TService) == serviceContract; + if (!ContractOperation.TryGetServiceName(serviceContract, out var serviceName, !isPublicContract)) return; + _logger.Log(LogLevel.Trace, "pb-net processing {0}/{1} as {2}", typeof(TService).Name, serviceContract.Name, serviceName); + object?[]? argsBuffer = null; + Type[] typesBuffer = Array.Empty(); + + int count = 0; + foreach (var op in ContractOperation.FindOperations(serviceContract, isPublicContract)) + { + if (_invokers.TryGetValue((op.MethodType, op.Context, op.Result), out var invoker) + && AddMethod(op.From, op.To, op.Method, op.MethodType, invoker)) + { + // yay! + count++; + } + else + { + _logger.Log(LogLevel.Warning, "operation cannot be hosted as a server: {0}", op); + } + } + if (count != 0) _logger.Log(LogLevel.Information, "{0} implementing service {1} (via '{2}') with {3} operation(s)", typeof(TService), serviceName, serviceContract.Name, count); + + bool AddMethod(Type @in, Type @out, MethodInfo m, MethodType t, Func? invoker = null) + { + try + { + if (typesBuffer.Length == 0) + { + typesBuffer = new Type[] { typeof(TService), typeof(void), typeof(void) }; + } + typesBuffer[1] = @in; + typesBuffer[2] = @out; + + if (argsBuffer == null) + { + argsBuffer = new object?[] { serviceName, null, null, context, _logger, null }; + } + argsBuffer[1] = m; + argsBuffer[2] = t; + argsBuffer[5] = invoker; + + s_addMethod.MakeGenericMethod(typesBuffer).Invoke(null, argsBuffer); + return true; + } + catch (Exception fail) + { + if (fail is TargetInvocationException tie) fail = tie.InnerException!; + _logger.Log(LogLevel.Error, "Failure processing {0}: {1}", m.Name, fail.Message); + return false; + } + } + } + } + + private static readonly MethodInfo s_addMethod = typeof(ServicesExtensions).GetMethod( + nameof(AddMethod), BindingFlags.Static | BindingFlags.NonPublic)!; + + private static void AddMethod( + string serviceName, MethodInfo method, MethodType methodType, + ServiceMethodProviderContext context, ILogger logger, + Func? invoker = null) + where TService : class + where TRequest : class + where TResponse : class + { + var oca = (OperationContractAttribute?)Attribute.GetCustomAttribute(method, typeof(OperationContractAttribute), inherit: true); + var operationName = oca?.Name; + if (string.IsNullOrWhiteSpace(operationName)) + { + operationName = method.Name; + if (operationName.EndsWith("Async")) operationName = operationName.Substring(0, operationName.Length - 5); + } + + var metadata = new List(); + // Add type metadata first so it has a lower priority + metadata.AddRange(typeof(TService).GetCustomAttributes(inherit: true)); + // Add method metadata last so it has a higher priority + metadata.AddRange(method.GetCustomAttributes(inherit: true)); + + TDelegate As() where TDelegate : Delegate + { + if (invoker == null) + { + // basic - direct call + return (TDelegate)Delegate.CreateDelegate(typeof(TDelegate), null, method); + } + var finalSignature = typeof(TDelegate).GetMethod("Invoke")!; + + var methodParameters = finalSignature.GetParameters(); + var lambdaParameters = Array.ConvertAll(methodParameters, p => Expression.Parameter(p.ParameterType, p.Name)); + var body = invoker?.Invoke(method, lambdaParameters); + var lambda = Expression.Lambda(body, lambdaParameters); + logger.Log(LogLevel.Trace, "mapped {0} via {1}", operationName, lambda); + return lambda.Compile(); + } + +#pragma warning disable CS8625 + switch (methodType) + { + case MethodType.Unary: + context.AddUnaryMethod( + new FullyNamedMethod(operationName, methodType, serviceName, method.Name), metadata, As>()); + break; + case MethodType.ClientStreaming: + context.AddClientStreamingMethod( + new FullyNamedMethod(operationName, methodType, serviceName, method.Name), metadata, As>()); + break; + case MethodType.ServerStreaming: + context.AddServerStreamingMethod( + new FullyNamedMethod(operationName, methodType, serviceName, method.Name), metadata, As>()); + break; + case MethodType.DuplexStreaming: + context.AddDuplexStreamingMethod( + new FullyNamedMethod(operationName, methodType, serviceName, method.Name), metadata, As>()); + break; + default: + throw new NotSupportedException(methodType.ToString()); + } +#pragma warning restore CS8625 + } + } +} \ No newline at end of file diff --git a/src/protobuf-net.Grpc/protobuf-net.Grpc.csproj b/src/protobuf-net.Grpc/protobuf-net.Grpc.csproj new file mode 100644 index 00000000..f66c22b3 --- /dev/null +++ b/src/protobuf-net.Grpc/protobuf-net.Grpc.csproj @@ -0,0 +1,29 @@ + + + + netcoreapp3.0 + ProtoBuf.Grpc + 0.1-preview1 + + + + + + + + + + + + + + + + ix + + + + \ No newline at end of file diff --git a/version.json b/version.json new file mode 100644 index 00000000..df70a602 --- /dev/null +++ b/version.json @@ -0,0 +1,13 @@ +{ + "$schema": "https://raw.githubusercontent.com/AArnott/Nerdbank.GitVersioning/master/src/NerdBank.GitVersioning/version.schema.json", + "version": "0.0.1-alpha.{height}", + "assemblyVersion": "1.0", + "nugetPackageVersion": { + "semVer": 2 + }, + "publicReleaseRefSpec": [ + "^refs/heads/master$", + "^refs/tags/v\\d+\\.\\d+", + "^refs/heads/split-reader" + ] +} \ No newline at end of file