diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 00000000..1ff0c423
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,63 @@
+###############################################################################
+# Set default behavior to automatically normalize line endings.
+###############################################################################
+* text=auto
+
+###############################################################################
+# Set default behavior for command prompt diff.
+#
+# This is need for earlier builds of msysgit that does not have it on by
+# default for csharp files.
+# Note: This is only used by command line
+###############################################################################
+#*.cs diff=csharp
+
+###############################################################################
+# Set the merge driver for project and solution files
+#
+# Merging from the command prompt will add diff markers to the files if there
+# are conflicts (Merging from VS is not affected by the settings below, in VS
+# the diff markers are never inserted). Diff markers may cause the following
+# file extensions to fail to load in VS. An alternative would be to treat
+# these files as binary and thus will always conflict and require user
+# intervention with every merge. To do so, just uncomment the entries below
+###############################################################################
+#*.sln merge=binary
+#*.csproj merge=binary
+#*.vbproj merge=binary
+#*.vcxproj merge=binary
+#*.vcproj merge=binary
+#*.dbproj merge=binary
+#*.fsproj merge=binary
+#*.lsproj merge=binary
+#*.wixproj merge=binary
+#*.modelproj merge=binary
+#*.sqlproj merge=binary
+#*.wwaproj merge=binary
+
+###############################################################################
+# behavior for image files
+#
+# image files are treated as binary by default.
+###############################################################################
+#*.jpg binary
+#*.png binary
+#*.gif binary
+
+###############################################################################
+# diff behavior for common document formats
+#
+# Convert binary document formats to text before diffing them. This feature
+# is only available from the command line. Turn it on by uncommenting the
+# entries below.
+###############################################################################
+#*.doc diff=astextplain
+#*.DOC diff=astextplain
+#*.docx diff=astextplain
+#*.DOCX diff=astextplain
+#*.dot diff=astextplain
+#*.DOT diff=astextplain
+#*.pdf diff=astextplain
+#*.PDF diff=astextplain
+#*.rtf diff=astextplain
+#*.RTF diff=astextplain
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..225dfec3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,28 @@
+.vs
+artifacts
+bin
+obj
+*.user
+TestFiles/
+*.suo
+Tools/PEVerify.*
+packages/
+project.lock.json
+Tests.Dnx/*.dll
+Tests.Dnx/*.bin
+Tests.Dnx/*.dat
+Tests.Dnx/*.dump
+Tests.Dnx/Data.protobuf
+Tests.Dnx/big.file
+Tests.Dnx/sandbox.txt
+Tests.Dnx/protoTest.txt
+Tests.Dnx/TraceCompile.txt
+*.NoCommit.*
+*.pubxml
+del.me
+_ReSharper.Caches/
+BenchmarkDotNet.Artifacts/
+filedata.bin
+*.binlog
+src/VBTest/*
+src/Benchmark/DalSerializer.dll
\ No newline at end of file
diff --git a/Directory.build.props b/Directory.build.props
new file mode 100644
index 00000000..4c38dd63
--- /dev/null
+++ b/Directory.build.props
@@ -0,0 +1,43 @@
+
+
+ ProtoBuf
+ Marc Gravell
+ Library
+ true
+ False
+ NU5105;CS1701;BC42016
+ $(MSBuildThisFileDirectory)ProtoBuf.snk
+ See https://github.com/mgravell/protobuf-net
+ https://github.com/mgravell/protobuf-net/blob/master/Licence.txt
+ https://github.com/mgravell/protobuf-net
+ https://github.com/mgravell/protobuf-net
+ git
+ protobuf-net ($(TargetFramework))
+ https://mgravell.github.io/protobuf-net/releasenotes#$(VersionPrefix)
+
+ binary;serialization;protobuf;grpc
+
+ true
+ embedded
+ en-US
+ false
+ $(MSBuildProjectName.Contains('Test'))
+ preview
+ enable
+ $(MSBuildThisFileDirectory)Shared.ruleset
+
+
+ true
+ true
+
+
+
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers
+
+
+
\ No newline at end of file
diff --git a/ProtoBuf.snk b/ProtoBuf.snk
new file mode 100644
index 00000000..5a323bf5
Binary files /dev/null and b/ProtoBuf.snk differ
diff --git a/global.json b/global.json
new file mode 100644
index 00000000..2900f204
--- /dev/null
+++ b/global.json
@@ -0,0 +1,5 @@
+{
+ "sdk": {
+ "version": "3.0.100-preview7-012341"
+ }
+}
diff --git a/protobuf-net.Grpc.sln b/protobuf-net.Grpc.sln
new file mode 100644
index 00000000..4baf6ec7
--- /dev/null
+++ b/protobuf-net.Grpc.sln
@@ -0,0 +1,41 @@
+
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio Version 16
+VisualStudioVersion = 16.0.29006.145
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "protobuf-net.Grpc", "src\protobuf-net.Grpc\protobuf-net.Grpc.csproj", "{30D0874E-DA1A-497E-A181-37739A46CF32}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{3E0CF81A-BA7A-4AAB-B46D-5AC8E22B0644}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "examples", "examples", "{F7FAC6AD-62B0-4B79-98AA-DBD99F84E4E9}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "grpc", "grpc", "{ABFDBC40-BB23-4E19-80C8-8ACC96671B63}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "dotnet-grpc", "dotnet-grpc", "{BCE4682E-1594-4DEF-BC23-35E8571FD002}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "pb-net-grpc", "pb-net-grpc", "{7F457A71-E0C6-4DB1-B692-56541CAEDB5A}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ Release|Any CPU = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {30D0874E-DA1A-497E-A181-37739A46CF32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {30D0874E-DA1A-497E-A181-37739A46CF32}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {30D0874E-DA1A-497E-A181-37739A46CF32}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {30D0874E-DA1A-497E-A181-37739A46CF32}.Release|Any CPU.Build.0 = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+ GlobalSection(NestedProjects) = preSolution
+ {30D0874E-DA1A-497E-A181-37739A46CF32} = {3E0CF81A-BA7A-4AAB-B46D-5AC8E22B0644}
+ {ABFDBC40-BB23-4E19-80C8-8ACC96671B63} = {F7FAC6AD-62B0-4B79-98AA-DBD99F84E4E9}
+ {BCE4682E-1594-4DEF-BC23-35E8571FD002} = {F7FAC6AD-62B0-4B79-98AA-DBD99F84E4E9}
+ {7F457A71-E0C6-4DB1-B692-56541CAEDB5A} = {F7FAC6AD-62B0-4B79-98AA-DBD99F84E4E9}
+ EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {BA14B07C-CA29-430D-A600-F37A050636D3}
+ EndGlobalSection
+EndGlobal
diff --git a/src/protobuf-net.Grpc/CallContext.cs b/src/protobuf-net.Grpc/CallContext.cs
new file mode 100644
index 00000000..f1051d05
--- /dev/null
+++ b/src/protobuf-net.Grpc/CallContext.cs
@@ -0,0 +1,66 @@
+using Grpc.Core;
+using ProtoBuf.Grpc.Internal;
+using System;
+using System.Threading;
+using System.Runtime.CompilerServices;
+using System.ComponentModel;
+
+namespace ProtoBuf.Grpc
+{
+ ///
+ /// Unifies the API for client and server gRPC call contexts; the API intersection is available
+ /// directly - for client-specific or server-specific options: use .Client or .Server; note that
+ /// whether this is a client or server context depends on the usage. Silent conversions are available.
+ ///
+ public readonly struct CallContext
+ {
+ public static readonly CallContext Default; // it is **not** accidental that this is a field - allows effective ldsflda usage
+
+ public CallOptions Client { get; }
+ public ServerCallContext? Server { get; }
+
+ public Metadata RequestHeaders => Client.Headers;
+ public CancellationToken CancellationToken => Client.CancellationToken;
+ public DateTime? Deadline => Client.Deadline;
+ public WriteOptions WriteOptions => Client.WriteOptions;
+
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ internal MetadataContext? Prepare() => _metadataContext?.Reset();
+
+ public CallContext(ServerCallContext server)
+ {
+ Server = server;
+ Client = server == null ? default : new CallOptions(server.RequestHeaders, server.Deadline, server.CancellationToken, server.WriteOptions);
+ _metadataContext = null;
+ }
+
+ public CallContext(in CallOptions client, CallContextFlags flags = CallContextFlags.None)
+ {
+ Client = client;
+ Server = default;
+ _metadataContext = (flags & CallContextFlags.CaptureMetadata) == 0 ? null : new MetadataContext();
+ }
+
+ private readonly MetadataContext? _metadataContext;
+
+ public Metadata ResponseHeaders() => _metadataContext?.Headers ?? ThrowNoContext();
+
+ public Metadata ResponseTrailers() => _metadataContext?.Trailers ?? ThrowNoContext();
+
+ public Status ResponseStatus() => _metadataContext?.Status ?? ThrowNoContext();
+
+ [MethodImpl]
+ private T ThrowNoContext()
+ {
+ if (Server != null) throw new InvalidOperationException("Response metadata is not available for server contexts");
+ throw new InvalidOperationException("The CaptureMetadata flag must be specified when creating the CallContext to enable response metadata");
+ }
+ }
+
+ [Flags]
+ public enum CallContextFlags
+ {
+ None = 0,
+ CaptureMetadata = 1,
+ }
+}
\ No newline at end of file
diff --git a/src/protobuf-net.Grpc/Client/ClientFactory.cs b/src/protobuf-net.Grpc/Client/ClientFactory.cs
new file mode 100644
index 00000000..9ab78175
--- /dev/null
+++ b/src/protobuf-net.Grpc/Client/ClientFactory.cs
@@ -0,0 +1,330 @@
+using Grpc.Core;
+using Grpc.Net.Client;
+using Microsoft.Extensions.Logging;
+using ProtoBuf.Grpc.Internal;
+using System;
+using System.Linq;
+using System.Net.Http;
+using System.Reflection;
+using System.Reflection.Emit;
+
+namespace ProtoBuf.Grpc.Client
+{
+ //public readonly struct ClientProxy : IDisposable
+ // where T : class
+ //{
+ // private readonly ClientBase _client;
+
+ // internal ClientProxy(ClientBase client) => _client = client;
+
+
+ // public T Channel
+ // {
+ // // assume default behaviour is for the client to implement it directly, but allow alternatives
+ // [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ // get => (T)(object)_client;
+ // }
+
+ // [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ // public void Dispose() => (_client as IDisposable)?.Dispose();
+
+ // [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ // public static implicit operator T (ClientProxy proxy) => proxy.Channel;
+ //}
+
+ public static class ClientFactory
+ {
+ public static TService Create(HttpClient httpClient, ILoggerFactory? loggerFactory = null)
+ where TService : class => ProxyCache.Create(httpClient, loggerFactory);
+ public static TService Create(Channel channel)
+ where TService : class => ProxyCache.Create(channel);
+ public static TService Create(CallInvoker callInvoker)
+ where TService : class => ProxyCache.Create(callInvoker);
+
+ internal readonly struct ProxyCache where TService : class
+ {
+ private static readonly ProxyCache s_factory = ProxyEmitter.CreateFactory();
+
+ public static TService Create(HttpClient httpClient, ILoggerFactory? loggerFactory) => s_factory._httpClient(httpClient, loggerFactory);
+ public static TService Create(CallInvoker callInvoker) => s_factory._callInvoker(callInvoker);
+ public static TService Create(Channel channel) => s_factory._channel(channel);
+
+ private readonly Func _httpClient;
+ private readonly Func _callInvoker;
+ private readonly Func _channel;
+ // public readonly Func ClientBaseConfiguration;
+
+ public ProxyCache(Type type)
+ {
+ if (!FindFactory(type, out _httpClient!)) _httpClient = (a, b) => throw new NotSupportedException();
+ if (!FindFactory(type, out _callInvoker!)) _callInvoker = a => throw new NotSupportedException();
+ if (!FindFactory(type, out _channel!)) _channel = a => throw new NotSupportedException();
+ // if (!FindFactory(type, out ClientBaseConfiguration!)) ClientBaseConfiguration = a => throw new NotSupportedException();
+ }
+ static bool FindFactory(Type type, out T? field) where T : Delegate
+ {
+ field = default;
+ if (type == null) return false;
+ var invoke = typeof(T).GetMethod("Invoke");
+ if (invoke == null) return false;
+ var signature = Array.ConvertAll(invoke.GetParameters(), x => x.ParameterType);
+ var factory = type.GetMethod(ProxyEmitter.FactoryName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static, null, signature, null);
+ if (factory == null) return false;
+ field = (T)Delegate.CreateDelegate(typeof(T), factory);
+ return true;
+ }
+ }
+
+ // this **abstract** inheritance is just to get access to ClientBaseConfiguration
+ // (without that, this could be a static class)
+ abstract class ProxyEmitter : ClientBase
+ {
+ private ProxyEmitter() { }
+
+ private static readonly string ProxyIdentity = typeof(ClientFactory).Namespace + ".Proxies";
+
+ private static readonly ModuleBuilder s_module = AssemblyBuilder.DefineDynamicAssembly(
+ new AssemblyName(ProxyIdentity), AssemblyBuilderAccess.Run).DefineDynamicModule(ProxyIdentity);
+
+ private static readonly MethodInfo s_ClientBase_CallInvoker = typeof(ClientBase).GetProperty(nameof(CallInvoker),
+ BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)!.GetGetMethod(true)!,
+ s_Object_ToString = typeof(object).GetMethod(nameof(object.ToString))!;
+ private static readonly FieldInfo s_CallContext_Default = typeof(CallContext).GetField(nameof(CallContext.Default))!;
+
+ private static void Ldc_I4(ILGenerator il, int value)
+ {
+ switch (value)
+ {
+ case -1: il.Emit(OpCodes.Ldc_I4_M1); break;
+ case 0: il.Emit(OpCodes.Ldc_I4_0); break;
+ case 1: il.Emit(OpCodes.Ldc_I4_1); break;
+ case 2: il.Emit(OpCodes.Ldc_I4_2); break;
+ case 3: il.Emit(OpCodes.Ldc_I4_3); break;
+ case 4: il.Emit(OpCodes.Ldc_I4_4); break;
+ case 5: il.Emit(OpCodes.Ldc_I4_5); break;
+ case 6: il.Emit(OpCodes.Ldc_I4_6); break;
+ case 7: il.Emit(OpCodes.Ldc_I4_7); break;
+ case 8: il.Emit(OpCodes.Ldc_I4_8); break;
+ case int i when (i >= -128 & i < 127): il.Emit(OpCodes.Ldc_I4_S, (sbyte)i); break;
+ default: il.Emit(OpCodes.Ldc_I4, value); break;
+ }
+ }
+
+ private static void LoadDefault(ILGenerator il) where T : struct
+ {
+ var local = il.DeclareLocal(typeof(T));
+ Ldloca(il, local);
+ il.Emit(OpCodes.Initobj, typeof(T));
+ Ldloc(il, local);
+ }
+
+ private static void Ldloc(ILGenerator il, LocalBuilder local)
+ {
+ switch (local.LocalIndex)
+ {
+ case 0: il.Emit(OpCodes.Ldloc_0); break;
+ case 1: il.Emit(OpCodes.Ldloc_1); break;
+ case 2: il.Emit(OpCodes.Ldloc_2); break;
+ case 3: il.Emit(OpCodes.Ldloc_3); break;
+ case int i when (i >= 0 & i <= 255): il.Emit(OpCodes.Ldloc_S, (byte)i); break;
+ default: il.Emit(OpCodes.Ldloc, local); break;
+ }
+ }
+
+ private static void Ldloca(ILGenerator il, LocalBuilder local)
+ {
+ switch (local.LocalIndex)
+ {
+ case int i when (i >= 0 & i <= 255): il.Emit(OpCodes.Ldloca_S, (byte)i); break;
+ default: il.Emit(OpCodes.Ldloca, local); break;
+ }
+ }
+ private static void Ldarga(ILGenerator il, ushort index)
+ {
+ if (index <= 255)
+ {
+ il.Emit(OpCodes.Ldarga_S, (byte)index);
+ }
+ else
+ {
+ il.Emit(OpCodes.Ldarga, index);
+ }
+ }
+ private static void Ldarg(ILGenerator il, ushort index)
+ {
+ switch(index)
+ {
+ case 0: il.Emit(OpCodes.Ldarg_0); break;
+ case 1: il.Emit(OpCodes.Ldarg_1); break;
+ case 2: il.Emit(OpCodes.Ldarg_2); break;
+ case 3: il.Emit(OpCodes.Ldarg_3); break;
+ case ushort x when x <= 255: il.Emit(OpCodes.Ldarg_S, (byte)x); break;
+ default: il.Emit(OpCodes.Ldarg, index); break;
+ }
+ }
+
+ internal static ProxyCache CreateFactory()
+ where TService : class
+ {
+ // front-load reflection discovery
+ if (!typeof(TService).IsInterface)
+ throw new InvalidOperationException("Type is not an interface: " + typeof(TService).FullName);
+ ContractOperation.TryGetServiceName(typeof(TService), out var serviceName);
+ var ops = ContractOperation.FindOperations(typeof(TService));
+
+ lock (s_module)
+ {
+ // private sealed class IFooProxy...
+ var type = s_module.DefineType(ProxyIdentity + "." + typeof(TService).Name + "_Proxy",
+ TypeAttributes.Class | TypeAttributes.Sealed | TypeAttributes.NotPublic | TypeAttributes.BeforeFieldInit);
+
+ // : ClientBase
+ Type baseType = typeof(ClientBase);
+ type.SetParent(baseType);
+
+ // : TService
+ type.AddInterfaceImplementation(typeof(TService));
+
+ // private IFooProxy() : base() { }
+ type.DefineDefaultConstructor(MethodAttributes.Private);
+
+ // public IFooProxy(CallInvoker callInvoker) : base(callInvoker) { }
+ var ctorCallInvoker = WritePassThruCtor(MethodAttributes.Public);
+
+ // public IFooProxy(Channel channel) : base(callIchannelnvoker) { }
+ var ctorChannel = WritePassThruCtor(MethodAttributes.Public);
+
+ // private IFooProxy(ClientBaseConfiguration configuration) : base(configuration) { }
+ var ctorClientBaseConfig = WritePassThruCtor(MethodAttributes.Family);
+
+ // override ToString
+ {
+ var toString = type.DefineMethod(nameof(ToString), s_Object_ToString.Attributes, s_Object_ToString.CallingConvention,
+ typeof(string), Type.EmptyTypes);
+ var il = toString.GetILGenerator();
+ il.Emit(OpCodes.Ldstr, serviceName);
+ il.Emit(OpCodes.Ret);
+ type.DefineMethodOverride(toString, s_Object_ToString);
+ }
+
+ var cctor = type.DefineTypeInitializer().GetILGenerator();
+
+ // add each method of the interface
+ int fieldIndex = 0;
+ foreach (var op in ops)
+ {
+ Type[] fromTo = new Type[] { op.From, op.To };
+ // public static readonly Method s_{i}
+ var field = type.DefineField("s_op_" + fieldIndex++, typeof(Method<,>).MakeGenericType(fromTo),
+ FieldAttributes.Static | FieldAttributes.Public | FieldAttributes.InitOnly);
+ // = new FullyNamedMethod(opName, methodType, serviceName, method.Name);
+ cctor.Emit(OpCodes.Ldstr, op.Name); // opName
+ Ldc_I4(cctor, (int)op.MethodType); // methodType
+ cctor.Emit(OpCodes.Ldstr, serviceName); // serviceName
+ cctor.Emit(OpCodes.Ldnull); // methodName: leave null (uses opName)
+ cctor.Emit(OpCodes.Ldnull); // requestMarshaller: always null
+ cctor.Emit(OpCodes.Ldnull); // responseMarshaller: always null
+ cctor.Emit(OpCodes.Newobj, typeof(FullyNamedMethod<,>).MakeGenericType(fromTo)
+ .GetConstructors(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance).Single()); // new FullyNamedMethod
+ cctor.Emit(OpCodes.Stsfld, field);
+
+ var impl = type.DefineMethod(typeof(TService).Name + "." + op.Method.Name,
+ MethodAttributes.HideBySig | MethodAttributes.Final | MethodAttributes.NewSlot | MethodAttributes.Private | MethodAttributes.Virtual,
+ op.Method.CallingConvention, op.Method.ReturnType, op.ParameterTypes);
+
+ // implement the method
+ var il = impl.GetILGenerator();
+
+ switch(op.Context)
+ {
+ case ContextKind.CallOptions:
+ // we only support this for signatures that match the exat google pattern, but:
+ // defer for now
+ il.ThrowException(typeof(NotImplementedException));
+ break;
+ case ContextKind.NoContext:
+ case ContextKind.CallContext:
+ // typically looks something like (where this is an extension method on Reshape):
+ // => context.{ReshapeMethod}(CallInvoker, {method}, request, [host: null]);
+ var method = op.TryGetClientHelper();
+ if (method == null)
+ {
+ // unexpected, but...
+ il.ThrowException(typeof(NotSupportedException));
+ }
+ else
+ {
+ if (op.Context == ContextKind.CallContext)
+ {
+ Ldarga(il, 2);
+ }
+ else
+ {
+ il.Emit(OpCodes.Ldsflda, s_CallContext_Default);
+ }
+ il.Emit(OpCodes.Ldarg_0); // this.
+ il.EmitCall(OpCodes.Callvirt, s_ClientBase_CallInvoker, null); // get_CallInvoker
+
+ il.Emit(OpCodes.Ldsfld, field); // {method}
+ il.Emit(OpCodes.Ldarg_1); // request
+ il.Emit(OpCodes.Ldnull); // host (always null)
+ il.EmitCall(OpCodes.Call, method, null);
+ il.Emit(OpCodes.Ret); // return
+ }
+ break;
+ case ContextKind.ServerCallContext: // server call? we're writing a client!
+ default: // who knows!
+ il.ThrowException(typeof(NotSupportedException));
+ break;
+ }
+
+ // mark it as the interface implementation
+ type.DefineMethodOverride(impl, op.Method);
+ }
+
+ cctor.Emit(OpCodes.Ret); // end the type initializer (after creating all the field types)
+
+ // write a factory method
+ WriteFactory(new[] { typeof(HttpClient), typeof(ILoggerFactory) }, typeof(HttpClientCallInvoker), ctorCallInvoker);
+ WriteFactory(new[] { typeof(CallInvoker) }, null, ctorCallInvoker);
+ WriteFactory(new[] { typeof(Channel) }, null, ctorChannel);
+ // WriteFactory(new[] { typeof(ClientBaseConfiguration) }, null, ctorClientBaseConfig);
+
+ // return the factory
+ return new ProxyCache(type.CreateType());
+
+ void WriteFactory(Type[] signature, Type? via, ConstructorBuilder? ctor)
+ {
+ if (ctor == null) return;
+ ConstructorInfo? viaCtor = via?.GetConstructor(signature);
+ if (via != null && viaCtor == null) return; // nope!
+
+ var factory = type.DefineMethod(FactoryName, MethodAttributes.Public | MethodAttributes.Static, typeof(TService), signature);
+ var il = factory.GetILGenerator();
+ for (ushort i = 0; i < signature.Length; i++)
+ Ldarg(il, i);
+ if (viaCtor != null) il.Emit(OpCodes.Newobj, viaCtor);
+ il.Emit(OpCodes.Newobj, ctor);
+ il.Emit(OpCodes.Ret);
+ }
+
+ ConstructorBuilder? WritePassThruCtor(MethodAttributes accessibility)
+ {
+ var signature = new[] { typeof(T) };
+ var baseCtor = baseType.GetConstructor(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, signature, null);
+ if (baseCtor == null) return null;
+
+ var ctor = type.DefineConstructor(accessibility, CallingConventions.HasThis, signature);
+ var il = ctor.GetILGenerator();
+ il.Emit(OpCodes.Ldarg_0);
+ il.Emit(OpCodes.Ldarg_1);
+ il.Emit(OpCodes.Call, baseCtor);
+ il.Emit(OpCodes.Ret);
+ return ctor;
+ }
+ }
+ }
+ internal const string FactoryName = "Create";
+ }
+ }
+}
diff --git a/src/protobuf-net.Grpc/Internal/ContractOperation.cs b/src/protobuf-net.Grpc/Internal/ContractOperation.cs
new file mode 100644
index 00000000..09b1f710
--- /dev/null
+++ b/src/protobuf-net.Grpc/Internal/ContractOperation.cs
@@ -0,0 +1,325 @@
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using Grpc.Core;
+using System.ServiceModel;
+using System.Buffers;
+using System.Threading.Tasks;
+using System.Linq;
+
+namespace ProtoBuf.Grpc.Internal
+{
+ internal readonly struct ContractOperation
+ {
+ public string Name { get; }
+ public Type From { get; }
+ public Type To { get; }
+ public MethodInfo Method { get; }
+ public Type[] ParameterTypes { get; }
+ public MethodType MethodType { get; }
+ public ContextKind Context { get; }
+ public ResultKind Result { get; }
+
+ public override string ToString() => $"{Name}: {From.Name}=>{To.Name}, {MethodType}, {Result}, {Context}";
+
+ public ContractOperation(string name, Type from, Type to, MethodInfo method,
+ MethodType methodType, ContextKind contextKind, ResultKind resultKind,
+ Type[] parameterTypes)
+ {
+ Name = name;
+ From = from;
+ To = to;
+ Method = method;
+ MethodType = methodType;
+ Context = contextKind;
+ Result = resultKind;
+ ParameterTypes = parameterTypes;
+ }
+
+ public static bool TryGetServiceName(Type contractType, out string? serviceName, bool demandAttribute = false)
+ {
+ var sca = (ServiceContractAttribute?)Attribute.GetCustomAttribute(contractType, typeof(ServiceContractAttribute), inherit: true);
+ if (demandAttribute && sca == null)
+ {
+ serviceName = null;
+ return false;
+ }
+ serviceName = sca?.Name;
+ if (string.IsNullOrWhiteSpace(serviceName))
+ {
+ serviceName = contractType.Name;
+ if (contractType.IsInterface && serviceName.StartsWith('I')) serviceName = serviceName.Substring(1); // IFoo => Foo
+ serviceName = contractType.Namespace + serviceName; // Whatever.Foo
+ serviceName = serviceName.Replace('+', '.'); // nested types
+ }
+ return !string.IsNullOrWhiteSpace(serviceName);
+ }
+
+ // do **not** replace these with a `params` etc version; the point here is to be as cheap
+ // as possible for misses
+ internal static bool IsMatch(Type returnType, ParameterInfo[] parameters, Type?[] types, Type? tRet)
+ => parameters.Length == 0
+ && IsMatch(tRet, returnType, out types[0]);
+ internal static bool IsMatch(Type returnType, ParameterInfo[] parameters, Type?[] types, Type? t0, Type? tRet)
+ => parameters.Length == 1
+ && IsMatch(t0, parameters[0].ParameterType, out types[0])
+ && IsMatch(tRet, returnType, out types[1]);
+ internal static bool IsMatch(Type returnType, ParameterInfo[] parameters, Type?[] types, Type? t0, Type? t1, Type? tRet)
+ => parameters.Length == 2
+ && IsMatch(t0, parameters[0].ParameterType, out types[0])
+ && IsMatch(t1, parameters[1].ParameterType, out types[1])
+ && IsMatch(tRet, returnType, out types[2]);
+ internal static bool IsMatch(Type returnType, ParameterInfo[] parameters, Type?[] types, Type? t0, Type? t1, Type? t2, Type? tRet)
+ => parameters.Length == 3
+ && IsMatch(t0, parameters[0].ParameterType, out types[0])
+ && IsMatch(t1, parameters[1].ParameterType, out types[1])
+ && IsMatch(t2, parameters[2].ParameterType, out types[2])
+ && IsMatch(tRet, returnType, out types[3]);
+
+ private static bool IsMatch(in Type? template, in Type actual, out Type result)
+ {
+ if (template == null || template == actual)
+ {
+ result = actual;
+ return true;
+ } // fine
+ if (actual.IsGenericType && template.IsGenericTypeDefinition
+ && actual.GetGenericTypeDefinition() == template)
+ {
+ // expected Foo<>, got Foo: report T
+ result = actual.GetGenericArguments()[0];
+ return true;
+ }
+ result = typeof(void);
+ return false;
+ }
+
+ public static List FindOperations(Type contractType, bool demandAttribute = false)
+ {
+ var ops = new List();
+ var types = ArrayPool.Shared.Rent(10);
+ try
+ {
+ foreach (var method in contractType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance))
+ {
+ if (method.IsGenericMethodDefinition) continue; // can't work with methods
+
+ var oca = (OperationContractAttribute?)Attribute.GetCustomAttribute(method, typeof(OperationContractAttribute), inherit: true);
+ if (demandAttribute && oca == null) continue;
+ string? opName = oca?.Name;
+ if (string.IsNullOrWhiteSpace(opName))
+ {
+ opName = method.Name;
+ if (opName.EndsWith("Async"))
+ opName = opName.Substring(0, opName.Length - 5);
+ }
+ if (string.IsNullOrWhiteSpace(opName)) continue;
+
+ var args = method.GetParameters();
+ if (args.Length == 0) continue; // no way of inferring anything!
+
+ var ret = method.ReturnType;
+
+ ContextKind contextKind = default;
+ MethodType methodType = default;
+ ResultKind resultKind = ResultKind.Unknown;
+ Type? from = null, to = null;
+
+ void Configure(ContextKind ck, MethodType mt, ResultKind rt, Type f, Type t)
+ {
+ contextKind = ck;
+ methodType = mt;
+ resultKind = rt;
+ from = f;
+ to = t;
+ }
+
+ // google server APIs
+ if (IsMatch(ret, args, types, typeof(IAsyncStreamReader<>), typeof(ServerCallContext), typeof(Task<>)))
+ {
+ Configure(ContextKind.ServerCallContext, MethodType.ClientStreaming, ResultKind.Task, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, typeof(IAsyncStreamReader<>), typeof(IAsyncStreamWriter<>), typeof(ServerCallContext), typeof(Task)))
+ {
+ Configure(ContextKind.ServerCallContext, MethodType.DuplexStreaming, ResultKind.Task, types[0], types[1]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(IServerStreamWriter<>), typeof(ServerCallContext), typeof(Task)))
+ {
+ Configure(ContextKind.ServerCallContext, MethodType.ServerStreaming, ResultKind.Task, types[0], types[1]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(ServerCallContext), typeof(Task<>)))
+ {
+ Configure(ContextKind.ServerCallContext, MethodType.Unary, ResultKind.Task, types[0], types[2]);
+ }
+
+ // google client APIs
+ else if (IsMatch(ret, args, types, null, typeof(CallOptions), typeof(AsyncUnaryCall<>)))
+ {
+ Configure(ContextKind.CallOptions, MethodType.Unary, ResultKind.Grpc, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, typeof(CallOptions), typeof(AsyncClientStreamingCall<,>)))
+ {
+ Configure(ContextKind.CallOptions, MethodType.ClientStreaming, ResultKind.Grpc, types[1], ret.GetGenericArguments()[1]);
+ }
+ else if (IsMatch(ret, args, types, typeof(CallOptions), typeof(AsyncDuplexStreamingCall<,>)))
+ {
+ Configure(ContextKind.CallOptions, MethodType.DuplexStreaming, ResultKind.Grpc, types[1], ret.GetGenericArguments()[1]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(CallOptions), typeof(AsyncServerStreamingCall<>)))
+ {
+ Configure(ContextKind.CallOptions, MethodType.ServerStreaming, ResultKind.Grpc, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(CallOptions), null))
+ {
+ Configure(ContextKind.CallOptions, MethodType.Unary, ResultKind.Sync, types[0], types[2]);
+ }
+
+
+ else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(CallContext), typeof(IAsyncEnumerable<>)))
+ {
+ Configure(ContextKind.CallContext, MethodType.DuplexStreaming, ResultKind.AsyncEnumerable, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(IAsyncEnumerable<>)))
+ {
+ Configure(ContextKind.NoContext, MethodType.DuplexStreaming, ResultKind.AsyncEnumerable, types[0], types[1]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(CallContext), typeof(IAsyncEnumerable<>)))
+ {
+ Configure(ContextKind.CallContext, MethodType.ServerStreaming, ResultKind.AsyncEnumerable, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(IAsyncEnumerable<>)))
+ {
+ Configure(ContextKind.NoContext, MethodType.ServerStreaming, ResultKind.AsyncEnumerable, types[0], types[1]);
+ }
+ else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(CallContext), typeof(Task<>)))
+ {
+ Configure(ContextKind.CallContext, MethodType.ClientStreaming, ResultKind.Task, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(Task<>)))
+ {
+ Configure(ContextKind.NoContext, MethodType.ClientStreaming, ResultKind.Task, types[0], types[1]);
+ }
+ else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(CallContext), typeof(ValueTask<>)))
+ {
+ Configure(ContextKind.CallContext, MethodType.ClientStreaming, ResultKind.ValueTask, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, typeof(IAsyncEnumerable<>), typeof(ValueTask<>)))
+ {
+ Configure(ContextKind.NoContext, MethodType.ClientStreaming, ResultKind.ValueTask, types[0], types[1]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(CallContext), typeof(Task<>)))
+ {
+ Configure(ContextKind.CallContext, MethodType.Unary, ResultKind.Task, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(Task<>)))
+ {
+ Configure(ContextKind.NoContext, MethodType.Unary, ResultKind.Task, types[0], types[1]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(CallContext), typeof(ValueTask<>)))
+ {
+ Configure(ContextKind.CallContext, MethodType.Unary, ResultKind.ValueTask, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(ValueTask<>)))
+ {
+ Configure(ContextKind.NoContext, MethodType.Unary, ResultKind.ValueTask, types[0], types[1]);
+ }
+ else if (IsMatch(ret, args, types, null, typeof(CallContext), null))
+ {
+ Configure(ContextKind.CallContext, MethodType.Unary, ResultKind.Sync, types[0], types[2]);
+ }
+ else if (IsMatch(ret, args, types, null, null))
+ {
+ Configure(ContextKind.NoContext, MethodType.Unary, ResultKind.Sync, types[0], types[1]);
+ }
+
+ Type[] argTypes = Array.ConvertAll(args, x => x.ParameterType);
+ if (resultKind != ResultKind.Unknown && from != null && to != null)
+ {
+ ops.Add(new ContractOperation(opName, from, to, method, methodType, contextKind, resultKind, argTypes));
+ }
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(types);
+ }
+ return ops;
+ }
+
+
+ internal MethodInfo? TryGetClientHelper()
+ {
+ var name = GetClientHelperName();
+ if (name == null || !s_reshaper.TryGetValue(name, out var method)) return null;
+ return method.MakeGenericMethod(From, To);
+ }
+#pragma warning disable CS0618
+ static readonly Dictionary s_reshaper =
+
+ (from method in typeof(Reshape).GetMethods(BindingFlags.Public | BindingFlags.Static)
+ where method.IsGenericMethodDefinition
+ let parameters = method.GetParameters()
+ where parameters[1].ParameterType == typeof(CallInvoker)
+ && parameters[0].ParameterType == typeof(CallContext).MakeByRefType()
+ select method).ToDictionary(x => x.Name);
+
+ static readonly Dictionary<(MethodType, ResultKind), string> _clientReshapeMap = new Dictionary<(MethodType, ResultKind), string>
+ {
+ {(MethodType.DuplexStreaming, ResultKind.AsyncEnumerable), nameof(Reshape.DuplexAsync) },
+ {(MethodType.ServerStreaming, ResultKind.AsyncEnumerable), nameof(Reshape.ServerStreamingAsync) },
+ {(MethodType.ClientStreaming, ResultKind.Task), nameof(Reshape.ClientStreamingTaskAsync) },
+ {(MethodType.ClientStreaming, ResultKind.ValueTask), nameof(Reshape.ClientStreamingValueTaskAsync) },
+ {(MethodType.Unary, ResultKind.Task), nameof(Reshape.UnaryTaskAsync) },
+ {(MethodType.Unary, ResultKind.ValueTask), nameof(Reshape.UnaryValueTaskAsync) },
+ {(MethodType.Unary, ResultKind.Sync), nameof(Reshape.UnarySync) },
+ };
+#pragma warning restore CS0618
+ private string? GetClientHelperName()
+ {
+ switch (Context)
+ {
+ case ContextKind.CallContext:
+ case ContextKind.NoContext:
+ return _clientReshapeMap.TryGetValue((MethodType, Result), out var helper) ? helper : null;
+ default:
+ return null;
+ }
+ }
+
+
+ internal bool IsSyncT()
+ {
+ return Method.ReturnType == To;
+ }
+ internal bool IsTaskT()
+ {
+ var ret = Method.ReturnType;
+ return ret.IsGenericType && ret.GetGenericTypeDefinition() == typeof(Task<>)
+ && ret.GetGenericArguments()[0] == To;
+ }
+ internal bool IsValueTaskT()
+ {
+ var ret = Method.ReturnType;
+ return ret.IsGenericType && ret.GetGenericTypeDefinition() == typeof(ValueTask<>)
+ && ret.GetGenericArguments()[0] == To;
+ }
+ }
+
+ internal enum ContextKind
+ {
+ NoContext, // no context
+ CallContext, // pb-net shared context kind
+ CallOptions, // GRPC core client context kind
+ ServerCallContext, // GRPC core server context kind
+ }
+
+ internal enum ResultKind
+ {
+ Unknown,
+ Sync,
+ Task,
+ ValueTask,
+ AsyncEnumerable,
+ Grpc,
+ }
+}
diff --git a/src/protobuf-net.Grpc/Internal/Empty.cs b/src/protobuf-net.Grpc/Internal/Empty.cs
new file mode 100644
index 00000000..807f3915
--- /dev/null
+++ b/src/protobuf-net.Grpc/Internal/Empty.cs
@@ -0,0 +1,21 @@
+using Grpc.Core;
+using System;
+using System.ComponentModel;
+
+namespace ProtoBuf.Grpc.Internal
+{
+ [Obsolete(Reshape.WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public sealed class Empty : IEquatable
+ {
+ public static Empty Instance { get; } = new Empty();
+ private Empty() { }
+ public override string ToString() => nameof(Empty);
+ public override bool Equals(object? obj) => obj is Empty;
+ public override int GetHashCode() => 42;
+ bool IEquatable.Equals(Empty other) => other != null;
+
+ internal static readonly Marshaller Marshaller
+ = new Marshaller((Empty _)=> Array.Empty(), (byte[] _) => Instance);
+ }
+}
diff --git a/src/protobuf-net.Grpc/Internal/FullyNamedMethod.cs b/src/protobuf-net.Grpc/Internal/FullyNamedMethod.cs
new file mode 100644
index 00000000..e5af88ff
--- /dev/null
+++ b/src/protobuf-net.Grpc/Internal/FullyNamedMethod.cs
@@ -0,0 +1,25 @@
+using Grpc.Core;
+
+namespace ProtoBuf.Grpc.Internal
+{
+ public class FullyNamedMethod : Method, IMethod
+ {
+ private readonly string _fullName;
+
+ public FullyNamedMethod(
+ string operationName,
+ MethodType type,
+ string serviceName,
+ string? methodName = null,
+ Marshaller? requestMarshaller = null,
+ Marshaller? responseMarshaller = null)
+ : base(type, serviceName, methodName ?? operationName,
+ requestMarshaller ?? MarshallerCache.Instance,
+ responseMarshaller ?? MarshallerCache.Instance)
+ {
+ _fullName = serviceName + "/" + operationName;
+ }
+
+ string IMethod.FullName => _fullName;
+ }
+}
diff --git a/src/protobuf-net.Grpc/Internal/MarshallerCache.cs b/src/protobuf-net.Grpc/Internal/MarshallerCache.cs
new file mode 100644
index 00000000..20d32680
--- /dev/null
+++ b/src/protobuf-net.Grpc/Internal/MarshallerCache.cs
@@ -0,0 +1,33 @@
+using Grpc.Core;
+using ProtoBuf.Meta;
+using System.IO;
+
+namespace ProtoBuf.Grpc.Internal
+{
+ internal static class MarshallerCache
+ {
+#pragma warning disable CS0618
+ public static Marshaller Instance { get; } = typeof(T) == typeof(Empty)
+ ? (Marshaller)(object)Empty.Marshaller : new Marshaller(Serialize, Deserialize);
+#pragma warning restore CS0618
+
+ private static readonly RuntimeTypeModel _model = RuntimeTypeModel.Default;
+
+ private static T Deserialize(byte[] payload)
+ {
+ using (var reader = ProtoReader.Create(out var state, payload, _model))
+ {
+ return (T)_model.Deserialize(reader, ref state, null, typeof(T));
+ }
+ }
+
+ private static byte[] Serialize(T value)
+ {
+ using (var ms = new MemoryStream())
+ {
+ Serializer.Serialize(ms, value);
+ return ms.ToArray();
+ }
+ }
+ }
+}
diff --git a/src/protobuf-net.Grpc/Internal/MetadataContext.cs b/src/protobuf-net.Grpc/Internal/MetadataContext.cs
new file mode 100644
index 00000000..8cb60451
--- /dev/null
+++ b/src/protobuf-net.Grpc/Internal/MetadataContext.cs
@@ -0,0 +1,34 @@
+using System;
+using System.Runtime.CompilerServices;
+using Grpc.Core;
+
+namespace ProtoBuf.Grpc.Internal
+{
+ internal sealed class MetadataContext
+ {
+ internal MetadataContext() { }
+
+ private Metadata? _headers, _trailers;
+ internal Metadata Headers
+ {
+ get => _headers ?? Throw("Headers are not yet available");
+ set => _headers = value;
+ }
+ internal Metadata Trailers
+ {
+ get => _trailers ?? Throw("Trailers are not yet available");
+ set => _trailers = value;
+ }
+ internal Status Status { get; set; }
+
+ [MethodImpl(MethodImplOptions.NoInlining)]
+ private static Metadata Throw(string message) => throw new InvalidOperationException(message);
+
+ internal MetadataContext Reset()
+ {
+ Status = default;
+ _headers = _trailers = null;
+ return this;
+ }
+ }
+}
diff --git a/src/protobuf-net.Grpc/Internal/Reshape.cs b/src/protobuf-net.Grpc/Internal/Reshape.cs
new file mode 100644
index 00000000..222bddb5
--- /dev/null
+++ b/src/protobuf-net.Grpc/Internal/Reshape.cs
@@ -0,0 +1,221 @@
+using Grpc.Core;
+using System;
+using System.Collections.Generic;
+using System.ComponentModel;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+namespace ProtoBuf.Grpc.Internal
+{
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static class Reshape
+ {
+ internal const string WarningMessage = "This API is intended for use by runtime-generated code; all methods can be changed without notice - it is only guaranteed to work with the internally generated code";
+
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static async IAsyncEnumerable AsAsyncEnumerable(this IAsyncStreamReader reader, [EnumeratorCancellation] CancellationToken cancellationToken)
+ {
+ using (reader)
+ {
+ while (await reader.MoveNext(cancellationToken))
+ {
+ yield return reader.Current;
+ }
+ }
+ }
+
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static async Task WriteTo(this IAsyncEnumerable reader, IServerStreamWriter writer, CancellationToken cancellationToken)
+ {
+ await using (var iter = reader.GetAsyncEnumerator(cancellationToken))
+ {
+ while (await iter.MoveNextAsync())
+ {
+ await writer.WriteAsync(iter.Current);
+ }
+ }
+ }
+
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static TResponse UnarySync(
+ this in CallContext context,
+ CallInvoker invoker, Method method, TRequest request, string? host = null)
+ where TRequest : class
+ where TResponse : class
+ {
+ context.Prepare();
+ return invoker.BlockingUnaryCall(method, host, context.Client, request);
+ }
+
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static Task UnaryTaskAsync(
+ this in CallContext context,
+ CallInvoker invoker, Method method, TRequest request, string? host = null)
+ where TRequest : class
+ where TResponse : class
+ => UnaryTaskAsyncImpl(invoker.AsyncUnaryCall(method, host, context.Client, request), context.Prepare());
+
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static ValueTask UnaryValueTaskAsync(
+ this in CallContext context, CallInvoker invoker,
+ Method method, TRequest request, string? host = null)
+ where TRequest : class
+ where TResponse : class
+ => new ValueTask(UnaryTaskAsyncImpl(invoker.AsyncUnaryCall(method, host, context.Client, request), context.Prepare()));
+
+ private static async Task UnaryTaskAsyncImpl(
+ AsyncUnaryCall call, MetadataContext? metadata)
+ {
+ using (call)
+ {
+ if (metadata != null) metadata.Headers = await call.ResponseHeadersAsync;
+ var value = await call;
+ if (metadata != null)
+ {
+ metadata.Trailers = call.GetTrailers();
+ metadata.Status = call.GetStatus();
+ }
+ return value;
+ }
+ }
+
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static IAsyncEnumerable ServerStreamingAsync(
+ this in CallContext context,
+ CallInvoker invoker, Method method, TRequest request, string? host = null)
+ where TRequest : class
+ where TResponse : class
+ => ServerStreamingAsyncImpl(invoker.AsyncServerStreamingCall(method, host, context.Client, request), context.Prepare(), context.CancellationToken);
+
+ private static async IAsyncEnumerable ServerStreamingAsyncImpl(
+ AsyncServerStreamingCall call, MetadataContext? metadata,
+ [EnumeratorCancellation] CancellationToken cancellationToken)
+ {
+ using (call)
+ {
+ if (metadata != null) metadata.Headers = await call.ResponseHeadersAsync;
+
+ using (var seq = call.ResponseStream)
+ {
+ while (await seq.MoveNext(default))
+ {
+ yield return seq.Current;
+ }
+ if (metadata != null)
+ {
+ metadata.Trailers = call.GetTrailers();
+ metadata.Status = call.GetStatus();
+ }
+ }
+ }
+ }
+
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static Task ClientStreamingTaskAsync(
+ this in CallContext options,
+ CallInvoker invoker, Method method, IAsyncEnumerable request, string? host = null)
+ where TRequest : class
+ where TResponse : class
+ => ClientStreamingTaskAsyncImpl(invoker.AsyncClientStreamingCall(method, host, options.Client), options.Prepare(), options.CancellationToken, request);
+
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static ValueTask ClientStreamingValueTaskAsync(
+ this in CallContext options,
+ CallInvoker invoker, Method method, IAsyncEnumerable request, string? host = null)
+ where TRequest : class
+ where TResponse : class
+ => new ValueTask(ClientStreamingTaskAsyncImpl(invoker.AsyncClientStreamingCall(method, host, options.Client), options.Prepare(), options.CancellationToken, request));
+
+ private static async Task ClientStreamingTaskAsyncImpl(
+ AsyncClientStreamingCall call, MetadataContext? metadata,
+ CancellationToken cancellationToken, IAsyncEnumerable request)
+ {
+ using (call)
+ {
+ var output = call.RequestStream;
+ await using (var iter = request.GetAsyncEnumerator(cancellationToken))
+ {
+ while (await iter.MoveNextAsync())
+ {
+ await output.WriteAsync(iter.Current);
+ }
+ }
+ await output.CompleteAsync();
+
+ if (metadata != null) metadata.Headers = await call.ResponseHeadersAsync;
+
+ var result = await call.ResponseAsync;
+
+ if (metadata != null)
+ {
+ metadata.Trailers = call.GetTrailers();
+ metadata.Status = call.GetStatus();
+ }
+ return result;
+ }
+ }
+
+ [Obsolete(WarningMessage, false)]
+ [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
+ public static IAsyncEnumerable DuplexAsync(
+ this in CallContext options,
+ CallInvoker invoker, Method method, IAsyncEnumerable request, string? host = null)
+ where TRequest : class
+ where TResponse : class
+ => DuplexAsyncImpl(invoker.AsyncDuplexStreamingCall(method, host, options.Client), options.Prepare(), options.CancellationToken, request);
+
+ private static async IAsyncEnumerable DuplexAsyncImpl(
+ AsyncDuplexStreamingCall call, MetadataContext? metadata,
+ [EnumeratorCancellation] CancellationToken cancellationToken, IAsyncEnumerable request)
+ {
+ using (call)
+ {
+ // we'll run the "send" as a concurrent operation
+ var sendAll = Task.Run(() => SendAll(call.RequestStream, request, cancellationToken));
+
+ if (metadata != null) metadata.Headers = await call.ResponseHeadersAsync;
+
+ using (var seq = call.ResponseStream)
+ {
+ while (await seq.MoveNext(default))
+ {
+ yield return seq.Current;
+ }
+ await sendAll; // observe any problems from sending
+
+ if (metadata != null)
+ {
+ metadata.Trailers = call.GetTrailers();
+ metadata.Status = call.GetStatus();
+ }
+ }
+ }
+ }
+
+ private static async Task SendAll(IClientStreamWriter output, IAsyncEnumerable request, CancellationToken cancellationToken)
+ {
+ try
+ {
+ await using (var iter = request.GetAsyncEnumerator(cancellationToken))
+ {
+ while (await iter.MoveNextAsync())
+ {
+ var item = iter.Current;
+ await output.WriteAsync(item);
+ }
+ }
+ await output.CompleteAsync();
+ }
+ catch (TaskCanceledException) { }
+ }
+ }
+}
diff --git a/src/protobuf-net.Grpc/Server/ServicesExtensions.cs b/src/protobuf-net.Grpc/Server/ServicesExtensions.cs
new file mode 100644
index 00000000..3af1ae1f
--- /dev/null
+++ b/src/protobuf-net.Grpc/Server/ServicesExtensions.cs
@@ -0,0 +1,239 @@
+using Grpc.AspNetCore.Server.Model;
+using Grpc.Core;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.DependencyInjection.Extensions;
+using Microsoft.Extensions.Logging;
+using ProtoBuf.Grpc.Internal;
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Reflection;
+using System.ServiceModel;
+using System.Threading.Tasks;
+
+namespace ProtoBuf.Grpc.Server
+{
+ public static class ServicesExtensions
+ {
+ public static void AddCodeFirstGrpc(this IServiceCollection services)
+ {
+ services.TryAddEnumerable(ServiceDescriptor.Singleton(typeof(IServiceMethodProvider<>), typeof(CodeFirstServiceMethodProvider<>)));
+ }
+
+ private sealed class CodeFirstServiceMethodProvider : IServiceMethodProvider where TService : class
+ {
+ private readonly ILogger> _logger;
+
+ public CodeFirstServiceMethodProvider(ILoggerFactory loggerFactory)
+ {
+ _logger = _logger = loggerFactory.CreateLogger>();
+ }
+ public void OnServiceMethodDiscovery(ServiceMethodProviderContext context)
+ {
+ // ignore any services that are known to be the default handler
+ if (Attribute.IsDefined(typeof(TService), typeof(BindServiceMethodAttribute))) return;
+
+ // we support methods that match suitable signatures, where:
+ // 1. (removed - no longer supported) the method is directly on TService and is marked [OperationContract]
+ // 2. the method is on an interface that TService implements, and the interface is marked [ServiceContract]
+ // AddMethodsForService(context,typeof(TService));
+
+ foreach (var iType in typeof(TService).GetInterfaces())
+ {
+ AddMethodsForService(context, iType);
+ }
+ }
+
+ static Expression ToTaskT(Expression expression)
+ {
+ var type = expression.Type;
+ if (type.IsGenericType)
+ {
+ if (type.GetGenericTypeDefinition() == typeof(Task<>))
+ return expression;
+ if (type.GetGenericTypeDefinition() == typeof(ValueTask<>))
+ return Expression.Call(expression, nameof(ValueTask.AsTask), null);
+ }
+ return Expression.Call(typeof(Task), nameof(Task.FromResult), new Type[] { expression.Type }, expression);
+ }
+
+ internal static readonly ConstructorInfo s_CallContext_FromServerContext = typeof(CallContext).GetConstructor(new[] { typeof(ServerCallContext) })!;
+ static Expression ToCallContext(Expression context) => Expression.New(s_CallContext_FromServerContext, context);
+#pragma warning disable CS0618
+ static Expression AsAsyncEnumerable(ParameterExpression value, ParameterExpression context)
+ => Expression.Call(typeof(Reshape), nameof(Reshape.AsAsyncEnumerable),
+ typeArguments: value.Type.GetGenericArguments(),
+ arguments: new Expression[] { value, Expression.Property(context, nameof(ServerCallContext.CancellationToken)) });
+
+ static Expression WriteTo(Expression value, ParameterExpression writer, ParameterExpression context)
+ => Expression.Call(typeof(Reshape), nameof(Reshape.WriteTo),
+ typeArguments: value.Type.GetGenericArguments(),
+ arguments: new Expression[] {value, writer, Expression.Property(context, nameof(ServerCallContext.CancellationToken)) });
+#pragma warning restore CS0618
+
+ static readonly Dictionary<(MethodType, ContextKind, ResultKind), Func?> _invokers
+ = new Dictionary<(MethodType, ContextKind, ResultKind), Func?>
+ {
+ // GRPC-style server methods are direct match; no mapping required
+ // => service.{method}(args)
+ { (MethodType.Unary, ContextKind.ServerCallContext, ResultKind.Task), null },
+ { (MethodType.ServerStreaming, ContextKind.ServerCallContext, ResultKind.Task), null },
+ { (MethodType.ClientStreaming, ContextKind.ServerCallContext, ResultKind.Task), null },
+ { (MethodType.DuplexStreaming, ContextKind.ServerCallContext, ResultKind.Task), null },
+
+ // Unary: Task Foo(TService service, TRequest request, ServerCallContext serverCallContext);
+ // => service.{method}(request, [new CallContext(serverCallContext)])
+ {(MethodType.Unary, ContextKind.NoContext, ResultKind.Task), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1])) },
+ {(MethodType.Unary, ContextKind.NoContext, ResultKind.ValueTask), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1])) },
+ {(MethodType.Unary, ContextKind.NoContext, ResultKind.Sync), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1])) },
+
+ {(MethodType.Unary, ContextKind.CallContext, ResultKind.Task), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1], ToCallContext(args[2]))) },
+ {(MethodType.Unary, ContextKind.CallContext, ResultKind.ValueTask), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1], ToCallContext(args[2]))) },
+ {(MethodType.Unary, ContextKind.CallContext, ResultKind.Sync), (method, args) => ToTaskT(Expression.Call(args[0], method, args[1], ToCallContext(args[2]))) },
+
+ // Client Streaming: Task Foo(TService service, IAsyncStreamReader stream, ServerCallContext serverCallContext);
+ // => service.{method}(reader.AsAsyncEnumerable(serverCallContext.CancellationToken), [new CallContext(serverCallContext)])
+ {(MethodType.ClientStreaming, ContextKind.NoContext, ResultKind.Task), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]))) },
+ {(MethodType.ClientStreaming, ContextKind.NoContext, ResultKind.ValueTask), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]))) },
+ {(MethodType.ClientStreaming, ContextKind.NoContext, ResultKind.Sync), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]))) },
+
+ {(MethodType.ClientStreaming, ContextKind.CallContext, ResultKind.Task), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]), ToCallContext(args[2]))) },
+ {(MethodType.ClientStreaming, ContextKind.CallContext, ResultKind.ValueTask), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]), ToCallContext(args[2]))) },
+ {(MethodType.ClientStreaming, ContextKind.CallContext, ResultKind.Sync), (method, args) => ToTaskT(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[2]), ToCallContext(args[2]))) },
+
+ // Server Streaming: Task Foo(TService service, TRequest request, IServerStreamWriter stream, ServerCallContext serverCallContext);
+ // => service.{method}(request, [new CallContext(serverCallContext)]).WriteTo(stream, serverCallContext.CancellationToken)
+ {(MethodType.ServerStreaming, ContextKind.NoContext, ResultKind.AsyncEnumerable), (method, args) => WriteTo(Expression.Call(args[0], method, args[1]), args[2], args[3])},
+ {(MethodType.ServerStreaming, ContextKind.CallContext, ResultKind.AsyncEnumerable), (method, args) => WriteTo(Expression.Call(args[0], method, args[1], ToCallContext(args[3])), args[2], args[3])},
+
+ // Duplex: Task Foo(TService service, IAsyncStreamReader input, IServerStreamWriter output, ServerCallContext serverCallContext);
+ // => service.{method}(input.AsAsyncEnumerable(serverCallContext.CancellationToken), [new CallContext(serverCallContext)]).WriteTo(output, serverCallContext.CancellationToken)
+ {(MethodType.DuplexStreaming, ContextKind.NoContext, ResultKind.AsyncEnumerable), (method, args) => WriteTo(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[3])), args[2], args[3]) },
+ {(MethodType.DuplexStreaming, ContextKind.CallContext, ResultKind.AsyncEnumerable), (method, args) => WriteTo(Expression.Call(args[0], method, AsAsyncEnumerable(args[1], args[3]), ToCallContext(args[3])), args[2], args[3]) },
+ };
+ private void AddMethodsForService(ServiceMethodProviderContext context, Type serviceContract)
+ {
+ bool isPublicContract = typeof(TService) == serviceContract;
+ if (!ContractOperation.TryGetServiceName(serviceContract, out var serviceName, !isPublicContract)) return;
+ _logger.Log(LogLevel.Trace, "pb-net processing {0}/{1} as {2}", typeof(TService).Name, serviceContract.Name, serviceName);
+ object?[]? argsBuffer = null;
+ Type[] typesBuffer = Array.Empty();
+
+ int count = 0;
+ foreach (var op in ContractOperation.FindOperations(serviceContract, isPublicContract))
+ {
+ if (_invokers.TryGetValue((op.MethodType, op.Context, op.Result), out var invoker)
+ && AddMethod(op.From, op.To, op.Method, op.MethodType, invoker))
+ {
+ // yay!
+ count++;
+ }
+ else
+ {
+ _logger.Log(LogLevel.Warning, "operation cannot be hosted as a server: {0}", op);
+ }
+ }
+ if (count != 0) _logger.Log(LogLevel.Information, "{0} implementing service {1} (via '{2}') with {3} operation(s)", typeof(TService), serviceName, serviceContract.Name, count);
+
+ bool AddMethod(Type @in, Type @out, MethodInfo m, MethodType t, Func? invoker = null)
+ {
+ try
+ {
+ if (typesBuffer.Length == 0)
+ {
+ typesBuffer = new Type[] { typeof(TService), typeof(void), typeof(void) };
+ }
+ typesBuffer[1] = @in;
+ typesBuffer[2] = @out;
+
+ if (argsBuffer == null)
+ {
+ argsBuffer = new object?[] { serviceName, null, null, context, _logger, null };
+ }
+ argsBuffer[1] = m;
+ argsBuffer[2] = t;
+ argsBuffer[5] = invoker;
+
+ s_addMethod.MakeGenericMethod(typesBuffer).Invoke(null, argsBuffer);
+ return true;
+ }
+ catch (Exception fail)
+ {
+ if (fail is TargetInvocationException tie) fail = tie.InnerException!;
+ _logger.Log(LogLevel.Error, "Failure processing {0}: {1}", m.Name, fail.Message);
+ return false;
+ }
+ }
+ }
+ }
+
+ private static readonly MethodInfo s_addMethod = typeof(ServicesExtensions).GetMethod(
+ nameof(AddMethod), BindingFlags.Static | BindingFlags.NonPublic)!;
+
+ private static void AddMethod(
+ string serviceName, MethodInfo method, MethodType methodType,
+ ServiceMethodProviderContext context, ILogger logger,
+ Func? invoker = null)
+ where TService : class
+ where TRequest : class
+ where TResponse : class
+ {
+ var oca = (OperationContractAttribute?)Attribute.GetCustomAttribute(method, typeof(OperationContractAttribute), inherit: true);
+ var operationName = oca?.Name;
+ if (string.IsNullOrWhiteSpace(operationName))
+ {
+ operationName = method.Name;
+ if (operationName.EndsWith("Async")) operationName = operationName.Substring(0, operationName.Length - 5);
+ }
+
+ var metadata = new List