Skip to content

Commit

Permalink
WebSocket Feedback Follow-up (#107662)
Browse files Browse the repository at this point in the history
* Fixes

* State validation update

* Roll back dispose changes, fix mutex logging

* Roll back observe changes

* Add internal flags enum for states

* Update src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketStateHelper.cs

Co-authored-by: Miha Zupan <[email protected]>

* Feedback

---------

Co-authored-by: Miha Zupan <[email protected]>
  • Loading branch information
CarnaViire and MihaZupan authored Sep 18, 2024
1 parent 24e7d1b commit b6b0fb1
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,6 @@ internal static partial class WebSocketValidate
private static readonly SearchValues<char> s_validSubprotocolChars =
SearchValues.Create("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~");

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates)
=> ThrowIfInvalidState(currentState, isDisposed, innerException: null, validStates ?? []);

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, Exception? innerException, WebSocketState[]? validStates = null)
{
if (validStates is not null && Array.IndexOf(validStates, currentState) == -1)
{
string invalidStateMessage = SR.Format(
SR.net_WebSockets_InvalidState, currentState, string.Join(", ", validStates));

throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, innerException);
}

if (innerException is not null)
{
Debug.Assert(currentState == WebSocketState.Aborted);
throw new OperationCanceledException(nameof(WebSocketState.Aborted), innerException);
}

// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
}

internal static void ValidateSubprotocol(string subProtocol)
{
ArgumentException.ThrowIfNullOrWhiteSpace(subProtocol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
<Compile Include="System\Net\WebSockets\Compression\WebSocketInflater.cs" />
<Compile Include="System\Net\WebSockets\ManagedWebSocket.cs" />
<Compile Include="System\Net\WebSockets\ManagedWebSocket.KeepAlive.cs" />
<Compile Include="System\Net\WebSockets\ManagedWebSocketStates.cs" />
<Compile Include="System\Net\WebSockets\NetEventSource.WebSockets.cs" />
<Compile Include="System\Net\WebSockets\ValueWebSocketReceiveResult.cs" />
<Compile Include="System\Net\WebSockets\WebSocket.cs" />
Expand All @@ -31,6 +32,7 @@
<Compile Include="System\Net\WebSockets\WebSocketMessageFlags.cs" />
<Compile Include="System\Net\WebSockets\WebSocketReceiveResult.cs" />
<Compile Include="System\Net\WebSockets\WebSocketState.cs" />
<Compile Include="System\Net\WebSockets\WebSocketStateHelper.cs" />
<Compile Include="$(CommonPath)System\Net\WebSockets\WebSocketDefaults.cs"
Link="Common\System\Net\WebSockets\WebSocketDefaults.cs" />
<Compile Include="$(CommonPath)System\Net\WebSockets\WebSocketValidate.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,28 @@ public Task EnterAsync(CancellationToken cancellationToken)
// If cancellation was requested, bail immediately.
// If the mutex is not currently held nor contended, enter immediately.
// Otherwise, fall back to a more expensive likely-asynchronous wait.
return
cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) :
Interlocked.Decrement(ref _gate) >= 0 ? Task.CompletedTask :
Contended(cancellationToken);

if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled(cancellationToken);
}

int gate = Interlocked.Decrement(ref _gate);
if (gate >= 0)
{
return Task.CompletedTask;
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Waiting to enter, queue length {-gate}");

return Contended(cancellationToken);

// Everything that follows is the equivalent of:
// return _sem.WaitAsync(cancellationToken);
// if _sem were to be constructed as `new SemaphoreSlim(0)`.

Task Contended(CancellationToken cancellationToken)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate);

var w = new Waiter(this);

// We need to register for cancellation before storing the waiter into the list.
Expand Down Expand Up @@ -178,18 +187,18 @@ static void OnCancellation(object? state, CancellationToken cancellationToken)
/// <remarks>The caller must logically own the mutex. This is not validated.</remarks>
public void Exit()
{
if (Interlocked.Increment(ref _gate) < 1)
// This is the equivalent of:
// _sem.Release();
// if _sem were to be constructed as `new SemaphoreSlim(0)`.
int gate = Interlocked.Increment(ref _gate);
if (gate < 1)
{
// This is the equivalent of:
// _sem.Release();
// if _sem were to be constructed as `new SemaphoreSlim(0)`.
if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Unblocking next waiter on exit, remaining queue length {-_gate}", nameof(Exit));
Contended();
}

void Contended()
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate);

Waiter? w;
lock (SyncObj)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Buffers.Binary;
using System.Diagnostics;
using System.Runtime.ExceptionServices;
Expand All @@ -13,8 +12,6 @@ namespace System.Net.WebSockets
internal sealed partial class ManagedWebSocket : WebSocket
{
private bool IsUnsolicitedPongKeepAlive => _keepAlivePingState is null;
private static bool IsValidSendState(WebSocketState state) => Array.IndexOf(s_validSendStates, state) != -1;
private static bool IsValidReceiveState(WebSocketState state) => Array.IndexOf(s_validReceiveStates, state) != -1;

private void HeartBeat()
{
Expand All @@ -36,21 +33,19 @@ private void UnsolicitedPongHeartBeat()
TrySendKeepAliveFrameAsync(MessageOpcode.Pong));
}

