diff --git a/lance-spark-3.5_2.12/src/main/java/com/lancedb/lance/spark/write/SparkPositionDeltaWrite.java b/lance-spark-3.5_2.12/src/main/java/com/lancedb/lance/spark/write/SparkPositionDeltaWrite.java index 924b2536..f6f02892 100644 --- a/lance-spark-3.5_2.12/src/main/java/com/lancedb/lance/spark/write/SparkPositionDeltaWrite.java +++ b/lance-spark-3.5_2.12/src/main/java/com/lancedb/lance/spark/write/SparkPositionDeltaWrite.java @@ -121,7 +121,7 @@ private static class PositionDeltaWriteFactory implements DeltaWriterFactory { @Override public DeltaWriter createWriter(int partitionId, long taskId) { int batch_size = SparkOptions.getBatchSize(config); - LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(sparkSchema, batch_size); + LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(sparkSchema, batch_size, config); WriteParams params = SparkOptions.genWriteParamsFromConfig(config); Callable> fragmentCreator = () -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter, params); diff --git a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/SparkOptions.java b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/SparkOptions.java index 8b1ad90f..9e682768 100644 --- a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/SparkOptions.java +++ b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/SparkOptions.java @@ -30,6 +30,8 @@ public class SparkOptions { private static final String max_bytes_per_file = "max_bytes_per_file"; private static final String batch_size = "batch_size"; private static final String topN_push_down = "topN_push_down"; + private static final String arrow_max_allocation_bytes = "arrow.max.allocation.bytes"; + private static final String arrow_var_width_avg_bytes = "arrow.var.width.avg.bytes"; public static ReadOptions genReadOptionFromConfig(LanceConfig config) { ReadOptions.Builder builder = new ReadOptions.Builder(); @@ -88,4 +90,36 @@ public static boolean enableTopNPushDown(LanceConfig config) { public static boolean overwrite(LanceConfig config) { return config.getOptions().getOrDefault(write_mode, "append").equalsIgnoreCase("overwrite"); } + + /** + * Get the maximum allocation size for Arrow buffers in bytes. Default is Long.MAX_VALUE to allow + * allocations beyond Integer.MAX_VALUE (2GB). Can be configured to a smaller value if needed to + * limit memory usage. + * + * @param config Lance configuration + * @return Maximum allocation size in bytes + */ + public static long getArrowMaxAllocationBytes(LanceConfig config) { + Map options = config.getOptions(); + if (options.containsKey(arrow_max_allocation_bytes)) { + return Long.parseLong(options.get(arrow_max_allocation_bytes)); + } + return Long.MAX_VALUE; + } + + /** + * Get the average number of bytes per element for variable-width vectors (strings, binary). This + * is used to pre-allocate buffers to avoid frequent reallocations. Default is 64 bytes, which is + * a conservative estimate for most workloads. + * + * @param config Lance configuration + * @return Average bytes per variable-width element + */ + public static int getArrowVarWidthAvgBytes(LanceConfig config) { + Map options = config.getOptions(); + if (options.containsKey(arrow_var_width_avg_bytes)) { + return Integer.parseInt(options.get(arrow_var_width_avg_bytes)); + } + return 64; + } } diff --git a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java index 6efe6890..c92f189c 100644 --- a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java +++ b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java @@ -191,12 +191,14 @@ public static FragmentMetadata deleteRows( } } - public static LanceArrowWriter getArrowWriter(StructType sparkSchema, int batchSize) { + public static LanceArrowWriter getArrowWriter( + StructType sparkSchema, int batchSize, LanceConfig config) { return new LanceArrowWriter( allocator, LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false), sparkSchema, - batchSize); + batchSize, + SparkOptions.getArrowVarWidthAvgBytes(config)); } public static List createFragment( diff --git a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java index 00fde775..40ac1e42 100644 --- a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java +++ b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/LanceArrowWriter.java @@ -32,6 +32,7 @@ public class LanceArrowWriter extends ArrowReader { private final Schema schema; private final StructType sparkSchema; private final int batchSize; + private final int avgBytesPerVarWidthElement; @GuardedBy("monitor") private volatile boolean finished = false; @@ -43,13 +44,19 @@ public class LanceArrowWriter extends ArrowReader { private final Semaphore loadToken; public LanceArrowWriter( - BufferAllocator allocator, Schema schema, StructType sparkSchema, int batchSize) { + BufferAllocator allocator, + Schema schema, + StructType sparkSchema, + int batchSize, + int avgBytesPerVarWidthElement) { super(allocator); Preconditions.checkNotNull(schema); Preconditions.checkArgument(batchSize > 0); + Preconditions.checkArgument(avgBytesPerVarWidthElement > 0); this.schema = schema; this.sparkSchema = sparkSchema; this.batchSize = batchSize; + this.avgBytesPerVarWidthElement = avgBytesPerVarWidthElement; this.writeToken = new Semaphore(0); this.loadToken = new Semaphore(0); } @@ -79,7 +86,7 @@ public void prepareLoadNextBatch() throws IOException { super.prepareLoadNextBatch(); arrowWriter = com.lancedb.lance.spark.arrow.LanceArrowWriter$.MODULE$.create( - this.getVectorSchemaRoot(), sparkSchema); + this.getVectorSchemaRoot(), sparkSchema, batchSize, avgBytesPerVarWidthElement); // release batch size token for write writeToken.release(batchSize); } diff --git a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java index e8c7d905..833255d6 100644 --- a/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java +++ b/lance-spark-base_2.12/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java @@ -93,7 +93,7 @@ protected WriterFactory(StructType schema, LanceConfig config) { @Override public DataWriter createWriter(int partitionId, long taskId) { int batch_size = SparkOptions.getBatchSize(config); - LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, batch_size); + LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, batch_size, config); WriteParams params = SparkOptions.genWriteParamsFromConfig(config); Callable> fragmentCreator = () -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter, params); diff --git a/lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowWriter.scala b/lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowWriter.scala index 04c5e57e..d778b5c5 100644 --- a/lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowWriter.scala +++ b/lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowWriter.scala @@ -63,6 +63,58 @@ object LanceArrowWriter { new LanceArrowWriter(root, children.toArray) } + def create( + root: VectorSchemaRoot, + sparkSchema: StructType, + batchSize: Int, + avgBytesPerVarWidthElement: Int): LanceArrowWriter = { + val children = root.getFieldVectors().asScala.zipWithIndex.map { case (vector, index) => + val sparkField = sparkSchema.fields(index) + allocateVectorWithSize(vector, batchSize, sparkField.dataType, avgBytesPerVarWidthElement) + createFieldWriter(vector, sparkField.dataType, sparkField.metadata) + } + new LanceArrowWriter(root, children.toArray) + } + + private def allocateVectorWithSize( + vector: ValueVector, + batchSize: Int, + sparkType: DataType, + avgBytesPerVarWidthElement: Int): Unit = { + vector match { + // FixedSizeListVector: Calculate exact size = batchSize * listSize * elementSize + case fixedSizeList: FixedSizeListVector => + val listSize = fixedSizeList.getListSize() + val totalElements = batchSize * listSize + val dataVector = fixedSizeList.getDataVector() + + // Allocate the underlying data vector + dataVector match { + case fwv: FixedWidthVector => + fwv.allocateNew(totalElements) + case vwv: BaseVariableWidthVector => + // For variable-width elements in fixed-size lists, use configured estimate + vwv.allocateNew(totalElements * avgBytesPerVarWidthElement, totalElements) + case _ => + dataVector.allocateNew() + } + fixedSizeList.setValueCount(batchSize) + + // FixedWidthVectors: Allocate exact count + case fixedWidth: FixedWidthVector => + fixedWidth.allocateNew(batchSize) + + // VariableWidthVectors: Allocate with size estimate + case varWidth: BaseVariableWidthVector => + // Use configured average bytes per element + varWidth.allocateNew(batchSize * avgBytesPerVarWidthElement, batchSize) + + // Default: Use default allocation + case _ => + vector.allocateNew() + } + } + private[arrow] def createFieldWriter( vector: ValueVector, sparkType: DataType, diff --git a/lance-spark-base_2.12/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java b/lance-spark-base_2.12/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java index 2bea1dc4..61328a65 100644 --- a/lance-spark-base_2.12/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java +++ b/lance-spark-base_2.12/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java @@ -53,8 +53,10 @@ public void test() throws Exception { final int totalRows = 125; final int batchSize = 34; + final int avgBytesPerVarWidthElement = 64; final LanceArrowWriter arrowWriter = - new LanceArrowWriter(allocator, schema, sparkSchema, batchSize); + new LanceArrowWriter( + allocator, schema, sparkSchema, batchSize, avgBytesPerVarWidthElement); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0);