diff --git a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs index 35364a0af8476b..fc283c5c180710 100644 --- a/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs +++ b/src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs @@ -2831,6 +2831,24 @@ public abstract class WrappingConnectedStreamConformanceTests : ConnectedStreamC protected virtual bool ExtraZeroByteReadsAllowed => false; + private static bool IsAsync(ReadWriteMode mode) => + mode == ReadWriteMode.AsyncArray || + mode == ReadWriteMode.AsyncMemory || + mode == ReadWriteMode.AsyncAPM; + + private Task CreateWrappedStreamsAsync(StreamPair streams, bool asyncOnly = false, bool leaveOpen = false) + { + (Stream writeable, Stream readable) = GetReadWritePair(streams); + + if (asyncOnly) + { + writeable = new AsyncOnlyStream(writeable); + readable = new AsyncOnlyStream(readable); + } + + return CreateWrappedConnectedStreamsAsync((writeable, readable), leaveOpen); + } + [Theory] [InlineData(false)] [InlineData(true)] @@ -2845,22 +2863,36 @@ public virtual async Task Flush_FlushesUnderlyingStream(bool flushAsync) (Stream writeable, Stream readable) = GetReadWritePair(streams); var tracker = new CallTrackingStream(writeable); - using StreamPair wrapper = await CreateWrappedConnectedStreamsAsync((tracker, readable)); + StreamPair wrapper = await CreateWrappedStreamsAsync((tracker, readable), asyncOnly: flushAsync); + + try + { + int orig = tracker.TimesCalled(nameof(tracker.Flush)) + tracker.TimesCalled(nameof(tracker.FlushAsync)); - int orig = tracker.TimesCalled(nameof(tracker.Flush)) + tracker.TimesCalled(nameof(tracker.FlushAsync)); + tracker.WriteByte(1); - tracker.WriteByte(1); + if (flushAsync) + { + await wrapper.Stream1.FlushAsync(); + } + else + { + wrapper.Stream1.Flush(); + } - if (flushAsync) - { - await wrapper.Stream1.FlushAsync(); + Assert.InRange(tracker.TimesCalled(nameof(tracker.Flush)) + tracker.TimesCalled(nameof(tracker.FlushAsync)), orig + 1, int.MaxValue); } - else + finally { - wrapper.Stream1.Flush(); + if (flushAsync) + { + await wrapper.DisposeAsync(); + } + else + { + wrapper.Dispose(); + } } - - Assert.InRange(tracker.TimesCalled(nameof(tracker.Flush)) + tracker.TimesCalled(nameof(tracker.FlushAsync)), orig + 1, int.MaxValue); } [Theory] @@ -2876,28 +2908,52 @@ public virtual async Task Dispose_Flushes(bool useAsync, bool leaveOpen) } using StreamPair streams = ConnectedStreams.CreateBidirectional(); - using StreamPair wrapper = await CreateWrappedConnectedStreamsAsync(streams, leaveOpen); - (Stream writeable, Stream readable) = GetReadWritePair(wrapper); - - await Task.WhenAll( - Task.Run(async () => - { - writeable.WriteByte(1); + StreamPair wrapper = await CreateWrappedStreamsAsync(streams, useAsync, leaveOpen); + try + { + (Stream writeable, Stream readable) = GetReadWritePair(wrapper); - if (useAsync) + await Task.WhenAll( + Task.Run(async () => { - await writeable.DisposeAsync(); - } - else + if (useAsync) + { + await writeable.WriteAsync(new byte[] { 1 }); + await writeable.DisposeAsync(); + } + else + { + writeable.WriteByte(1); + writeable.Dispose(); + } + }), + Task.Run(async () => { - writeable.Dispose(); - } - }), - Task.Run(() => + if (useAsync) + { + byte[] buffer = new byte[1]; + Assert.Equal(1, await readable.ReadAsync(buffer)); + Assert.Equal(1, buffer[0]); + await readable.DisposeAsync(); + } + else + { + Assert.Equal(1, readable.ReadByte()); + readable.Dispose(); + } + })); + } + finally + { + if (useAsync) { - Assert.Equal(1, readable.ReadByte()); - readable.Dispose(); - })); + await wrapper.DisposeAsync(); + } + else + { + wrapper.Dispose(); + } + } } [Theory] @@ -2914,42 +2970,63 @@ public virtual async Task Dispose_ClosesInnerStreamIfDesired(bool useAsync, bool using StreamPair streams = ConnectedStreams.CreateBidirectional(); (Stream writeable, Stream readable) = GetReadWritePair(streams); - using StreamPair wrapper = await CreateWrappedConnectedStreamsAsync((writeable, readable), leaveOpen); - (Stream writeableWrapper, Stream readableWrapper) = GetReadWritePair(wrapper); + StreamPair wrapper = await CreateWrappedStreamsAsync((writeable, readable), useAsync, leaveOpen); + try + { + (Stream writeableWrapper, Stream readableWrapper) = GetReadWritePair(wrapper); - await Task.WhenAll( - Task.Run(async () => - { - if (useAsync) + await Task.WhenAll( + Task.Run(async () => { - await writeableWrapper.DisposeAsync(); - } - else + if (useAsync) + { + await writeableWrapper.DisposeAsync(); + } + else + { + writeableWrapper.Dispose(); + } + }), + Task.Run(async () => { - writeableWrapper.Dispose(); - } - }), - Task.Run(async () => + if (useAsync) + { + await readableWrapper.DisposeAsync(); + } + else + { + readableWrapper.Dispose(); + } + })); + + if (leaveOpen) + { + await WhenAllOrAnyFailed( + writeable.WriteAsync(new byte[] { 42 }, 0, 1), + Task.Run(() => readable.ReadByte())); + } + else { if (useAsync) { - await readableWrapper.DisposeAsync(); + await Assert.ThrowsAsync(async () => { await writeable.WriteAsync(new byte[] { 42 }, 0, 1); }); } else { - readableWrapper.Dispose(); + Assert.Throws(() => writeable.WriteByte(42)); } - })); - - if (leaveOpen) - { - await WhenAllOrAnyFailed( - writeable.WriteAsync(new byte[] { 42 }, 0, 1), - Task.Run(() => readable.ReadByte())); + } } - else + finally { - Assert.Throws(() => writeable.WriteByte(42)); + if (useAsync) + { + await wrapper.DisposeAsync(); + } + else + { + wrapper.Dispose(); + } } } @@ -2963,7 +3040,7 @@ public virtual async Task UseWrappedAfterClose_Success() using StreamPair streams = ConnectedStreams.CreateBidirectional(); - using (StreamPair wrapper = await CreateWrappedConnectedStreamsAsync(streams, leaveOpen: true)) + using (StreamPair wrapper = await CreateWrappedStreamsAsync(streams, leaveOpen: true)) { foreach ((Stream writeable, Stream readable) in GetReadWritePairs(wrapper)) { @@ -2983,9 +3060,9 @@ public virtual async Task UseWrappedAfterClose_Success() public virtual async Task NestedWithinSelf_ReadWrite_Success() { using StreamPair streams = ConnectedStreams.CreateBidirectional(); - using StreamPair wrapper1 = await CreateWrappedConnectedStreamsAsync(streams); - using StreamPair wrapper2 = await CreateWrappedConnectedStreamsAsync(wrapper1); - using StreamPair wrapper3 = await CreateWrappedConnectedStreamsAsync(wrapper2); + using StreamPair wrapper1 = await CreateWrappedStreamsAsync(streams); + using StreamPair wrapper2 = await CreateWrappedStreamsAsync(wrapper1); + using StreamPair wrapper3 = await CreateWrappedStreamsAsync(wrapper2); if (Bidirectional(wrapper3) && FlushGuaranteesAllDataWritten) { @@ -3046,6 +3123,8 @@ public virtual async Task ZeroByteRead_PerformsZeroByteReadOnUnderlyingStreamWhe return; } + bool useAsync = IsAsync(mode); + // This is the data we will send across the connected streams. We assume this data will both // (a) produce at least two readable bytes, so we can unblock the reader and read a single byte without clearing its buffer; and // (b) produce no more than 1K of readable bytes, so we can clear the reader buffer below. @@ -3056,53 +3135,66 @@ public virtual async Task ZeroByteRead_PerformsZeroByteReadOnUnderlyingStreamWhe (Stream innerWriteable, Stream innerReadable) = GetReadWritePair(innerStreams); var tracker = new ZeroByteReadTrackingStream(innerReadable, ExtraZeroByteReadsAllowed); - using StreamPair streams = await CreateWrappedConnectedStreamsAsync((innerWriteable, tracker)); - - (Stream writeable, Stream readable) = GetReadWritePair(streams); - - for (int iter = 0; iter < 2; iter++) + StreamPair streams = await CreateWrappedStreamsAsync((innerWriteable, tracker), useAsync); + try { - // Register to be signalled for the zero byte read. - var signalTask = tracker.WaitForZeroByteReadAsync(); + (Stream writeable, Stream readable) = GetReadWritePair(streams); - // Issue zero byte read against wrapper stream. - Task zeroByteRead = Task.Run(() => ReadAsync(mode, readable, Array.Empty(), 0, 0)); + for (int iter = 0; iter < 2; iter++) + { + // Register to be signalled for the zero byte read. + var signalTask = tracker.WaitForZeroByteReadAsync(); - // The tracker stream will signal us when the zero byte read actually happens. - await signalTask; + // Issue zero byte read against wrapper stream. + Task zeroByteRead = Task.Run(() => ReadAsync(mode, readable, Array.Empty(), 0, 0)); - // Write some data (see notes above re 'data') - await writeable.WriteAsync(data); - if (FlushRequiredToWriteData) - { - await writeable.FlushAsync(); - } + // The tracker stream will signal us when the zero byte read actually happens. + await signalTask; - // Reader should be unblocked, and we should have issued a zero byte read against the underlying stream as part of unblocking. - int bytesRead = await zeroByteRead; - Assert.Equal(0, bytesRead); + // Write some data (see notes above re 'data') + await writeable.WriteAsync(data); + if (FlushRequiredToWriteData) + { + await writeable.FlushAsync(); + } - byte[] buffer = new byte[1024]; + // Reader should be unblocked, and we should have issued a zero byte read against the underlying stream as part of unblocking. + int bytesRead = await zeroByteRead; + Assert.Equal(0, bytesRead); - // Should be able to read one byte without blocking - var readTask = ReadAsync(mode, readable, buffer, 0, 1); - Assert.True(readTask.IsCompleted); - bytesRead = await readTask; - Assert.Equal(1, bytesRead); + byte[] buffer = new byte[1024]; - // Issue zero byte read against wrapper stream. Since there is still data available, this should complete immediately and not do another zero-byte read. - readTask = ReadAsync(mode, readable, Array.Empty(), 0, 0); - Assert.True(readTask.IsCompleted); - Assert.Equal(0, await readTask); + // Should be able to read one byte without blocking + var readTask = ReadAsync(mode, readable, buffer, 0, 1); + Assert.True(readTask.IsCompleted); + bytesRead = await readTask; + Assert.Equal(1, bytesRead); - // Clear the reader stream of any buffered data by doing a large read, which again should not block. - readTask = ReadAsync(mode, readable, buffer, 1, buffer.Length - 1); - Assert.True(readTask.IsCompleted); - bytesRead += await readTask; + // Issue zero byte read against wrapper stream. Since there is still data available, this should complete immediately and not do another zero-byte read. + readTask = ReadAsync(mode, readable, Array.Empty(), 0, 0); + Assert.True(readTask.IsCompleted); + Assert.Equal(0, await readTask); - if (FlushGuaranteesAllDataWritten) + // Clear the reader stream of any buffered data by doing a large read, which again should not block. + readTask = ReadAsync(mode, readable, buffer, 1, buffer.Length - 1); + Assert.True(readTask.IsCompleted); + bytesRead += await readTask; + + if (FlushGuaranteesAllDataWritten) + { + AssertExtensions.SequenceEqual(data.AsSpan(), buffer.AsSpan(0, bytesRead)); + } + } + } + finally + { + if (useAsync) + { + await streams.DisposeAsync(); + } + else { - AssertExtensions.SequenceEqual(data.AsSpan(), buffer.AsSpan(0, bytesRead)); + streams.Dispose(); } } } @@ -3173,10 +3265,69 @@ public override ValueTask ReadAsync(Memory buffer, CancellationToken return base.ReadAsync(buffer, cancellationToken); } } + + private sealed class AsyncOnlyStream : Stream + { + private readonly Stream _innerStream; + private bool _disposed; + + public AsyncOnlyStream(Stream innerStream) + { + _innerStream = innerStream; + } + + public override bool CanRead => _innerStream.CanRead; + public override bool CanSeek => _innerStream.CanSeek; + public override bool CanWrite => _innerStream.CanWrite; + public override long Length => _innerStream.Length; + public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; } + public override long Seek(long offset, SeekOrigin origin) => _innerStream.Seek(offset, origin); + public override void SetLength(long value) => _innerStream.SetLength(value); + + public override void Flush() + { + ObjectDisposedException.ThrowIf(_disposed, _innerStream); + throw new NotSupportedException("Synchronous operations are not supported."); + } + + public override int Read(byte[] buffer, int offset, int count) + { + ObjectDisposedException.ThrowIf(_disposed, _innerStream); + throw new NotSupportedException("Synchronous operations are not supported."); + } + + public override void Write(byte[] buffer, int offset, int count) + { + ObjectDisposedException.ThrowIf(_disposed, _innerStream); + throw new NotSupportedException("Synchronous operations are not supported."); + } + + public override ValueTask ReadAsync(Memory buffer, System.Threading.CancellationToken cancellationToken = default) => _innerStream.ReadAsync(buffer, cancellationToken); + public override ValueTask WriteAsync(ReadOnlyMemory buffer, System.Threading.CancellationToken cancellationToken = default) => _innerStream.WriteAsync(buffer, cancellationToken); + public override Task FlushAsync(System.Threading.CancellationToken cancellationToken = default) => _innerStream.FlushAsync(cancellationToken); + + protected override void Dispose(bool disposing) + { + if (_disposed) return; + + _disposed = true; + _innerStream.Dispose(); + } + + public override ValueTask DisposeAsync() + { + if (_disposed) return default; + + _disposed = true; + return _innerStream.DisposeAsync(); + } + + public override void Close() => _innerStream.Close(); + } } /// Provides a disposable, enumerable tuple of two streams. - public class StreamPair : IDisposable, IEnumerable + public class StreamPair : IDisposable, IAsyncDisposable, IEnumerable { public readonly Stream Stream1, Stream2; @@ -3202,6 +3353,13 @@ public virtual void Dispose() Task.Run(() => Stream2?.Dispose())); } + public virtual async ValueTask DisposeAsync() + { + await Task.WhenAll( + Stream1?.DisposeAsync().AsTask() ?? Task.CompletedTask, + Stream2?.DisposeAsync().AsTask() ?? Task.CompletedTask).ConfigureAwait(false); + } + public IEnumerator GetEnumerator() { yield return Stream1; diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamConformanceTests.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamConformanceTests.cs index dd01c99c4b5a49..d919ea25f839d6 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamConformanceTests.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamConformanceTests.cs @@ -38,14 +38,10 @@ await new[] { // TLS 1.3 can generate some extra messages and we may get reset if test sends unidirectional traffic // and extra packet stays in socket buffer. - - // This ping-ping should flush leftovers from the handshake. - // We use sync method to preserve socket in default blocking state - // (as we don't go back once Async is used at least once) - ssl1.Write(new byte[1]); - ssl2.Write(new byte[1]); - Assert.Equal(1, ssl2.Read(new byte[1])); - Assert.Equal(1, ssl1.Read(new byte[1])); + await ssl1.WriteAsync(new byte[1]); + await ssl2.WriteAsync(new byte[1]); + Assert.Equal(1, await ssl2.ReadAsync(new byte[1])); + Assert.Equal(1, await ssl1.ReadAsync(new byte[1])); } return new StreamPair(ssl1, ssl2);