private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemory<byte>? payload = null)
private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemory<byte> payload = default)
{
Debug.Assert(opcode is MessageOpcode.Pong || !IsUnsolicitedPongKeepAlive && opcode is MessageOpcode.Ping);
Debug.Assert((opcode is MessageOpcode.Pong) || (!IsUnsolicitedPongKeepAlive && opcode is MessageOpcode.Ping));

if (!IsValidSendState(_state))
if (!WebSocketStateHelper.IsValidSendState(_state))
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Cannot send keep-alive frame in {nameof(_state)}={_state}");

// we can't send any frames, but no need to throw as we are not observing errors anyway
return ValueTask.CompletedTask;
}

payload ??= ReadOnlyMemory<byte>.Empty;

return SendFrameAsync(opcode, endOfMessage: true, disableCompression: true, payload.Value, CancellationToken.None);
return SendFrameAsync(opcode, endOfMessage: true, disableCompression: true, payload, CancellationToken.None);
}

private void KeepAlivePingHeartBeat()
Expand All @@ -76,7 +71,7 @@ private void KeepAlivePingHeartBeat()

if (_keepAlivePingState.PingSent)
{
if (Environment.TickCount64 > _keepAlivePingState.PingTimeoutTimestamp)
if (now > _keepAlivePingState.PingTimeoutTimestamp)
{
if (NetEventSource.Log.IsEnabled())
{
Expand All @@ -92,7 +87,7 @@ private void KeepAlivePingHeartBeat()
}
else
{
if (Environment.TickCount64 > _keepAlivePingState.NextPingRequestTimestamp)
if (now > _keepAlivePingState.NextPingRequestTimestamp)
{
_keepAlivePingState.OnNextPingRequestCore(); // we are holding the lock
shouldSendPing = true;
Expand All @@ -119,18 +114,12 @@ private async ValueTask SendPingAsync(long pingPayload)
{
Debug.Assert(_keepAlivePingState != null);

byte[] pingPayloadBuffer = ArrayPool<byte>.Shared.Rent(sizeof(long));
byte[] pingPayloadBuffer = new byte[sizeof(long)];
BinaryPrimitives.WriteInt64BigEndian(pingPayloadBuffer, pingPayload);
try
{
await TrySendKeepAliveFrameAsync(MessageOpcode.Ping, pingPayloadBuffer.AsMemory(0, sizeof(long))).ConfigureAwait(false);

if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload);
}
finally
{
ArrayPool<byte>.Shared.Return(pingPayloadBuffer);
}
await TrySendKeepAliveFrameAsync(MessageOpcode.Ping, pingPayloadBuffer).ConfigureAwait(false);

if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload);
}

// "Observe" either a ValueTask result, or any exception, ignoring it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ internal sealed partial class ManagedWebSocket : WebSocket
/// <summary>Encoding for the payload of text messages: UTF-8 encoding that throws if invalid bytes are discovered, per the RFC.</summary>
private static readonly UTF8Encoding s_textEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);

/// <summary>Valid states to be in when calling SendAsync.</summary>
private static readonly WebSocketState[] s_validSendStates = { WebSocketState.Open, WebSocketState.CloseReceived };
/// <summary>Valid states to be in when calling ReceiveAsync.</summary>
private static readonly WebSocketState[] s_validReceiveStates = { WebSocketState.Open, WebSocketState.CloseSent };
/// <summary>Valid states to be in when calling CloseOutputAsync.</summary>
private static readonly WebSocketState[] s_validCloseOutputStates = { WebSocketState.Open, WebSocketState.CloseReceived };
/// <summary>Valid states to be in when calling CloseAsync.</summary>
private static readonly WebSocketState[] s_validCloseStates = { WebSocketState.Open, WebSocketState.CloseReceived, WebSocketState.CloseSent };

/// <summary>The maximum size in bytes of a message frame header that includes mask bytes.</summary>
internal const int MaxMessageHeaderLength = 14;
/// <summary>The maximum size of a control message payload.</summary>
Expand Down Expand Up @@ -337,7 +328,7 @@ public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessag

try
{
ThrowIfInvalidState(s_validSendStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidSendStates);
}
catch (Exception exc)
{
Expand Down Expand Up @@ -377,7 +368,7 @@ public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buf

try
{
ThrowIfInvalidState(s_validReceiveStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidReceiveStates);

return ReceiveAsyncPrivate<WebSocketReceiveResult>(buffer, cancellationToken).AsTask();
}
Expand All @@ -394,7 +385,7 @@ public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte>

try
{
ThrowIfInvalidState(s_validReceiveStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidReceiveStates);

return ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(buffer, cancellationToken);
}
Expand All @@ -413,7 +404,7 @@ public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? status

try
{
ThrowIfInvalidState(s_validCloseStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidCloseStates);
}
catch (Exception exc)
{
Expand All @@ -436,7 +427,7 @@ private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);

ThrowIfInvalidState(s_validCloseOutputStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidCloseOutputStates);

await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -1737,9 +1728,9 @@ private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberNa
cancellationToken);
}

