diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index 4c5e6739b3b..0202f2f5f21 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -277,6 +277,7 @@ private byte[] getOrCreateBuffer(int partitionId) { protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException { logger.debug("Push giant record, size {}.", numBytes); + long start = System.nanoTime(); int bytesWritten = shuffleClient.pushData( shuffleId, @@ -288,8 +289,10 @@ protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) thr numBytes, numMappers, numPartitions); + long delta = System.nanoTime() - start; mapStatusLengths[partitionId].add(bytesWritten); writeMetrics.incBytesWritten(bytesWritten); + writeMetrics.incWriteTime(delta); } private int getOrUpdateOffset(int partitionId, int serializedRecordSize) diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index 5717910eea9..3346deb2ad5 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -346,6 +346,7 @@ private void write0(scala.collection.Iterator iterator) throws IOException { private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException { logger.debug("Push giant record, size {}.", Utils.bytesToString(numBytes)); + long start = System.nanoTime(); int bytesWritten = shuffleClient.pushData( shuffleId, @@ -357,8 +358,10 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw numBytes, numMappers, numPartitions); + long delta = System.nanoTime() - start; mapStatusLengths[partitionId].add(bytesWritten); writeMetrics.incBytesWritten(bytesWritten); + writeMetrics.incWriteTime(delta); } private void cleanupPusher() throws IOException { diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java index 33eedd9d40d..62d31ee6d25 100644 --- a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java +++ b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java @@ -268,6 +268,7 @@ private void check( ShuffleWriteMetrics metrics = taskContext.taskMetrics().shuffleWriteMetrics(); assertEquals(metrics.recordsWritten(), total.intValue()); assertEquals(metrics.bytesWritten(), tempFile.length()); + assertTrue(metrics.writeTime() > 0); try (FileInputStream fis = new FileInputStream(tempFile)) { Iterator it = newSerializerInstance(serializer).deserializeStream(fis).asKeyValueIterator();