diff --git a/bson/src/main/org/bson/ByteBuf.java b/bson/src/main/org/bson/ByteBuf.java index e44a97dfc67..7cc5bdc77dc 100644 --- a/bson/src/main/org/bson/ByteBuf.java +++ b/bson/src/main/org/bson/ByteBuf.java @@ -125,6 +125,13 @@ public interface ByteBuf { */ ByteBuf flip(); + /** + * States whether this buffer is backed by an accessible byte array. + * + * @return {@code true} if, and only if, this buffer is backed by an array and is not read-only + */ + boolean hasArray(); + /** *
Returns the byte array that backs this buffer (optional operation).
* diff --git a/bson/src/main/org/bson/ByteBufNIO.java b/bson/src/main/org/bson/ByteBufNIO.java index 83bfa7d893a..fff6bbe6930 100644 --- a/bson/src/main/org/bson/ByteBufNIO.java +++ b/bson/src/main/org/bson/ByteBufNIO.java @@ -103,6 +103,11 @@ public ByteBuf flip() { return this; } + @Override + public boolean hasArray() { + return buf.hasArray(); + } + @Override public byte[] array() { return buf.array(); diff --git a/bson/src/main/org/bson/io/OutputBuffer.java b/bson/src/main/org/bson/io/OutputBuffer.java index 00f88cea706..bf70a68ee95 100644 --- a/bson/src/main/org/bson/io/OutputBuffer.java +++ b/bson/src/main/org/bson/io/OutputBuffer.java @@ -196,11 +196,15 @@ public void writeLong(final long value) { writeInt64(value); } - private int writeCharacters(final String str, final boolean checkForNullCharacters) { + protected int writeCharacters(final String str, final boolean checkForNullCharacters) { + return writeCharacters(str, 0, checkForNullCharacters); + } + + protected final int writeCharacters(final String str, int start, final boolean checkForNullCharacters) { int len = str.length(); int total = 0; - for (int i = 0; i < len;) { + for (int i = start; i < len;) { int c = Character.codePointAt(str, i); if (checkForNullCharacters && c == 0x0) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java index 40df1b867fd..8287315ac27 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java @@ -16,17 +16,21 @@ package com.mongodb.internal.connection; +import com.mongodb.internal.connection.netty.NettyByteBuf; +import org.bson.BsonSerializationException; import org.bson.ByteBuf; import org.bson.io.OutputBuffer; import java.io.IOException; import java.io.OutputStream; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.notNull; +import static java.lang.String.format; /** *This class is not part of the public API and may be removed or changed at any time
@@ -273,6 +277,95 @@ public void close() { } } + @Override + protected int writeCharacters(final String str, final boolean checkForNullCharacters) { + ensureOpen(); + ByteBuf buf = getCurrentByteBuffer(); + if ((buf.remaining() >= str.length() + 1)) { + if (buf.hasArray()) { + return writeCharactersOnArray(str, checkForNullCharacters, buf); + } else if (buf instanceof NettyByteBuf) { + io.netty.buffer.ByteBuf nettyBuffer = ((NettyByteBuf) buf).asByteBuf(); + if (nettyBuffer.nioBufferCount() == 1) { + return writeCharactersOnInternalNioNettyByteBuf(str, checkForNullCharacters, buf, nettyBuffer); + } + } + } + return super.writeCharacters(str, 0, checkForNullCharacters); + } + + private int writeCharactersOnInternalNioNettyByteBuf(String str, boolean checkForNullCharacters, ByteBuf buf, io.netty.buffer.ByteBuf nettyBuffer) { + int len = str.length(); + final ByteBuffer nioBuffer = internalNioBufferOf(nettyBuffer, len + 1); + int i = 0; + int pos = nioBuffer.position(); + for (; i < len; i++) { + char c = str.charAt(i); + if (checkForNullCharacters && c == 0x0) { + throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character " + + "at index %d", str, i)); + } + if (c >= 0x80) { + break; + } + nioBuffer.put(pos + i, (byte) c); + } + if (i == len) { + int total = len + 1; + nioBuffer.put(pos + len, (byte) 0); + position += total; + buf.position(buf.position() + total); + return len + 1; + } + // ith character is not ASCII + if (i > 0) { + position += i; + buf.position(buf.position() + i); + } + return i + super.writeCharacters(str, i, checkForNullCharacters); + } + + private static ByteBuffer internalNioBufferOf(io.netty.buffer.ByteBuf buf, int minCapacity) { + io.netty.buffer.ByteBuf unwrap; + while ((unwrap = buf.unwrap()) != null) { + buf = unwrap; + } + assert buf.unwrap() == null; + buf.ensureWritable(minCapacity); + return buf.internalNioBuffer(buf.writerIndex(), buf.writableBytes()); + } + + private int writeCharactersOnArray(String str, boolean checkForNullCharacters, ByteBuf buf) { + int i = 0; + byte[] array = buf.array(); + int pos = buf.position(); + int len = str.length(); + for (; i < len; i++) { + char c = str.charAt(i); + if (checkForNullCharacters && c == 0x0) { + throw new BsonSerializationException(format("BSON cstring '%s' is not valid because it contains a null character " + + "at index %d", str, i)); + } + if (c >= 0x80) { + break; + } + array[pos + i] = (byte) c; + } + if (i == len) { + int total = len + 1; + array[pos + len] = 0; + position += total; + buf.position(pos + total); + return len + 1; + } + // ith character is not ASCII + if (i > 0) { + position += i; + buf.position(pos + i); + } + return i + super.writeCharacters(str, i, checkForNullCharacters); + } + private static final class BufferPositionPair { private final int bufferIndex; private int position; diff --git a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java index fa8cde2e517..48a80be5158 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java @@ -208,6 +208,11 @@ private int getShort(final int index) { return (short) (get(index) & 0xff | (get(index + 1) & 0xff) << 8); } + @Override + public boolean hasArray() { + return false; + } + @Override public byte[] array() { throw new UnsupportedOperationException("Not implemented yet!"); diff --git a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java index 074e77de04f..d5c22e2a5ca 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java +++ b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java @@ -95,6 +95,10 @@ public ByteBuf flip() { return this; } + public boolean hasArray() { + return proxied.hasArray(); + } + @Override public byte[] array() { return proxied.array(); diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java index 3a8a2c83acb..b54a976ef93 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java @@ -17,18 +17,23 @@ package com.mongodb.internal.connection; import com.mongodb.assertions.Assertions; +import com.mongodb.internal.connection.ByteBufSpecification.NettyBufferProvider; import org.bson.BsonSerializationException; import org.bson.ByteBuf; import org.bson.types.ObjectId; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.CharacterCodingException; import java.nio.charset.StandardCharsets; import java.util.concurrent.ThreadLocalRandom; @@ -45,6 +50,85 @@ import static org.junit.jupiter.api.Assertions.assertThrows; final class ByteBufferBsonOutputTest { + + // Test all combinations: useBranch (true/false) × BufferProvider implementations + static Arguments[] bufferProviders() { + return new Arguments[]{ + Arguments.of(false, new SimpleBufferProvider()), + Arguments.of(true, new SimpleBufferProvider()), + Arguments.of(false, new NettyBufferProvider()), + Arguments.of(true, new NettyBufferProvider())}; + } + + /** Generic method to test writing strings **/ + + private static String expectedNullCharExceptionMessage(String value) { + if (value.isEmpty()) { + return null; + } + int zeroIndex = value.indexOf(0); + if (zeroIndex == -1) { + return null; + } + return "BSON cstring '" + value + "' is not valid because it contains a null character at index " + zeroIndex; + } + + private static void writeStringTest(BufferProvider bufferProvider, String value, boolean useBranch, boolean cstring) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(bufferProvider)) { + String expectedNullCharEx = null; + if (cstring) { + expectedNullCharEx = expectedNullCharExceptionMessage(value); + } + final byte[] expectedEncodedBytes; + if (expectedNullCharEx == null) { + expectedEncodedBytes = expectedStringBytesOf(value, cstring); + } else { + expectedEncodedBytes = null; + } + try { + if (useBranch) { + try (ByteBufferBsonOutput.Branch branch = out.branch()) { + if (cstring) { + branch.writeCString(value); + } else { + branch.writeString(value); + } + } + } else { + if (cstring) { + out.writeCString(value); + } else { + out.writeString(value); + } + } + if (expectedNullCharEx != null) { + Assertions.fail("Expected BsonSerializationException"); + } + } catch (BsonSerializationException e) { + if (expectedNullCharEx != null) { + assertEquals(expectedNullCharEx, e.getMessage()); + return; + } + } + assertArrayEquals(expectedEncodedBytes, out.toByteArray()); + assertEquals(expectedEncodedBytes.length, out.getPosition()); + assertEquals(expectedEncodedBytes.length, out.size()); + } + } + + private static @NotNull byte[] expectedStringBytesOf(String v, boolean cstring) { + byte[] encoded = v.getBytes(StandardCharsets.UTF_8); + ByteBuffer expected = ByteBuffer.allocate((cstring ? 0 : 4) + encoded.length + 1).order(ByteOrder.LITTLE_ENDIAN); + if (!cstring) { + expected.putInt((byte) (encoded.length + 1)); + } + expected.put(encoded); + expected.put((byte) 0); + return expected.array(); + } + + /** Tests **/ + @DisplayName("constructor should throw if buffer provider is null") @Test @SuppressWarnings("try") @@ -92,9 +176,9 @@ void positionAndSizeShouldBe0AfterConstructor(final String branchState) { @DisplayName("should write a byte") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldWriteByte(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + @MethodSource("bufferProviders") + void shouldWriteByte(final boolean useBranch, final BufferProvider bufferProvider) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(bufferProvider)) { byte v = 11; if (useBranch) { try (ByteBufferBsonOutput.Branch branch = out.branch()) { @@ -111,9 +195,9 @@ void shouldWriteByte(final boolean useBranch) { @DisplayName("should write a bytes") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldWriteBytes(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + @MethodSource("bufferProviders") + void shouldWriteBytes(final boolean useBranch, final BufferProvider bufferProvider) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(bufferProvider)) { byte[] v = {1, 2, 3, 4}; if (useBranch) { try (ByteBufferBsonOutput.Branch branch = out.branch()) { @@ -226,123 +310,51 @@ void shouldWriteObjectId(final boolean useBranch) { @DisplayName("should write an empty string") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldWriteEmptyString(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { - String v = ""; - if (useBranch) { - try (ByteBufferBsonOutput.Branch branch = out.branch()) { - branch.writeString(v); - } - } else { - out.writeString(v); - } - assertArrayEquals(new byte[] {1, 0, 0, 0, 0}, out.toByteArray()); - assertEquals(5, out.getPosition()); - assertEquals(5, out.size()); - } + @MethodSource("bufferProviders") + void shouldWriteEmptyString(final boolean useBranch, final BufferProvider bufferProvider) { + writeStringTest(bufferProvider, "", useBranch, false); } @DisplayName("should write an ASCII string") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldWriteAsciiString(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { - String v = "Java"; - if (useBranch) { - try (ByteBufferBsonOutput.Branch branch = out.branch()) { - branch.writeString(v); - } - } else { - out.writeString(v); - } - assertArrayEquals(new byte[] {5, 0, 0, 0, 0x4a, 0x61, 0x76, 0x61, 0}, out.toByteArray()); - assertEquals(9, out.getPosition()); - assertEquals(9, out.size()); - } + @MethodSource("bufferProviders") + void shouldWriteAsciiString(final boolean useBranch, final BufferProvider bufferProvider) { + writeStringTest(bufferProvider, "JavaIsACool\u0000Language", useBranch, false); } @DisplayName("should write a UTF-8 string") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldWriteUtf8String(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { - String v = "\u0900"; - if (useBranch) { - try (ByteBufferBsonOutput.Branch branch = out.branch()) { - branch.writeString(v); - } - } else { - out.writeString(v); - } - assertArrayEquals(new byte[] {4, 0, 0, 0, (byte) 0xe0, (byte) 0xa4, (byte) 0x80, 0}, out.toByteArray()); - assertEquals(8, out.getPosition()); - assertEquals(8, out.size()); - } + @MethodSource("bufferProviders") + void shouldWriteUtf8String(final boolean useBranch, final BufferProvider bufferProvider) { + writeStringTest(bufferProvider, "Java\u0080I\u0000sACool\u0900Language", useBranch, false); } @DisplayName("should write an empty CString") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldWriteEmptyCString(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { - String v = ""; - if (useBranch) { - try (ByteBufferBsonOutput.Branch branch = out.branch()) { - branch.writeCString(v); - } - } else { - out.writeCString(v); - } - assertArrayEquals(new byte[] {0}, out.toByteArray()); - assertEquals(1, out.getPosition()); - assertEquals(1, out.size()); - } + @MethodSource("bufferProviders") + void shouldWriteEmptyCString(final boolean useBranch, final BufferProvider bufferProvider) { + writeStringTest(bufferProvider, "", useBranch, true); } @DisplayName("should write an ASCII CString") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldWriteAsciiCString(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { - String v = "Java"; - if (useBranch) { - try (ByteBufferBsonOutput.Branch branch = out.branch()) { - branch.writeCString(v); - } - } else { - out.writeCString(v); - } - assertArrayEquals(new byte[] {0x4a, 0x61, 0x76, 0x61, 0}, out.toByteArray()); - assertEquals(5, out.getPosition()); - assertEquals(5, out.size()); - } + @MethodSource("bufferProviders") + void shouldWriteAsciiCString(final boolean useBranch, final BufferProvider bufferProvider) { + writeStringTest(bufferProvider, "JavaIsACoolLanguage", useBranch, true); } @DisplayName("should write a UTF-8 CString") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldWriteUtf8CString(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { - String v = "\u0900"; - if (useBranch) { - try (ByteBufferBsonOutput.Branch branch = out.branch()) { - branch.writeCString(v); - } - } else { - out.writeCString(v); - } - assertArrayEquals(new byte[] {(byte) 0xe0, (byte) 0xa4, (byte) 0x80, 0}, out.toByteArray()); - assertEquals(4, out.getPosition()); - assertEquals(4, out.size()); - } + @MethodSource("bufferProviders") + void shouldWriteUtf8CString(final boolean useBranch, final BufferProvider bufferProvider) { + writeStringTest(bufferProvider, "Java\u0080IsACool\u0900Language", useBranch, true); } @DisplayName("should get byte buffers as little endian") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldGetByteBuffersAsLittleEndian(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + @MethodSource("bufferProviders") + void shouldGetByteBuffersAsLittleEndian(final boolean useBranch, final BufferProvider bufferProvider) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(bufferProvider)) { byte[] v = {1, 0, 0, 0}; if (useBranch) { try (ByteBufferBsonOutput.Branch branch = out.branch()) { @@ -357,35 +369,23 @@ void shouldGetByteBuffersAsLittleEndian(final boolean useBranch) { @DisplayName("null character in CString should throw SerializationException") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void nullCharacterInCStringShouldThrowSerializationException(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { - String v = "hell\u0000world"; - if (useBranch) { - try (ByteBufferBsonOutput.Branch branch = out.branch()) { - assertThrows(BsonSerializationException.class, () -> branch.writeCString(v)); - } - } else { - assertThrows(BsonSerializationException.class, () -> out.writeCString(v)); - } - } + @MethodSource("bufferProviders") + void nullCharacterInCStringShouldThrowSerializationException(final boolean useBranch, final BufferProvider bufferProvider) { + writeStringTest(bufferProvider, "hello\u0000world", useBranch, true); + } + + @DisplayName("null character in UTF-8 CString should throw SerializationException") + @ParameterizedTest + @MethodSource("bufferProviders") + void nullCharacterInUtf8CStringShouldThrowSerializationException(final boolean useBranch, final BufferProvider bufferProvider) { + writeStringTest(bufferProvider, "hello\u0080\u0000world", useBranch, true); } @DisplayName("null character in String should not throw SerializationException") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void nullCharacterInStringShouldNotThrowSerializationException(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { - String v = "h\u0000i"; - if (useBranch) { - try (ByteBufferBsonOutput.Branch branch = out.branch()) { - branch.writeString(v); - } - } else { - out.writeString(v); - } - assertArrayEquals(new byte[] {4, 0, 0, 0, (byte) 'h', 0, (byte) 'i', 0}, out.toByteArray()); - } + @MethodSource("bufferProviders") + void nullCharacterInStringShouldNotThrowSerializationException(final boolean useBranch, final BufferProvider bufferProvider) { + writeStringTest(bufferProvider, "hello\u0000world", useBranch, false); } @DisplayName("write Int32 at position should throw with invalid position") @@ -409,9 +409,9 @@ void writeInt32AtPositionShouldThrowWithInvalidPosition(final boolean useBranch, @DisplayName("should write Int32 at position") @ParameterizedTest - @ValueSource(booleans = {false, true}) - void shouldWriteInt32AtPosition(final boolean useBranch) { - try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + @MethodSource("bufferProviders") + void shouldWriteInt32AtPosition(final boolean useBranch, final BufferProvider bufferProvider) { + try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(bufferProvider)) { Consumer