private void ThrowIfDisposed() => ThrowIfInvalidState();
private void ThrowIfDisposed() => ThrowIfInvalidState(validStates: ManagedWebSocketStates.All);

private void ThrowIfInvalidState(WebSocketState[]? validStates = null)
private void ThrowIfInvalidState(ManagedWebSocketStates validStates)
{
bool disposed = _disposed;
WebSocketState state = _state;
Expand All @@ -1758,7 +1749,7 @@ private void ThrowIfInvalidState(WebSocketState[]? validStates = null)

if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_state={state}, _disposed={disposed}, _keepAlivePingState.Exception={keepAliveException}");

WebSocketValidate.ThrowIfInvalidState(state, disposed, keepAliveException, validStates);
WebSocketStateHelper.ThrowIfInvalidState(state, disposed, keepAliveException, validStates);
}

// From https://github.com/aspnet/WebSockets/blob/aa63e27fce2e9202698053620679a9a1059b501e/src/Microsoft.AspNetCore.WebSockets.Protocol/Utilities.cs#L75
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace System.Net.WebSockets
{
[Flags]
internal enum ManagedWebSocketStates
{
None = 0,

// WebSocketState.None = 0 -- this state is invalid for the managed implementation
// WebSocketState.Connecting = 1 -- this state is invalid for the managed implementation
Open = 0x04, // WebSocketState.Open = 2
CloseSent = 0x08, // WebSocketState.CloseSent = 3
CloseReceived = 0x10, // WebSocketState.CloseReceived = 4
Closed = 0x20, // WebSocketState.Closed = 5
Aborted = 0x40, // WebSocketState.Aborted = 6

All = Open | CloseSent | CloseReceived | Closed | Aborted
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ internal sealed partial class NetEventSource

private const int MutexEnterId = SendStopId + 1;
private const int MutexExitId = MutexEnterId + 1;
private const int MutexContendedId = MutexExitId + 1;

//
// Keep-Alive
Expand Down Expand Up @@ -185,10 +184,6 @@ private void MutexEnter(string objName, string memberName) =>
private void MutexExit(string objName, string memberName) =>
WriteEvent(MutexExitId, objName, memberName);

[Event(MutexContendedId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
private void MutexContended(string objName, string memberName, int queueLength) =>
WriteEvent(MutexContendedId, objName, memberName, queueLength);

[NonEvent]
public static void MutexEntered(object? obj, [CallerMemberName] string? memberName = null)
{
Expand All @@ -203,13 +198,6 @@ public static void MutexExited(object? obj, [CallerMemberName] string? memberNam
Log.MutexExit(IdOf(obj), memberName ?? MissingMember);
}

[NonEvent]
public static void MutexContended(object? obj, int gateValue, [CallerMemberName] string? memberName = null)
{
Debug.Assert(Log.IsEnabled());
Log.MutexContended(IdOf(obj), memberName ?? MissingMember, -gateValue);
}

//
// WriteEvent overloads
//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;

namespace System.Net.WebSockets
{
internal static class WebSocketStateHelper
{
/// <summary>Valid states to be in when calling SendAsync.</summary>
internal const ManagedWebSocketStates ValidSendStates = ManagedWebSocketStates.Open | ManagedWebSocketStates.CloseReceived;
/// <summary>Valid states to be in when calling ReceiveAsync.</summary>
internal const ManagedWebSocketStates ValidReceiveStates = ManagedWebSocketStates.Open | ManagedWebSocketStates.CloseSent;
/// <summary>Valid states to be in when calling CloseOutputAsync.</summary>
internal const ManagedWebSocketStates ValidCloseOutputStates = ManagedWebSocketStates.Open | ManagedWebSocketStates.CloseReceived;
/// <summary>Valid states to be in when calling CloseAsync.</summary>
internal const ManagedWebSocketStates ValidCloseStates = ManagedWebSocketStates.Open | ManagedWebSocketStates.CloseReceived | ManagedWebSocketStates.CloseSent;

internal static bool IsValidSendState(WebSocketState state) => ValidSendStates.HasFlag(ToFlag(state));

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, Exception? innerException, ManagedWebSocketStates validStates)
{
ManagedWebSocketStates state = ToFlag(currentState);

if ((state & validStates) == 0)
{
string invalidStateMessage = SR.Format(
SR.net_WebSockets_InvalidState, currentState, validStates);

throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, innerException);
}

if (innerException is not null)
{
Debug.Assert(state == ManagedWebSocketStates.Aborted);
throw new OperationCanceledException(nameof(WebSocketState.Aborted), innerException);
}

// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
}

private static ManagedWebSocketStates ToFlag(WebSocketState value)
{
ManagedWebSocketStates flag = (ManagedWebSocketStates)(1 << (int)value);
Debug.Assert(Enum.IsDefined(flag));
return flag;
}
}
}

0 comments on commit b6b0fb1

Please sign in to comment.