Skip to content

Commit

Permalink
Fix HttpContext race by copying values to reader and writer (#2294)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Oct 18, 2023
1 parent 683fbdf commit 91be392
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 5 deletions.
15 changes: 14 additions & 1 deletion src/Grpc.AspNetCore.Server/Internal/GrpcProtocolHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -234,4 +234,17 @@ internal static bool ShouldSkipHeader(string name)
{
return name.StartsWith(':') || GrpcProtocolConstants.FilteredHeaders.Contains(name);
}

internal static IHttpRequestLifetimeFeature GetRequestLifetimeFeature(HttpContext httpContext)
{
var lifetimeFeature = httpContext.Features.Get<IHttpRequestLifetimeFeature>();
if (lifetimeFeature is null)
{
// This should only run in tests where the HttpContext is manually created.
lifetimeFeature = new HttpRequestLifetimeFeature();
httpContext.Features.Set(lifetimeFeature);
}

return lifetimeFeature;
}
}
14 changes: 12 additions & 2 deletions src/Grpc.AspNetCore.Server/Internal/HttpContextStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#endregion

using System.Diagnostics;
using System.IO.Pipelines;
using Grpc.Core;
using Grpc.Shared;
using Microsoft.AspNetCore.Http.Features;

namespace Grpc.AspNetCore.Server.Internal;

