Skip to content

Commit ea12f3b

Browse files
committed
reuse allocated buffer in stream reads for TCP/Unix sockets
1 parent 1279244 commit ea12f3b

File tree

2 files changed

+117
-74
lines changed

2 files changed

+117
-74
lines changed

io/jvm-native/src/main/scala/fs2/io/net/SocketPlatform.scala

Lines changed: 94 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ package io
2424
package net
2525

2626
import com.comcast.ip4s.{IpAddress, SocketAddress}
27-
import cats.effect.Async
27+
import cats.effect.{Async, Resource}
2828
import cats.effect.std.Mutex
2929
import cats.syntax.all._
3030

@@ -33,82 +33,124 @@ import java.nio.channels.{AsynchronousSocketChannel, CompletionHandler}
3333
import java.nio.{Buffer, ByteBuffer}
3434

3535
private[net] trait SocketCompanionPlatform {
36+
37+
/** Creates a [[Socket]] instance for given `AsynchronousSocketChannel`
38+
* with 16 KiB max. read chunk size and exclusive access guards for both reads abd writes.
39+
*/
3640
private[net] def forAsync[F[_]: Async](
3741
ch: AsynchronousSocketChannel
3842
): F[Socket[F]] =
39-
(Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) =>
40-
new AsyncSocket[F](ch, readMutex, writeMutex)
43+
forAsync(ch, maxReadChunkSize = 16384, withExclusiveReads = true, withExclusiveWrites = true)
44+
45+
/** Creates a [[Socket]] instance for given `AsynchronousSocketChannel`.
46+
*
47+
* @param ch async socket channel for actual reads and writes
48+
* @param maxReadChunkSize maximum chunk size for [[Socket#reads]] method
49+
* @param withExclusiveReads set to `true` if reads should be guarded by mutex
50+
* @param withExclusiveWrites set to `true` if writes should be guarded by mutex
51+
*/
52+
private[net] def forAsync[F[_]](
53+
ch: AsynchronousSocketChannel,
54+
maxReadChunkSize: Int,
55+
withExclusiveReads: Boolean = false,
56+
withExclusiveWrites: Boolean = false
57+
)(implicit F: Async[F]): F[Socket[F]] = {
58+
def maybeMutex(maybe: Boolean) = F.defer(if (maybe) Mutex[F].map(Some(_)) else F.pure(None))
59+
(maybeMutex(withExclusiveReads), maybeMutex(withExclusiveWrites)).mapN {
60+
(readMutex, writeMutex) => new AsyncSocket[F](ch, readMutex, writeMutex, maxReadChunkSize)
4161
}
62+
}
4263

4364
private[net] abstract class BufferedReads[F[_]](
44-
readMutex: Mutex[F]
65+
readMutex: Option[Mutex[F]],
66+
writeMutex: Option[Mutex[F]],
67+
maxReadChunkSize: Int
4568
)(implicit F: Async[F])
4669
extends Socket[F] {
47-
private[this] final val defaultReadSize = 8192
48-
private[this] var readBuffer: ByteBuffer = ByteBuffer.allocate(defaultReadSize)
70+
private def lock(mutex: Option[Mutex[F]]): Resource[F, Unit] =
71+
mutex match {
72+
case Some(mutex) => mutex.lock
73+
case None => Resource.unit
74+
}
4975

5076
private def withReadBuffer[A](size: Int)(f: ByteBuffer => F[A]): F[A] =
51-
readMutex.lock.surround {
52-
F.delay {
53-
if (readBuffer.capacity() < size)
54-
readBuffer = ByteBuffer.allocate(size)
55-
else
56-
(readBuffer: Buffer).limit(size)
57-
f(readBuffer)
58-
}.flatten
77+
lock(readMutex).surround {
78+
F.delay(ByteBuffer.allocate(size)).flatMap(f)
5979
}
6080

6181
/** Performs a single channel read operation in to the supplied buffer. */
6282
protected def readChunk(buffer: ByteBuffer): F[Int]
6383

64-
/** Copies the contents of the supplied buffer to a `Chunk[Byte]` and clears the buffer contents. */
65-
private def releaseBuffer(buffer: ByteBuffer): F[Chunk[Byte]] =
66-
F.delay {
67-
val read = buffer.position()
68-
val result =
69-
if (read == 0) Chunk.empty
70-
else {
71-
val dest = new Array[Byte](read)
72-
(buffer: Buffer).flip()
73-
buffer.get(dest)
74-
Chunk.array(dest)
75-
}
76-
(buffer: Buffer).clear()
77-
result
78-
}
84+
/** Performs a channel write operation(-s) from the supplied buffer.
85+
*
86+
* Write could be performed multiple times till all buffer remaining contents are written.
87+
*/
88+
protected def writeChunk(buffer: ByteBuffer): F[Unit]
7989

8090
def read(max: Int): F[Option[Chunk[Byte]]] =
8191
withReadBuffer(max) { buffer =>
82-
readChunk(buffer).flatMap { read =>
83-
if (read < 0) F.pure(None)
84-
else releaseBuffer(buffer).map(Some(_))
92+
readChunk(buffer).map { read =>
93+
if (read < 0) None
94+
else if (buffer.position() == 0) Some(Chunk.empty)
95+
else {
96+
(buffer: Buffer).flip()
97+
Some(Chunk.byteBuffer(buffer.asReadOnlyBuffer()))
98+
}
8599
}
86100
}
87101

88102
def readN(max: Int): F[Chunk[Byte]] =
89103
withReadBuffer(max) { buffer =>
90104
def go: F[Chunk[Byte]] =
91105
readChunk(buffer).flatMap { readBytes =>
92-
if (readBytes < 0 || buffer.position() >= max)
93-
releaseBuffer(buffer)
94-
else go
106+
if (readBytes < 0 || buffer.position() >= max) {
107+
(buffer: Buffer).flip()
108+
F.pure(Chunk.byteBuffer(buffer.asReadOnlyBuffer()))
109+
} else go
95110
}
96111
go
97112
}
98113

99114
def reads: Stream[F, Byte] =
100-
Stream.repeatEval(read(defaultReadSize)).unNoneTerminate.unchunks
115+
Stream.resource(lock(readMutex)).flatMap { _ =>
116+
Stream.unfoldChunkEval(ByteBuffer.allocate(maxReadChunkSize)) { case buffer =>
117+
readChunk(buffer).flatMap { read =>
118+
if (read < 0) none[(Chunk[Byte], ByteBuffer)].pure
119+
else if (buffer.position() == 0) {
120+
(Chunk.empty[Byte] -> buffer).some.pure
121+
} else if (buffer.remaining() == 0) {
122+
val bytes = buffer.asReadOnlyBuffer()
123+
val fresh = ByteBuffer.allocate(maxReadChunkSize)
124+
(Chunk.byteBuffer(bytes) -> fresh).some.pure
125+
} else {
126+
val bytes = buffer.duplicate().asReadOnlyBuffer()
127+
val slice = buffer.slice()
128+
(bytes: Buffer).flip()
129+
(Chunk.byteBuffer(bytes) -> slice).some.pure
130+
}
131+
}
132+
}
133+
}
134+
135+
def write(bytes: Chunk[Byte]): F[Unit] =
136+
lock(writeMutex).surround {
137+
F.delay(bytes.toByteBuffer).flatMap(writeChunk)
138+
}
101139

102-
def writes: Pipe[F, Byte, Nothing] =
103-
_.chunks.foreach(write)
140+
def writes: Pipe[F, Byte, Nothing] = { in =>
141+
Stream.resource(lock(writeMutex)).flatMap { _ =>
142+
in.chunks.foreach(chunk => writeChunk(chunk.toByteBuffer))
143+
}
144+
}
104145
}
105146

106147
private final class AsyncSocket[F[_]](
107148
ch: AsynchronousSocketChannel,
108-
readMutex: Mutex[F],
109-
writeMutex: Mutex[F]
149+
readMutex: Option[Mutex[F]],
150+
writeMutex: Option[Mutex[F]],
151+
maxReadChunkSize: Int
110152
)(implicit F: Async[F])
111-
extends BufferedReads[F](readMutex) {
153+
extends BufferedReads[F](readMutex, writeMutex, maxReadChunkSize) {
112154

113155
protected def readChunk(buffer: ByteBuffer): F[Int] =
114156
F.async[Int] { cb =>
@@ -120,24 +162,18 @@ private[net] trait SocketCompanionPlatform {
120162
F.delay(Some(endOfInput.voidError))
121163
}
122164

123-
def write(bytes: Chunk[Byte]): F[Unit] = {
124-
def go(buff: ByteBuffer): F[Unit] =
125-
F.async[Int] { cb =>
126-
ch.write(
127-
buff,
128-
null,
129-
new IntCompletionHandler(cb)
130-
)
131-
F.delay(Some(endOfOutput.voidError))
132-
}.flatMap { written =>
133-
if (written >= 0 && buff.remaining() > 0)
134-
go(buff)
135-
else F.unit
136-
}
137-
writeMutex.lock.surround {
138-
F.delay(bytes.toByteBuffer).flatMap(go)
165+
protected def writeChunk(buffer: ByteBuffer): F[Unit] =
166+
F.async[Int] { cb =>
167+
ch.write(
168+
buffer,
169+
null,
170+
new IntCompletionHandler(cb)
171+
)
172+
F.delay(Some(endOfOutput.voidError))
173+
}.flatMap { written =>
174+
if (written < 0 || buffer.remaining() == 0) F.unit
175+
else writeChunk(buffer)
139176
}
140-
}
141177

142178
def localAddress: F[SocketAddress[IpAddress]] =
143179
F.delay(

io/jvm/src/main/scala/fs2/io/net/unixsocket/UnixSocketsPlatform.scala

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import cats.effect.std.Mutex
2828
import cats.effect.syntax.all._
2929
import cats.syntax.all._
3030
import com.comcast.ip4s.{IpAddress, SocketAddress}
31-
import fs2.{Chunk, Stream}
31+
import fs2.Stream
3232
import fs2.io.file.{Files, Path}
3333
import fs2.io.net.Socket
3434
import java.nio.ByteBuffer
@@ -89,29 +89,36 @@ private[unixsocket] trait UnixSocketsCompanionPlatform {
8989
private def makeSocket[F[_]: Async](
9090
ch: SocketChannel
9191
): F[Socket[F]] =
92-
(Mutex[F], Mutex[F]).mapN { (readMutex, writeMutex) =>
93-
new AsyncSocket[F](ch, readMutex, writeMutex)
92+
makeSocket(ch, maxReadChunkSize = 16384, withExclusiveReads = true, withExclusiveWrites = true)
93+
94+
private def makeSocket[F[_]](
95+
ch: SocketChannel,
96+
maxReadChunkSize: Int,
97+
withExclusiveReads: Boolean,
98+
withExclusiveWrites: Boolean
99+
)(implicit F: Async[F]): F[Socket[F]] = {
100+
def maybeMutex(maybe: Boolean) = F.defer(if (maybe) Mutex[F].map(Some(_)) else F.pure(None))
101+
(maybeMutex(withExclusiveReads), maybeMutex(withExclusiveWrites)).mapN {
102+
(readMutex, writeMutex) => new AsyncSocket[F](ch, readMutex, writeMutex, maxReadChunkSize)
94103
}
104+
}
95105

96106
private final class AsyncSocket[F[_]](
97107
ch: SocketChannel,
98-
readMutex: Mutex[F],
99-
writeMutex: Mutex[F]
108+
readMutex: Option[Mutex[F]],
109+
writeMutex: Option[Mutex[F]],
110+
maxReadChunkSize: Int
100111
)(implicit F: Async[F])
101-
extends Socket.BufferedReads[F](readMutex) {
112+
extends Socket.BufferedReads[F](readMutex, writeMutex, maxReadChunkSize) {
102113

103-
def readChunk(buff: ByteBuffer): F[Int] =
104-
F.blocking(ch.read(buff)).cancelable(close)
114+
protected def readChunk(buffer: ByteBuffer): F[Int] =
115+
F.blocking(ch.read(buffer)).cancelable(close)
105116

106-
def write(bytes: Chunk[Byte]): F[Unit] = {
107-
def go(buff: ByteBuffer): F[Unit] =
108-
F.blocking(ch.write(buff)).cancelable(close) *>
109-
F.delay(buff.remaining <= 0).ifM(F.unit, go(buff))
110-
111-
writeMutex.lock.surround {
112-
F.delay(bytes.toByteBuffer).flatMap(go)
117+
protected def writeChunk(buffer: ByteBuffer): F[Unit] =
118+
F.blocking(ch.write(buffer)).cancelable(close).flatMap { _ =>
119+
if (buffer.remaining() == 0) F.unit
120+
else writeChunk(buffer)
113121
}
114-
}
115122

116123
def localAddress: F[SocketAddress[IpAddress]] = raiseIpAddressError
117124
def remoteAddress: F[SocketAddress[IpAddress]] = raiseIpAddressError

0 commit comments

Comments
 (0)