Expand All @@ -28,6 +30,8 @@ internal class HttpContextStreamReader<TRequest> : IAsyncStreamReader<TRequest>
{
private readonly HttpContextServerCallContext _serverCallContext;
private readonly Func<DeserializationContext, TRequest> _deserializer;
private readonly PipeReader _bodyReader;
private readonly IHttpRequestLifetimeFeature _requestLifetimeFeature;
private bool _completed;
private long _readCount;
private bool _endOfStream;
Expand All @@ -36,6 +40,12 @@ public HttpContextStreamReader(HttpContextServerCallContext serverCallContext, F
{
_serverCallContext = serverCallContext;
_deserializer = deserializer;

// Copy HttpContext values.
// This is done to avoid a race condition when reading them from HttpContext later when running in a separate thread.
_bodyReader = _serverCallContext.HttpContext.Request.BodyReader;
// Copy lifetime feature because HttpContext.RequestAborted on .NET 6 doesn't return the real cancellation token.
_requestLifetimeFeature = GrpcProtocolHelpers.GetRequestLifetimeFeature(_serverCallContext.HttpContext);
}

public TRequest Current { get; private set; } = default!;
Expand All @@ -54,7 +64,7 @@ async Task<bool> MoveNextAsync(ValueTask<TRequest?> readStreamTask)
return Task.FromCanceled<bool>(cancellationToken);
}

if (_completed || _serverCallContext.CancellationToken.IsCancellationRequested)
if (_completed || _requestLifetimeFeature.RequestAborted.IsCancellationRequested)
{
return Task.FromException<bool>(new InvalidOperationException("Can't read messages after the request is complete."));
}
Expand All @@ -63,7 +73,7 @@ async Task<bool> MoveNextAsync(ValueTask<TRequest?> readStreamTask)
// In a long running stream this can allow the previous value to be GCed.
Current = null!;

var request = _serverCallContext.HttpContext.Request.BodyReader.ReadStreamMessageAsync(_serverCallContext, _deserializer, cancellationToken);
var request = _bodyReader.ReadStreamMessageAsync(_serverCallContext, _deserializer, cancellationToken);
if (!request.IsCompletedSuccessfully)
{
return MoveNextAsync(request);
Expand Down
14 changes: 12 additions & 2 deletions src/Grpc.AspNetCore.Server/Internal/HttpContextStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#endregion

using System.Diagnostics;
using System.IO.Pipelines;
using Grpc.Core;
using Grpc.Shared;
using Microsoft.AspNetCore.Http.Features;

namespace Grpc.AspNetCore.Server.Internal;

Expand All @@ -29,6 +31,8 @@ internal class HttpContextStreamWriter<TResponse> : IServerStreamWriter<TRespons
{
private readonly HttpContextServerCallContext _context;
private readonly Action<TResponse, SerializationContext> _serializer;
private readonly PipeWriter _bodyWriter;
private readonly IHttpRequestLifetimeFeature _requestLifetimeFeature;
private readonly object _writeLock;
private Task? _writeTask;
private bool _completed;
Expand All @@ -39,6 +43,12 @@ public HttpContextStreamWriter(HttpContextServerCallContext context, Action<TRes
_context = context;
_serializer = serializer;
_writeLock = new object();

// Copy HttpContext values.
// This is done to avoid a race condition when reading them from HttpContext later when running in a separate thread.
_bodyWriter = context.HttpContext.Response.BodyWriter;
// Copy lifetime feature because HttpContext.RequestAborted on .NET 6 doesn't return the real cancellation token.
_requestLifetimeFeature = GrpcProtocolHelpers.GetRequestLifetimeFeature(context.HttpContext);
}

public WriteOptions? WriteOptions
Expand Down Expand Up @@ -77,7 +87,7 @@ private async Task WriteCoreAsync(TResponse message, CancellationToken cancellat
{
cancellationToken.ThrowIfCancellationRequested();

if (_completed || _context.CancellationToken.IsCancellationRequested)
if (_completed || _requestLifetimeFeature.RequestAborted.IsCancellationRequested)
{
throw new InvalidOperationException("Can't write the message because the request is complete.");
}
Expand All @@ -91,7 +101,7 @@ private async Task WriteCoreAsync(TResponse message, CancellationToken cancellat
}

// Save write task to track whether it is complete. Must be set inside lock.
_writeTask = _context.HttpContext.Response.BodyWriter.WriteStreamedMessageAsync(message, _context, _serializer, cancellationToken);
_writeTask = _bodyWriter.WriteStreamedMessageAsync(message, _context, _serializer, cancellationToken);
}

await _writeTask;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#endregion

using Grpc.AspNetCore.Server.Internal.CallHandlers;
using Grpc.AspNetCore.Server.Tests.TestObjects;
using Grpc.Core;
using Grpc.Shared.Server;
using Grpc.Tests.Shared;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.Logging.Abstractions;
using NUnit.Framework;

namespace Grpc.AspNetCore.Server.Tests;

[TestFixture]
public class DuplexStreamingServerCallHandlerTests
{
private static readonly Marshaller<TestMessage> _marshaller = new Marshaller<TestMessage>((message, context) => { context.Complete(Array.Empty<byte>()); }, context => new TestMessage());

[Test]
public async Task HandleCallAsync_ConcurrentReadAndWrite_Success()
{
// Arrange
var invoker = new DuplexStreamingServerMethodInvoker<TestService, TestMessage, TestMessage>(
(service, reader, writer, context) =>
{
var message = new TestMessage();
var readTask = Task.Run(() => reader.MoveNext());
var writeTask = Task.Run(() => writer.WriteAsync(message));
return Task.WhenAll(readTask, writeTask);
},
new Method<TestMessage, TestMessage>(MethodType.DuplexStreaming, "test", "test", _marshaller, _marshaller),
HttpContextServerCallContextHelper.CreateMethodOptions(),
new TestGrpcServiceActivator<TestService>());
var handler = new DuplexStreamingServerCallHandler<TestService, TestMessage, TestMessage>(invoker, NullLoggerFactory.Instance);

// Verify there isn't a race condition when reading/writing on seperate threads.
// This test primarily exists to ensure that the stream reader and stream writer aren't accessing non-thread safe APIs on HttpContext.
for (var i = 0; i < 10_000; i++)
{
var httpContext = HttpContextHelpers.CreateContext();

// Act
await handler.HandleCallAsync(httpContext);

// Assert
var trailers = httpContext.Features.Get<IHttpResponseTrailersFeature>()!.Trailers;
Assert.AreEqual("0", trailers["grpc-status"].ToString());
}
}
}

0 comments on commit 91be392

Please sign in to comment.