diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake
index 180b33f15f12..b82949285e78 100644
--- a/cmake/Utils.cmake
+++ b/cmake/Utils.cmake
@@ -120,6 +120,7 @@ function(xgboost_set_cuda_flags target)
     # need to link with CCCL and CUDA runtime.
     target_link_libraries(${target} PRIVATE CCCL::CCCL CUDA::cudart_static)
   endif()
+  target_link_libraries(${target} PRIVATE CUDA::nvml)
   target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1)
   target_include_directories(
     ${target} PRIVATE
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 16347e85f1fc..0ba848211979 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -5,7 +5,7 @@
     4.0.0
 
     ml.dmlc
-    xgboost-jvm_2.12
+    xgboost-jvm_2.13
     3.1.0-SNAPSHOT
     pom
     XGBoost JVM Package
@@ -43,15 +43,15 @@
     
         UTF-8
         UTF-8
-        1.8
-        1.8
+        17
+        17
         1.20.0
         4.13.2
-        3.5.3
-        3.5.1
+        4.0.0
+        4.0.0
         2.15.0
-        2.12.18
-        2.12
+        2.13.11
+        2.13
         3.4.1
         5
         OFF
@@ -89,6 +89,17 @@
             central maven
             https://repo1.maven.org/maven2
         
+      
+        apache-snapshots
+        https://repository.apache.org/content/repositories/snapshots/
+        
+          true
+        
+        
+          false
+        
+      
+
     
     
     
diff --git a/jvm-packages/xgboost4j-example/pom.xml b/jvm-packages/xgboost4j-example/pom.xml
index 9a8408124c63..c2d83dcca753 100644
--- a/jvm-packages/xgboost4j-example/pom.xml
+++ b/jvm-packages/xgboost4j-example/pom.xml
@@ -5,11 +5,11 @@
     4.0.0
     
         ml.dmlc
-        xgboost-jvm_2.12
+        xgboost-jvm_2.13
         3.1.0-SNAPSHOT
     
     xgboost4j-example
-    xgboost4j-example_2.12
+    xgboost4j-example_2.13
     3.1.0-SNAPSHOT
     jar
     
@@ -26,7 +26,7 @@
     
         
             ml.dmlc
-            xgboost4j-spark_2.12
+            xgboost4j-spark_2.13
             ${project.version}
         
         
@@ -37,7 +37,7 @@
         
         
             ml.dmlc
-            xgboost4j-flink_2.12
+            xgboost4j-flink_2.13
             ${project.version}
         
         
diff --git a/jvm-packages/xgboost4j-flink/pom.xml b/jvm-packages/xgboost4j-flink/pom.xml
index 96fe0563d499..6073afe49a7b 100644
--- a/jvm-packages/xgboost4j-flink/pom.xml
+++ b/jvm-packages/xgboost4j-flink/pom.xml
@@ -5,12 +5,12 @@
     4.0.0
     
         ml.dmlc
-        xgboost-jvm_2.12
+        xgboost-jvm_2.13
         3.1.0-SNAPSHOT
     
 
     xgboost4j-flink
-    xgboost4j-flink_2.12
+    xgboost4j-flink_2.13
     3.1.0-SNAPSHOT
     
       2.2.0
@@ -30,7 +30,7 @@
     
         
             ml.dmlc
-            xgboost4j_2.12
+            xgboost4j_2.13
             ${project.version}
         
         
diff --git a/jvm-packages/xgboost4j-spark-gpu/pom.xml b/jvm-packages/xgboost4j-spark-gpu/pom.xml
index 87d6ab78e04f..0c6bc470028d 100644
--- a/jvm-packages/xgboost4j-spark-gpu/pom.xml
+++ b/jvm-packages/xgboost4j-spark-gpu/pom.xml
@@ -5,11 +5,11 @@
     4.0.0
     
         ml.dmlc
-        xgboost-jvm_2.12
+        xgboost-jvm_2.13
         3.1.0-SNAPSHOT
     
     xgboost4j-spark-gpu
-    xgboost4j-spark-gpu_2.12
+    xgboost4j-spark-gpu_2.13
     JVM Package for XGBoost
     https://github.com/dmlc/xgboost/tree/master/jvm-packages
     
@@ -78,17 +78,17 @@
     
         
             ml.dmlc
-            xgboost4j_2.12
+            xgboost4j_2.13
             ${project.version}
         
         
             ml.dmlc
-            xgboost4j-spark_2.12
+            xgboost4j-spark_2.13
             ${project.version}
             
               
                   ml.dmlc
-                  xgboost4j_2.12
+                  xgboost4j_2.13
               
             
         
diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/ExtMemQuantileDMatrix.java b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/ExtMemQuantileDMatrix.java
index 6c9868be8473..0ebaa41e8959 100644
--- a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/ExtMemQuantileDMatrix.java
+++ b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/ExtMemQuantileDMatrix.java
@@ -31,7 +31,8 @@ public ExtMemQuantileDMatrix(Iterator iter,
       DMatrix ref,
       int nthread,
       int maxQuantileBatches,
-      int minCachePageBytes) throws XGBoostError {
+      long minCachePageBytes,
+      float cacheHostRatio) throws XGBoostError {
     long[] out = new long[1];
     long[] refHandle = null;
     if (ref != null) {
@@ -39,7 +40,7 @@ public ExtMemQuantileDMatrix(Iterator iter,
       refHandle[0] = ref.getHandle();
     }
     String conf = this.getConfig(missing, maxBin, nthread,
-                                 maxQuantileBatches, minCachePageBytes);
+                                 maxQuantileBatches, minCachePageBytes, cacheHostRatio);
     XGBoostJNI.checkCall(XGBoostJNI.XGExtMemQuantileDMatrixCreateFromCallback(
         iter, refHandle, conf, out));
     handle = out[0];
@@ -50,7 +51,7 @@ public ExtMemQuantileDMatrix(
       float missing,
       int maxBin,
       DMatrix ref) throws XGBoostError {
-    this(iter, missing, maxBin, ref, 0, -1, -1);
+    this(iter, missing, maxBin, ref, 0, -1, -1, -1.0f);
   }
 
   public ExtMemQuantileDMatrix(
@@ -61,19 +62,25 @@ public ExtMemQuantileDMatrix(
   }
 
   private String getConfig(float missing, int maxBin, int nthread,
-                           int maxQuantileBatches, int minCachePageBytes) {
+                           int maxQuantileBatches, long minCachePageBytes, float cacheHostRatio) {
     Map conf = new java.util.HashMap<>();
     conf.put("missing", missing);
     conf.put("max_bin", maxBin);
     conf.put("nthread", nthread);
 
     if (maxQuantileBatches > 0) {
-      conf.put("max_quantile_batches", maxQuantileBatches);
+      conf.put("max_quantile_blocks", maxQuantileBatches);
     }
+    System.err.println("minCachePageBytes");
+    System.err.println(minCachePageBytes);
     if (minCachePageBytes > 0) {
       conf.put("min_cache_page_bytes", minCachePageBytes);
     }
 
+    if (cacheHostRatio > 0.0 && cacheHostRatio <= 1.0) {
+      conf.put("cache_host_ratio", cacheHostRatio);
+    }
+
     conf.put("on_host", true);
     conf.put("cache_prefix", ".");
     ObjectMapper mapper = new ObjectMapper();
diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/ExtMemQuantileDMatrix.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/ExtMemQuantileDMatrix.scala
index 6c870ad06299..d6cd447fde8c 100644
--- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/ExtMemQuantileDMatrix.scala
+++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/ExtMemQuantileDMatrix.scala
@@ -28,10 +28,11 @@ class ExtMemQuantileDMatrix private[scala](
            ref: Option[QuantileDMatrix],
            nthread: Int,
            maxQuantileBatches: Int,
-           minCachePageBytes: Int) {
+           minCachePageBytes: Long,
+           cacheHostRatio: Float) {
     this(new jExtMemQuantileDMatrix(iter.asJava, missing, maxBin,
       ref.map(_.jDMatrix).orNull,
-      nthread, maxQuantileBatches, minCachePageBytes))
+      nthread, maxQuantileBatches, minCachePageBytes, cacheHostRatio))
   }
 
   def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int) {
diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ExternalMemory.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ExternalMemory.scala
index 735941e679c9..dd28022684d6 100644
--- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ExternalMemory.scala
+++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ExternalMemory.scala
@@ -18,10 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark
 
 import java.io.File
 import java.nio.file.{Files, Paths}
+import java.util.concurrent.Executors
 
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.duration.DurationInt
 
 import ai.rapids.cudf._
+import org.apache.commons.logging.LogFactory
 
 import ml.dmlc.xgboost4j.java.{ColumnBatch, CudfColumnBatch}
 import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
@@ -61,15 +66,26 @@ private[spark] trait ExternalMemory[T] extends Iterator[Table] with AutoCloseabl
 }
 
 // The data will be cached into disk.
-private[spark] class DiskExternalMemoryIterator(val path: String) extends ExternalMemory[String] {
+private[spark] class DiskExternalMemoryIterator(val parent: String,
+                                                val cacheBatchNumber: Int = 1)
+  extends ExternalMemory[String] {
+
+  private val logger = LogFactory.getLog("XGBoostSparkGpuPlugin")
 
   private lazy val root = {
-    val tmp = path + "/xgboost"
+    val tmp = parent + "/xgboost"
     createDirectory(tmp)
     tmp
   }
 
-  private var counter = 0
+  logger.info(s"DiskExternalMemoryIterator createDirectory $root")
+
+  // Tasks mapping the path to the Future of caching table
+  private val cachingTasksFutures: mutable.HashMap[String, Future[Boolean]] = mutable.HashMap.empty
+  private val executor = Executors.newFixedThreadPool(cacheBatchNumber)
+  implicit val ec = ExecutionContext.fromExecutor(executor)
+
+  private var fileCounter = 0
 
   private def createDirectory(dirPath: String): Unit = {
     val path = Paths.get(dirPath)
@@ -78,6 +94,40 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
     }
   }
 
+  /**
+   * Cache the table into disk which runs in a separate thread
+   *
+   * @param table to be cached
+   * @param path where to cache the table
+   */
+  private def cacheTableThread(table: Table, path: String): Future[Boolean] = {
+    Future {
+      try {
+        val rows = table.getRowCount
+        val size = rows * table.getNumberOfColumns * 4 / 1024 / 1024
+        logger.info(s"cacheTableThread begin to cache table (rows: $rows, " +
+          s"size: ${size}M) to $path")
+        val names = (1 to table.getNumberOfColumns).map(_.toString)
+        val options = ArrowIPCWriterOptions.builder()
+          .withCallback((t: Table) => {
+            logger.info(s"=========> Close table. Data has been offloaded to host." +
+              s"will  cache to $path <============")
+            t.close()}
+          )
+          .withColumnNames(names: _*).build()
+        withResource(Table.writeArrowIPCChunked(options, new File(path))) { writer =>
+          writer.write(table)
+        }
+        logger.info(s"cacheTableThread Finished caching table (rows: $rows, " +
+          s"size: ${size}M) to $path ================> Done")
+        true
+      } catch {
+        case e: Throwable =>
+          throw e
+      }
+    }
+  }
+
   /**
    * Convert the table to file path which will be cached
    *
@@ -85,13 +135,24 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
    * @return the content
    */
   override def convertTable(table: Table): String = {
-    val names = (1 to table.getNumberOfColumns).map(_.toString)
-    val options = ArrowIPCWriterOptions.builder().withColumnNames(names: _*).build()
-    val path = root + "/table_" + counter + "_" + System.nanoTime();
-    counter += 1
-    withResource(Table.writeArrowIPCChunked(options, new File(path))) { writer =>
-      writer.write(table)
+    val index = fileCounter - cacheBatchNumber
+    if (index >= 0 && index < buffers.length) {
+      checkAndWaitCachingDone(buffers(index))
+      logger.info(s"Waiting for ${buffers(index)} done")
     }
+
+    val path = root + "/table_" + fileCounter + "_" + System.nanoTime()
+    fileCounter += 1
+
+    val rows = table.getRowCount
+    val size = rows * table.getNumberOfColumns * 4 / 1024 / 1024
+    logger.info(s"Intend to cache table (rows: $rows, " +
+      s"size: ${size}M) to $path")
+
+    // Increase the reference count of columnars to avoid being recycled
+    val newTable = new Table((0 until table.getNumberOfColumns).map(table.getColumn): _*)
+    val future = cacheTableThread(newTable, path)
+    cachingTasksFutures += (path -> future)
     path
   }
 
@@ -106,19 +167,34 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
     }
   }
 
+  private def checkAndWaitCachingDone(path: String): Unit = {
+    val futureOpt = cachingTasksFutures.get(path)
+    if (futureOpt.isEmpty) {
+      throw new RuntimeException(s"Failed to find the caching process for $path")
+    }
+    // Wait 6s to check if the caching is done.
+    // TODO, make it configurable
+    // If timeout, it's going to throw exception
+    val success = Await.result(futureOpt.get, 20.seconds)
+    if (!success) { // Failed to cache
+      throw new RuntimeException(s"Failed to cache table to $path")
+    }
+  }
+
   /**
    * Load the path from disk to the Table
    *
-   * @param name to be loaded
+   * @param path to be loaded
    * @return Table
    */
-  override def loadTable(name: String): Table = {
-    val file = new File(name)
-    if (!file.exists()) {
-      throw new RuntimeException(s"The cache file ${name} doesn't exist" )
-    }
+  override def loadTable(path: String): Table = {
+    val file = new File(path)
+
+    logger.info(s"loadTable to table from to $path")
     try {
-      withResource(Table.readArrowIPCChunked(file)) { reader =>
+      checkAndWaitCachingDone(path)
+
+      val t = withResource(Table.readArrowIPCChunked(file)) { reader =>
         val tables = ArrayBuffer.empty[Table]
         closeOnExcept(tables) { tables =>
           var table = Option(reader.getNextIfAvailable())
@@ -135,6 +211,10 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
           tables(0)
         }
       }
+      val rows = t.getRowCount
+      val size = rows * t.getNumberOfColumns * 4 / 1024 / 1024
+      logger.info(s"loadTable done to load to table (rows: $rows, size: $size) from to $path")
+      t
     } catch {
       case e: Throwable =>
         close()
@@ -147,6 +227,7 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
   }
 
   override def close(): Unit = {
+    executor.shutdown()
     buffers.foreach { path =>
       val file = new File(path)
       if (file.exists()) {
@@ -158,8 +239,9 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
 }
 
 private[spark] object ExternalMemory {
-  def apply(path: Option[String] = None): ExternalMemory[_] = {
-    path.map(new DiskExternalMemoryIterator(_))
+  def apply(path: Option[String] = None,
+            cacheBatchNumber: Int = 1): ExternalMemory[_] = {
+    path.map(new DiskExternalMemoryIterator(_, cacheBatchNumber))
       .getOrElse(throw new RuntimeException("No disk path provided"))
   }
 }
@@ -169,7 +251,7 @@ private[spark] object ExternalMemory {
  *
  * The first round iteration gets the input batch that will be
  *   1. cached in the external memory
- *      2. fed in QuantilDmatrix
+ *   2. fed in QuantileDMatrix
  *      The second round iteration returns the cached batch got from external memory.
  *
  * @param input   the spark input iterator
@@ -177,7 +259,8 @@ private[spark] object ExternalMemory {
  */
 private[scala] class ExternalMemoryIterator(val input: Iterator[Table],
                                             val indices: ColumnIndices,
-                                            val path: Option[String] = None)
+                                            val path: Option[String] = None,
+                                            val cacheBatchNumber: Int = 1)
   extends Iterator[ColumnBatch] {
 
   private var iter = input
@@ -188,7 +271,7 @@ private[scala] class ExternalMemoryIterator(val input: Iterator[Table],
   private var inputNextIsCalled = false
 
   // visible for testing
-  private[spark] val externalMemory = ExternalMemory(path)
+  private[spark] val externalMemory = ExternalMemory(path, cacheBatchNumber)
 
   override def hasNext: Boolean = {
     val value = iter.hasNext
diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala
index 01a8842e82b4..4261e2491c76 100644
--- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala
+++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala
@@ -134,6 +134,12 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
 
     val maxQuantileBatches = estimator.getMaxQuantileBatches
     val minCachePageBytes = estimator.getMinCachePageBytes
+    val cacheBatchNumber = estimator.getCacheBatchNumber
+    val cacheHostRatio = estimator.getCacheHostRatio
+
+    logger.info(s"maxQuantileBatches: $maxQuantileBatches, " +
+      s"minCachePageBytes: $minCachePageBytes cacheBatchNumber: $cacheBatchNumber " +
+      s"cacheHostRatio: $cacheHostRatio")
 
     /** build QuantileDMatrix on the executor side */
     def buildQuantileDMatrix(input: Iterator[Table],
@@ -141,9 +147,9 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
 
       extMemPath match {
         case Some(_) =>
-          val itr = new ExternalMemoryIterator(input, indices, extMemPath)
+          val itr = new ExternalMemoryIterator(input, indices, extMemPath, cacheBatchNumber)
           new ExtMemQuantileDMatrix(itr, missing, maxBin, ref, nthread,
-            maxQuantileBatches, minCachePageBytes)
+            maxQuantileBatches, minCachePageBytes, cacheHostRatio)
 
         case None =>
           val itr = input.map { table =>
@@ -188,7 +194,6 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
 
     val sconf = dataset.sparkSession.conf
     val rmmEnabled: Boolean = try {
-      sconf.get("spark.rapids.memory.gpu.pooling.enabled").toBoolean &&
       sconf.get("spark.rapids.memory.gpu.pool").trim.toLowerCase != "none"
     } catch {
       case _: Throwable => false // Any exception will return false
diff --git a/jvm-packages/xgboost4j-spark/pom.xml b/jvm-packages/xgboost4j-spark/pom.xml
index 904c97a08bcd..8577820a4aba 100644
--- a/jvm-packages/xgboost4j-spark/pom.xml
+++ b/jvm-packages/xgboost4j-spark/pom.xml
@@ -5,11 +5,11 @@
     4.0.0
     
         ml.dmlc
-        xgboost-jvm_2.12
+        xgboost-jvm_2.13
         3.1.0-SNAPSHOT
     
     xgboost4j-spark
-    xgboost4j-spark_2.12
+    xgboost4j-spark_2.13
     
         
             
@@ -46,7 +46,7 @@
     
         
             ml.dmlc
-            xgboost4j_2.12
+            xgboost4j_2.13
             ${project.version}
         
         
diff --git a/jvm-packages/xgboost4j-spark/python/pyproject.toml b/jvm-packages/xgboost4j-spark/python/pyproject.toml
new file mode 100644
index 000000000000..d5853f5a0c0d
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/python/pyproject.toml
@@ -0,0 +1,50 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+[project]
+name = "xgboost4j"
+version = "3.1.0"
+authors = [
+  { name = "Bobby Wang", email = "wbo4958@gmail.com" },
+]
+description = "XGBoost4j-Spark pyspark"
+readme = "README.md"
+requires-python = ">=3.10"
+classifiers = [
+  "Programming Language :: Python :: 3",
+  "Programming Language :: Python :: 3.10",
+  "Programming Language :: Python :: 3.11",
+  "Programming Language :: Python :: 3.12",
+  "License :: OSI Approved :: Apache Software License",
+  "Operating System :: OS Independent",
+  "Environment :: GPU :: NVIDIA CUDA :: 11",
+  "Environment :: GPU :: NVIDIA CUDA :: 11.4",
+  "Environment :: GPU :: NVIDIA CUDA :: 11.5",
+  "Environment :: GPU :: NVIDIA CUDA :: 11.6",
+  "Environment :: GPU :: NVIDIA CUDA :: 11.7",
+  "Environment :: GPU :: NVIDIA CUDA :: 11.8",
+  "Environment :: GPU :: NVIDIA CUDA :: 12",
+  "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.0",
+  "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1",
+  "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.2",
+  "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.3",
+  "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.4",
+  "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.5",
+  "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.6",
+  "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.8",
+]
+
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
diff --git a/jvm-packages/xgboost4j-spark/python/src/ml/__init__.py b/jvm-packages/xgboost4j-spark/python/src/ml/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/jvm-packages/xgboost4j-spark/python/src/ml/dmlc/__init__.py b/jvm-packages/xgboost4j-spark/python/src/ml/dmlc/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/jvm-packages/xgboost4j-spark/python/src/ml/dmlc/xgboost4j/__init__.py b/jvm-packages/xgboost4j-spark/python/src/ml/dmlc/xgboost4j/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/jvm-packages/xgboost4j-spark/python/src/ml/dmlc/xgboost4j/scala/__init__.py b/jvm-packages/xgboost4j-spark/python/src/ml/dmlc/xgboost4j/scala/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/jvm-packages/xgboost4j-spark/python/src/ml/dmlc/xgboost4j/scala/spark/__init__.py b/jvm-packages/xgboost4j-spark/python/src/ml/dmlc/xgboost4j/scala/spark/__init__.py
new file mode 100644
index 000000000000..86f5356d3aba
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/python/src/ml/dmlc/xgboost4j/scala/spark/__init__.py
@@ -0,0 +1,5 @@
+import sys
+
+import xgboost4j
+
+sys.modules["ml.dmlc.xgboost4j.scala.spark"] = xgboost4j
diff --git a/jvm-packages/xgboost4j-spark/python/src/xgboost4j/__init__.py b/jvm-packages/xgboost4j-spark/python/src/xgboost4j/__init__.py
new file mode 100644
index 000000000000..ec13e879aaa0
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/python/src/xgboost4j/__init__.py
@@ -0,0 +1,5 @@
+from .estimator import XGBoostClassificationModel, XGBoostClassifier
+
+__version__ = "3.0.0"
+
+__all__ = ["XGBoostClassifier", "XGBoostClassificationModel"]
diff --git a/jvm-packages/xgboost4j-spark/python/src/xgboost4j/estimator.py b/jvm-packages/xgboost4j-spark/python/src/xgboost4j/estimator.py
new file mode 100644
index 000000000000..1680452dc4af
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/python/src/xgboost4j/estimator.py
@@ -0,0 +1,105 @@
+from typing import Union, List, Any, Optional, Dict
+
+from pyspark import keyword_only
+from pyspark.ml.classification import _JavaProbabilisticClassifier, _JavaProbabilisticClassificationModel
+
+from .params import XGBoostParams
+
+
+class XGBoostClassifier(_JavaProbabilisticClassifier["XGBoostClassificationModel"], XGBoostParams):
+    _input_kwargs: Dict[str, Any]
+
+    @keyword_only
+    def __init__(
+        self,
+        *,
+        featuresCol: Union[str, List[str]] = "features",
+        labelCol: str = "label",
+        predictionCol: str = "prediction",
+        probabilityCol: str = "probability",
+        rawPredictionCol: str = "rawPrediction",
+        # SparkParams
+        numWorkers: Optional[int] = None,
+        numRound: Optional[int] = None,
+        forceRepartition: Optional[bool] = None,
+        numEarlyStoppingRounds: Optional[int] = None,
+        inferBatchSize: Optional[int] = None,
+        missing: Optional[float] = None,
+        useExternalMemory: Optional[bool] = None,
+        maxNumDevicePages: Optional[int] = None,
+        maxQuantileBatches: Optional[int] = None,
+        minCachePageBytes: Optional[int] = None,
+        cacheBatchNumber: Optional[int] = None,
+        cacheHostRatio: Optional[float] = None,
+        feature_names: Optional[List[str]] = None,
+        feature_types: Optional[List[str]] = None,
+        # RabitParams
+        rabitTrackerTimeout: Optional[int] = None,
+        rabitTrackerHostIp: Optional[str] = None,
+        rabitTrackerPort: Optional[int] = None,
+        # GeneralParams
+        booster: Optional[str] = None,
+        device: Optional[str] = None,
+        verbosity: Optional[int] = None,
+        validate_parameters: Optional[bool] = None,
+        nthread: Optional[int] = None,
+        # TreeBoosterParams
+        eta: Optional[float] = None,
+        gamma: Optional[float] = None,
+        max_depth: Optional[int] = None,
+        min_child_weight: Optional[float] = None,
+        max_delta_step: Optional[float] = None,
+        subsample: Optional[float] = None,
+        sampling_method: Optional[str] = None,
+        colsample_bytree: Optional[float] = None,
+        colsample_bylevel: Optional[float] = None,
+        colsample_bynode: Optional[float] = None,
+        reg_lambda: Optional[float] = None,
+        alpha: Optional[float] = None,
+        tree_method: Optional[str] = None,
+        scale_pos_weight: Optional[float] = None,
+        updater: Optional[str] = None,
+        refresh_leaf: Optional[bool] = None,
+        process_type: Optional[str] = None,
+        grow_policy: Optional[str] = None,
+        max_leaves: Optional[int] = None,
+        max_bin: Optional[int] = None,
+        num_parallel_tree: Optional[int] = None,
+        monotone_constraints: Optional[List[int]] = None,
+        interaction_constraints: Optional[str] = None,
+        max_cached_hist_node: Optional[int] = None,
+        # LearningTaskParams
+        objective: Optional[str] = None,
+        num_class: Optional[int] = None,
+        base_score: Optional[float] = None,
+        eval_metric: Optional[str] = None,
+        seed: Optional[int] = None,
+        seed_per_iteration: Optional[bool] = None,
+        tweedie_variance_power: Optional[float] = None,
+        huber_slope: Optional[float] = None,
+        aft_loss_distribution: Optional[str] = None,
+        lambdarank_pair_method: Optional[str] = None,
+        lambdarank_num_pair_per_sample: Optional[int] = None,
+        lambdarank_unbiased: Optional[bool] = None,
+        lambdarank_bias_norm: Optional[float] = None,
+        ndcg_exp_gain: Optional[bool] = None,
+        # DartBoosterParams
+        sample_type: Optional[str] = None,
+        normalize_type: Optional[str] = None,
+        rate_drop: Optional[float] = None,
+        one_drop: Optional[bool] = None,
+        skip_drop: Optional[float] = None,
+        **kwargs: Any,
+    ):
+        super().__init__()
+        self._java_obj = self._new_java_obj(
+            "ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier", self.uid
+        )
+        self._set_params(**self._input_kwargs)
+
+    def _create_model(self, java_model: "JavaObject") -> "XGBoostClassificationModel":
+        return XGBoostClassificationModel(java_model)
+
+
+class XGBoostClassificationModel(_JavaProbabilisticClassificationModel, XGBoostParams):
+    pass
diff --git a/jvm-packages/xgboost4j-spark/python/src/xgboost4j/params.py b/jvm-packages/xgboost4j-spark/python/src/xgboost4j/params.py
new file mode 100644
index 000000000000..733ede929125
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/python/src/xgboost4j/params.py
@@ -0,0 +1,935 @@
+from pyspark.ml.param import Param, Params, TypeConverters
+from typing import List, TypeVar, Any
+
+P = TypeVar("P", bound=Params)
+
+
+class DartBoosterParams(Params):
+    """
+    Parameters specific to the DART (Dropout Additive Regression Trees) boosting algorithm.
+    """
+
+    sampleType = Param(
+        Params._dummy(),
+        "sample_type",
+        "Type of sampling algorithm, options: {'uniform', 'weighted'}",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getSampleType(self) -> str:
+        """Gets the value of sampleType or its default value."""
+        return self.getOrDefault(self.sampleType)
+
+    normalizeType = Param(
+        Params._dummy(),
+        "normalize_type",
+        "Type of normalization algorithm, options: {'tree', 'forest'}",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getNormalizeType(self) -> str:
+        """Gets the value of normalizeType or its default value."""
+        return self.getOrDefault(self.normalizeType)
+
+    rateDrop = Param(
+        Params._dummy(),
+        "rate_drop",
+        "Dropout rate (a fraction of previous trees to drop during the dropout)",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getRateDrop(self) -> float:
+        """Gets the value of rateDrop or its default value."""
+        return float(self.getOrDefault(self.rateDrop))
+
+    oneDrop = Param(
+        Params._dummy(),
+        "one_drop",
+        "When this flag is enabled, at least one tree is always dropped during the dropout "
+        "(allows Binomial-plus-one or epsilon-dropout from the original DART paper)",
+        typeConverter=TypeConverters.toBoolean
+    )
+
+    def getOneDrop(self) -> bool:
+        """Gets the value of oneDrop or its default value."""
+        return bool(self.getOrDefault(self.oneDrop))
+
+    skipDrop = Param(
+        Params._dummy(),
+        "skip_drop",
+        "Probability of skipping the dropout procedure during a boosting iteration.\n"
+        "If a dropout is skipped, new trees are added in the same manner as gbtree.\n"
+        "Note that non-zero skip_drop has higher priority than rate_drop or one_drop.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getSkipDrop(self) -> float:
+        """Gets the value of skipDrop or its default value."""
+        return float(self.getOrDefault(self.skipDrop))
+
+    def __init__(self):
+        super(DartBoosterParams, self).__init__()
+        self._setDefault(
+            sampleType="uniform",
+            normalizeType="tree",
+            rateDrop=0,
+            skipDrop=0
+        )
+
+
+class GeneralParams(Params):
+    """
+    General parameters for XGBoost.
+    """
+
+    booster = Param(
+        Params._dummy(),
+        "booster",
+        "Which booster to use. Can be gbtree, gblinear or dart; gbtree and dart use tree "
+        "based models while gblinear uses linear functions.",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getBooster(self) -> str:
+        """Gets the value of booster or its default value."""
+        return self.getOrDefault(self.booster)
+
+    device = Param(
+        Params._dummy(),
+        "device",
+        "Device for XGBoost to run. User can set it to one of the following values: "
+        "{cpu, cuda, gpu}",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getDevice(self) -> str:
+        """Gets the value of device or its default value."""
+        return self.getOrDefault(self.device)
+
+    verbosity = Param(
+        Params._dummy(),
+        "verbosity",
+        "Verbosity of printing messages. Valid values are 0 (silent), 1 (warning), "
+        "2 (info), 3 (debug). Sometimes XGBoost tries to change configurations based "
+        "on heuristics, which is displayed as warning message. If there's unexpected "
+        "behaviour, please try to increase value of verbosity.",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getVerbosity(self) -> int:
+        """Gets the value of verbosity or its default value."""
+        return int(self.getOrDefault(self.verbosity))
+
+    validateParameters = Param(
+        Params._dummy(),
+        "validate_parameters",
+        "When set to True, XGBoost will perform validation of input parameters to check "
+        "whether a parameter is used or not. A warning is emitted when there's unknown parameter.",
+        typeConverter=TypeConverters.toBoolean
+    )
+
+    def getValidateParameters(self) -> bool:
+        """Gets the value of validateParameters or its default value."""
+        return bool(self.getOrDefault(self.validateParameters))
+
+    nthread = Param(
+        Params._dummy(),
+        "nthread",
+        "Number of threads used by per worker",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getNthread(self) -> int:
+        """Gets the value of nthread or its default value."""
+        return int(self.getOrDefault(self.nthread))
+
+    def __init__(self):
+        super(GeneralParams, self).__init__()
+        self._setDefault(
+            booster="gbtree",
+            device="cpu",
+            verbosity=1,
+            validateParameters=False,
+            nthread=0
+        )
+
+
+class LearningTaskParams(Params):
+    """
+    Parameters related to the learning task for XGBoost models.
+    """
+
+    objective = Param(
+        Params._dummy(),
+        "objective",
+        "Objective function used for training",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getObjective(self) -> str:
+        """Gets the value of objective or its default value."""
+        return self.getOrDefault(self.objective)
+
+    numClass = Param(
+        Params._dummy(),
+        "num_class",
+        "Number of classes, used by multi:softmax and multi:softprob objectives",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getNumClass(self) -> int:
+        """Gets the value of numClass or its default value."""
+        return int(self.getOrDefault(self.numClass))
+
+    baseScore = Param(
+        Params._dummy(),
+        "base_score",
+        "The initial prediction score of all instances, global bias. The parameter is "
+        "automatically estimated for selected objectives before training. To disable "
+        "the estimation, specify a real number argument. For sufficient number of "
+        "iterations, changing this value will not have too much effect.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getBaseScore(self) -> float:
+        """Gets the value of baseScore or its default value."""
+        return float(self.getOrDefault(self.baseScore))
+
+    evalMetric = Param(
+        Params._dummy(),
+        "eval_metric",
+        "Evaluation metrics for validation data, a default metric will be assigned "
+        "according to objective (rmse for regression, and logloss for classification, "
+        "mean average precision for rank:map, etc.) User can add multiple evaluation "
+        "metrics. Python users: remember to pass the metrics in as list of parameters "
+        "pairs instead of map, so that latter eval_metric won't override previous ones",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getEvalMetric(self) -> str:
+        """Gets the value of evalMetric or its default value."""
+        return self.getOrDefault(self.evalMetric)
+
+    seed = Param(
+        Params._dummy(),
+        "seed",
+        "Random number seed.",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getSeed(self) -> int:
+        """Gets the value of seed or its default value."""
+        return int(self.getOrDefault(self.seed))
+
+    seedPerIteration = Param(
+        Params._dummy(),
+        "seed_per_iteration",
+        "Seed PRNG determnisticly via iterator number.",
+        typeConverter=TypeConverters.toBoolean
+    )
+
+    def getSeedPerIteration(self) -> bool:
+        """Gets the value of seedPerIteration or its default value."""
+        return bool(self.getOrDefault(self.seedPerIteration))
+
+    tweedieVariancePower = Param(
+        Params._dummy(),
+        "tweedie_variance_power",
+        "Parameter that controls the variance of the Tweedie distribution "
+        "var(y) ~ E(y)^tweedie_variance_power.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getTweedieVariancePower(self) -> float:
+        """Gets the value of tweedieVariancePower or its default value."""
+        return float(self.getOrDefault(self.tweedieVariancePower))
+
+    huberSlope = Param(
+        Params._dummy(),
+        "huber_slope",
+        "A parameter used for Pseudo-Huber loss to define the (delta) term.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getHuberSlope(self) -> float:
+        """Gets the value of huberSlope or its default value."""
+        return float(self.getOrDefault(self.huberSlope))
+
+    aftLossDistribution = Param(
+        Params._dummy(),
+        "aft_loss_distribution",
+        "Probability Density Function",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getAftLossDistribution(self) -> str:
+        """Gets the value of aftLossDistribution or its default value."""
+        return self.getOrDefault(self.aftLossDistribution)
+
+    lambdarankPairMethod = Param(
+        Params._dummy(),
+        "lambdarank_pair_method",
+        "pairs for pair-wise learning",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getLambdarankPairMethod(self) -> str:
+        """Gets the value of lambdarankPairMethod or its default value."""
+        return self.getOrDefault(self.lambdarankPairMethod)
+
+    lambdarankNumPairPerSample = Param(
+        Params._dummy(),
+        "lambdarank_num_pair_per_sample",
+        "It specifies the number of pairs sampled for each document when pair method is "
+        "mean, or the truncation level for queries when the pair method is topk. For "
+        "example, to train with ndcg@6, set lambdarank_num_pair_per_sample to 6 and "
+        "lambdarank_pair_method to topk",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getLambdarankNumPairPerSample(self) -> int:
+        """Gets the value of lambdarankNumPairPerSample or its default value."""
+        return int(self.getOrDefault(self.lambdarankNumPairPerSample))
+
+    lambdarankUnbiased = Param(
+        Params._dummy(),
+        "lambdarank_unbiased",
+        "Specify whether do we need to debias input click data.",
+        typeConverter=TypeConverters.toBoolean
+    )
+
+    def getLambdarankUnbiased(self) -> bool:
+        """Gets the value of lambdarankUnbiased or its default value."""
+        return bool(self.getOrDefault(self.lambdarankUnbiased))
+
+    lambdarankBiasNorm = Param(
+        Params._dummy(),
+        "lambdarank_bias_norm",
+        "Lp normalization for position debiasing, default is L2. Only relevant when "
+        "lambdarankUnbiased is set to true.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getLambdarankBiasNorm(self) -> float:
+        """Gets the value of lambdarankBiasNorm or its default value."""
+        return float(self.getOrDefault(self.lambdarankBiasNorm))
+
+    ndcgExpGain = Param(
+        Params._dummy(),
+        "ndcg_exp_gain",
+        "Whether we should use exponential gain function for NDCG.",
+        typeConverter=TypeConverters.toBoolean
+    )
+
+    def getNdcgExpGain(self) -> bool:
+        """Gets the value of ndcgExpGain or its default value."""
+        return bool(self.getOrDefault(self.ndcgExpGain))
+
+    def __init__(self):
+        super(LearningTaskParams, self).__init__()
+        self._setDefault(
+            objective="reg:squarederror",
+            numClass=0,
+            seed=0,
+            seedPerIteration=False,
+            tweedieVariancePower=1.5,
+            huberSlope=1,
+            lambdarankPairMethod="mean",
+            lambdarankUnbiased=False,
+            lambdarankBiasNorm=2,
+            ndcgExpGain=True
+        )
+
+
+class RabitParams(Params):
+    """
+    Parameters related to Rabit tracker configuration for distributed XGBoost.
+    """
+
+    rabitTrackerTimeout = Param(
+        Params._dummy(),
+        "rabitTrackerTimeout",
+        "The number of seconds before timeout waiting for workers to connect. "
+        "and for the tracker to shutdown.",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getRabitTrackerTimeout(self) -> int:
+        """Gets the value of rabitTrackerTimeout or its default value."""
+        return int(self.getOrDefault(self.rabitTrackerTimeout))
+
+    rabitTrackerHostIp = Param(
+        Params._dummy(),
+        "rabitTrackerHostIp",
+        "The Rabit Tracker host IP address. This is only needed if the host IP "
+        "cannot be automatically guessed.",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getRabitTrackerHostIp(self) -> str:
+        """Gets the value of rabitTrackerHostIp or its default value."""
+        return self.getOrDefault(self.rabitTrackerHostIp)
+
+    rabitTrackerPort = Param(
+        Params._dummy(),
+        "rabitTrackerPort",
+        "The port number for the tracker to listen to. Use a system allocated one by default.",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getRabitTrackerPort(self) -> int:
+        """Gets the value of rabitTrackerPort or its default value."""
+        return int(self.getOrDefault(self.rabitTrackerPort))
+
+    def __init__(self):
+        super(RabitParams, self).__init__()
+        self._setDefault(
+            rabitTrackerTimeout=0,
+            rabitTrackerHostIp="",
+            rabitTrackerPort=0
+        )
+
+
+class TreeBoosterParams(Params):
+    """
+    Parameters for Tree Boosting algorithms.
+    """
+
+    eta = Param(
+        Params._dummy(),
+        "eta",
+        "Step size shrinkage used in update to prevents overfitting. After each boosting step, "
+        "we can directly get the weights of new features, and eta shrinks the feature weights "
+        "to make the boosting process more conservative.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getEta(self) -> float:
+        """Gets the value of eta or its default value."""
+        return float(self.getOrDefault(self.eta))
+
+    gamma = Param(
+        Params._dummy(),
+        "gamma",
+        "Minimum loss reduction required to make a further partition on a leaf node of the tree. "
+        "The larger gamma is, the more conservative the algorithm will be.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getGamma(self) -> float:
+        """Gets the value of gamma or its default value."""
+        return float(self.getOrDefault(self.gamma))
+
+    maxDepth = Param(
+        Params._dummy(),
+        "max_depth",
+        "Maximum depth of a tree. Increasing this value will make the model more complex and "
+        "more likely to overfit. 0 indicates no limit on depth. Beware that XGBoost aggressively "
+        "consumes memory when training a deep tree. exact tree method requires non-zero value.",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getMaxDepth(self) -> int:
+        """Gets the value of maxDepth or its default value."""
+        return int(self.getOrDefault(self.maxDepth))
+
+    minChildWeight = Param(
+        Params._dummy(),
+        "min_child_weight",
+        "Minimum sum of instance weight (hessian) needed in a child. If the tree partition "
+        "step results in a leaf node with the sum of instance weight less than "
+        "min_child_weight, then the building process will give up further partitioning. "
+        "In linear regression task, this simply corresponds to minimum number of instances "
+        "needed to be in each node. The larger min_child_weight is, the more conservative "
+        "the algorithm will be.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getMinChildWeight(self) -> float:
+        """Gets the value of minChildWeight or its default value."""
+        return float(self.getOrDefault(self.minChildWeight))
+
+    maxDeltaStep = Param(
+        Params._dummy(),
+        "max_delta_step",
+        "Maximum delta step we allow each leaf output to be. If the value is set to 0, "
+        "it means there is no constraint. If it is set to a positive value, it can help "
+        "making the update step more conservative. Usually this parameter is not needed, "
+        "but it might help in logistic regression when class is extremely imbalanced. "
+        "Set it to value of 1-10 might help control the update.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getMaxDeltaStep(self) -> float:
+        """Gets the value of maxDeltaStep or its default value."""
+        return float(self.getOrDefault(self.maxDeltaStep))
+
+    subsample = Param(
+        Params._dummy(),
+        "subsample",
+        "Subsample ratio of the training instances. Setting it to 0.5 means that XGBoost "
+        "would randomly sample half of the training data prior to growing trees. and this "
+        "will prevent overfitting. Subsampling will occur once in every boosting iteration.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getSubsample(self) -> float:
+        """Gets the value of subsample or its default value."""
+        return float(self.getOrDefault(self.subsample))
+
+    samplingMethod = Param(
+        Params._dummy(),
+        "sampling_method",
+        "The method to use to sample the training instances. The supported sampling methods "
+        "uniform: each training instance has an equal probability of being selected. "
+        "Typically set subsample >= 0.5 for good results.\n"
+        "gradient_based: the selection probability for each training instance is proportional "
+        "to the regularized absolute value of gradients. subsample may be set to as low as "
+        "0.1 without loss of model accuracy. Note that this sampling method is only supported "
+        "when tree_method is set to hist and the device is cuda; other tree methods only "
+        "support uniform sampling.",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getSamplingMethod(self) -> str:
+        """Gets the value of samplingMethod or its default value."""
+        return self.getOrDefault(self.samplingMethod)
+
+    colsampleBytree = Param(
+        Params._dummy(),
+        "colsample_bytree",
+        "Subsample ratio of columns when constructing each tree. Subsampling occurs once "
+        "for every tree constructed.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getColsampleBytree(self) -> float:
+        """Gets the value of colsampleBytree or its default value."""
+        return float(self.getOrDefault(self.colsampleBytree))
+
+    colsampleBylevel = Param(
+        Params._dummy(),
+        "colsample_bylevel",
+        "Subsample ratio of columns for each level. Subsampling occurs once for every new "
+        "depth level reached in a tree. Columns are subsampled from the set of columns "
+        "chosen for the current tree.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getColsampleBylevel(self) -> float:
+        """Gets the value of colsampleBylevel or its default value."""
+        return float(self.getOrDefault(self.colsampleBylevel))
+
+    colsampleBynode = Param(
+        Params._dummy(),
+        "colsample_bynode",
+        "Subsample ratio of columns for each node (split). Subsampling occurs once every "
+        "time a new split is evaluated. Columns are subsampled from the set of columns "
+        "chosen for the current level.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getColsampleBynode(self) -> float:
+        """Gets the value of colsampleBynode or its default value."""
+        return float(self.getOrDefault(self.colsampleBynode))
+
+    # Additional parameters
+
+    lambda_ = Param(
+        Params._dummy(),
+        "lambda",
+        "L2 regularization term on weights. Increasing this value will make model more conservative.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getLambda(self) -> float:
+        """Gets the value of lambda or its default value."""
+        return float(self.getOrDefault(self.lambda_))
+
+    alpha = Param(
+        Params._dummy(),
+        "alpha",
+        "L1 regularization term on weights. Increasing this value will make model more conservative.",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getAlpha(self) -> float:
+        """Gets the value of alpha or its default value."""
+        return float(self.getOrDefault(self.alpha))
+
+    treeMethod = Param(
+        Params._dummy(),
+        "tree_method",
+        "The tree construction algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getTreeMethod(self) -> str:
+        """Gets the value of treeMethod or its default value."""
+        return self.getOrDefault(self.treeMethod)
+
+    scalePosWeight = Param(
+        Params._dummy(),
+        "scale_pos_weight",
+        "Control the balance of positive and negative weights, useful for unbalanced classes. "
+        "A typical value to consider: sum(negative instances) / sum(positive instances)",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getScalePosWeight(self) -> float:
+        """Gets the value of scalePosWeight or its default value."""
+        return float(self.getOrDefault(self.scalePosWeight))
+
+    updater = Param(
+        Params._dummy(),
+        "updater",
+        "A comma separated string defining the sequence of tree updaters to run, providing a modular "
+        "way to construct and to modify the trees. This is an advanced parameter that is usually set "
+        "automatically, depending on some other parameters. However, it could be also set explicitly "
+        "by a user. The following updaters exist:\n"
+        "grow_colmaker: non-distributed column-based construction of trees.\n"
+        "grow_histmaker: distributed tree construction with row-based data splitting based on "
+        "global proposal of histogram counting.\n"
+        "grow_quantile_histmaker: Grow tree using quantized histogram.\n"
+        "grow_gpu_hist: Enabled when tree_method is set to hist along with device=cuda.\n"
+        "grow_gpu_approx: Enabled when tree_method is set to approx along with device=cuda.\n"
+        "sync: synchronizes trees in all distributed nodes.\n"
+        "refresh: refreshes tree's statistics and or leaf values based on the current data. Note "
+        "that no random subsampling of data rows is performed.\n"
+        "prune: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth "
+        "greater than max_depth.",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getUpdater(self) -> str:
+        """Gets the value of updater or its default value."""
+        return self.getOrDefault(self.updater)
+
+    refreshLeaf = Param(
+        Params._dummy(),
+        "refresh_leaf",
+        "This is a parameter of the refresh updater. When this flag is 1, tree leafs as well as "
+        "tree nodes' stats are updated. When it is 0, only node stats are updated.",
+        typeConverter=TypeConverters.toBoolean
+    )
+
+    def getRefreshLeaf(self) -> bool:
+        """Gets the value of refreshLeaf or its default value."""
+        return bool(self.getOrDefault(self.refreshLeaf))
+
+    processType = Param(
+        Params._dummy(),
+        "process_type",
+        "A type of boosting process to run. options: {default, update}",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getProcessType(self) -> str:
+        """Gets the value of processType or its default value."""
+        return self.getOrDefault(self.processType)
+
+    growPolicy = Param(
+        Params._dummy(),
+        "grow_policy",
+        "Controls a way new nodes are added to the tree. Currently supported only if tree_method "
+        "is set to hist or approx. Choices: depthwise, lossguide. depthwise: split at nodes closest "
+        "to the root. lossguide: split at nodes with highest loss change.",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getGrowPolicy(self) -> str:
+        """Gets the value of growPolicy or its default value."""
+        return self.getOrDefault(self.growPolicy)
+
+    maxLeaves = Param(
+        Params._dummy(),
+        "max_leaves",
+        "Maximum number of nodes to be added. Not used by exact tree method",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getMaxLeaves(self) -> int:
+        """Gets the value of maxLeaves or its default value."""
+        return int(self.getOrDefault(self.maxLeaves))
+
+    maxBins = Param(
+        Params._dummy(),
+        "max_bin",
+        "Maximum number of discrete bins to bucket continuous features. Increasing this number "
+        "improves the optimality of splits at the cost of higher computation time. Only used if "
+        "tree_method is set to hist or approx.",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getMaxBins(self) -> int:
+        """Gets the value of maxBins or its default value."""
+        return int(self.getOrDefault(self.maxBins))
+
+    numParallelTree = Param(
+        Params._dummy(),
+        "num_parallel_tree",
+        "Number of parallel trees constructed during each iteration. This option is used to "
+        "support boosted random forest.",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getNumParallelTree(self) -> int:
+        """Gets the value of numParallelTree or its default value."""
+        return int(self.getOrDefault(self.numParallelTree))
+
+    monotoneConstraints = Param(
+        Params._dummy(),
+        "monotone_constraints",
+        "Constraint of variable monotonicity.",
+        typeConverter=TypeConverters.toListInt
+    )
+
+    def getMonotoneConstraints(self) -> List[int]:
+        """Gets the value of monotoneConstraints or its default value."""
+        return self.getOrDefault(self.monotoneConstraints)
+
+    interactionConstraints = Param(
+        Params._dummy(),
+        "interaction_constraints",
+        "Constraints for interaction representing permitted interactions. The constraints "
+        "must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]], "
+        "where each inner list is a group of indices of features that are allowed to interact "
+        "with each other. See tutorial for more information",
+        typeConverter=TypeConverters.toString
+    )
+
+    def getInteractionConstraints(self) -> str:
+        """Gets the value of interactionConstraints or its default value."""
+        return self.getOrDefault(self.interactionConstraints)
+
+    maxCachedHistNode = Param(
+        Params._dummy(),
+        "max_cached_hist_node",
+        "Maximum number of cached nodes for CPU histogram.",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getMaxCachedHistNode(self) -> int:
+        """Gets the value of maxCachedHistNode or its default value."""
+        return int(self.getOrDefault(self.maxCachedHistNode))
+
+    def __init__(self):
+        super().__init__()
+        self._setDefault(
+            eta=0.3, gamma=0, maxDepth=6, minChildWeight=1, maxDeltaStep=0,
+            subsample=1, samplingMethod="uniform", colsampleBytree=1, colsampleBylevel=1,
+            colsampleBynode=1, lambda_=1, alpha=0, treeMethod="auto", scalePosWeight=1,
+            processType="default", growPolicy="depthwise", maxLeaves=0, maxBins=256,
+            numParallelTree=1, maxCachedHistNode=65536
+        )
+
+
+class SparkParams(Params):
+    """
+    Parameters for XGBoost on Spark.
+    """
+
+    numWorkers = Param(
+        Params._dummy(),
+        "numWorkers",
+        "Number of workers used to train xgboost",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getNumWorkers(self) -> int:
+        """Gets the value of numWorkers or its default value."""
+        return int(self.getOrDefault(self.numWorkers))
+
+    numRound = Param(
+        Params._dummy(),
+        "numRound",
+        "The number of rounds for boosting",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getNumRound(self) -> int:
+        """Gets the value of numRound or its default value."""
+        return int(self.getOrDefault(self.numRound))
+
+    forceRepartition = Param(
+        Params._dummy(),
+        "forceRepartition",
+        "If the partition is equal to numWorkers, xgboost won't repartition the dataset. "
+        "Set forceRepartition to true to force repartition.",
+        typeConverter=TypeConverters.toBoolean
+    )
+
+    def getForceRepartition(self) -> bool:
+        """Gets the value of forceRepartition or its default value."""
+        return bool(self.getOrDefault(self.forceRepartition))
+
+    numEarlyStoppingRounds = Param(
+        Params._dummy(),
+        "numEarlyStoppingRounds",
+        "Stop training Number of rounds of decreasing eval metric to tolerate before stopping training",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getNumEarlyStoppingRounds(self) -> int:
+        """Gets the value of numEarlyStoppingRounds or its default value."""
+        return int(self.getOrDefault(self.numEarlyStoppingRounds))
+
+    inferBatchSize = Param(
+        Params._dummy(),
+        "inferBatchSize",
+        "batch size in rows to be grouped for inference",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getInferBatchSize(self) -> int:
+        """Gets the value of inferBatchSize or its default value."""
+        return int(self.getOrDefault(self.inferBatchSize))
+
+    missing = Param(
+        Params._dummy(),
+        "missing",
+        "The value treated as missing",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getMissing(self) -> float:
+        """Gets the value of missing or its default value."""
+        return float(self.getOrDefault(self.missing))
+
+    featureNames = Param(
+        Params._dummy(),
+        "feature_names",
+        "an array of feature names",
+        typeConverter=TypeConverters.toListString
+    )
+
+    def getFeatureNames(self) -> List[str]:
+        """Gets the value of featureNames or its default value."""
+        return self.getOrDefault(self.featureNames)
+
+    featureTypes = Param(
+        Params._dummy(),
+        "feature_types",
+        "an array of feature types",
+        typeConverter=TypeConverters.toListString
+    )
+
+    def getFeatureTypes(self) -> List[str]:
+        """Gets the value of featureTypes or its default value."""
+        return self.getOrDefault(self.featureTypes)
+
+    useExternalMemory = Param(
+        Params._dummy(),
+        "useExternalMemory",
+        "Whether to use the external memory or not when building QuantileDMatrix. Please note that "
+        "useExternalMemory is useful only when `device` is set to `cuda` or `gpu`. When "
+        "useExternalMemory is enabled, the directory specified by spark.local.dir if set will be "
+        "used to cache the temporary files, if spark.local.dir is not set, the /tmp directory "
+        "will be used.",
+        typeConverter=TypeConverters.toBoolean
+    )
+
+    def getUseExternalMemory(self) -> bool:
+        """Gets the value of useExternalMemory or its default value."""
+        return bool(self.getOrDefault(self.useExternalMemory))
+
+    maxNumDevicePages = Param(
+        Params._dummy(),
+        "maxNumDevicePages",
+        "Maximum number of pages cached in device",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getMaxNumDevicePages(self) -> int:
+        """Gets the value of maxNumDevicePages or its default value."""
+        return int(self.getOrDefault(self.maxNumDevicePages))
+
+    maxQuantileBatches = Param(
+        Params._dummy(),
+        "maxQuantileBatches",
+        "Maximum quantile batches",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getMaxQuantileBatches(self) -> int:
+        """Gets the value of maxQuantileBatches or its default value."""
+        return int(self.getOrDefault(self.maxQuantileBatches))
+
+    minCachePageBytes = Param(
+        Params._dummy(),
+        "minCachePageBytes",
+        "Minimum number of bytes for each ellpack page in cache. Only used for in-host",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getMinCachePageBytes(self) -> int:
+        """Gets the value of minCachePageBytes or its default value."""
+        return int(self.getOrDefault(self.minCachePageBytes))
+
+    cacheBatchNumber = Param(
+        Params._dummy(),
+        "cacheBatchNumber",
+        "Maximum batches to be allowed to be cached. When enabling ExternalMemory, to overlap "
+        "the caching time, we put the caching process run in the backgroud, this number is to "
+        "limit how many batches must be cached before continuing to handling the current batch",
+        typeConverter=TypeConverters.toInt
+    )
+
+    def getCacheBatchNumber(self) -> int:
+        """Gets the value of cacheBatchNumber or its default value."""
+        return int(self.getOrDefault(self.cacheBatchNumber))
+
+    cacheHostRatio = Param(
+        Params._dummy(),
+        "cacheHostRatio",
+        "Used by the GPU implementation. For GPU-based inputs, XGBoost can split the cache into "
+        "host and device caches to reduce the data transfer overhead. This parameter specifies "
+        "the size of host cache compared to the size of the entire cache: host / (host + device)",
+        typeConverter=TypeConverters.toFloat
+    )
+
+    def getCacheHostRatio(self) -> float:
+        """Gets the value of cacheHostRatio or its default value."""
+        return float(self.getOrDefault(self.cacheHostRatio))
+
+    # Assuming featuresCols is defined elsewhere but referenced in the defaults
+    featuresCols = Param(
+        Params._dummy(),
+        "featuresCols",
+        "Feature column names",
+        typeConverter=TypeConverters.toListString
+    )
+
+    def __init__(self):
+        super(SparkParams, self).__init__()
+        self._setDefault(
+            numRound=100,
+            numWorkers=1,
+            inferBatchSize=(32 << 10),
+            numEarlyStoppingRounds=0,
+            forceRepartition=False,
+            missing=float("nan"),
+            featuresCols=[],
+            featureNames=[],
+            featureTypes=[],
+            useExternalMemory=False,
+            maxNumDevicePages=-1,
+            maxQuantileBatches=-1,
+            minCachePageBytes=-1,
+            cacheBatchNumber=1,
+            cacheHostRatio=-1.0,
+        )
+
+
+class XGBoostParams(SparkParams, DartBoosterParams, GeneralParams,
+                    LearningTaskParams, RabitParams, TreeBoosterParams):
+
+    def _set_params(self: "P", **kwargs: Any) -> "P":
+        if "featuresCol" in kwargs:
+            v = kwargs.pop("featuresCol")
+            if isinstance(v, str):
+                self._set(**{"featuresCol": v})
+            elif isinstance(v, List):
+                self._set(**{"featuresCols": v})
+
+        return self._set(**kwargs)
diff --git a/jvm-packages/xgboost4j-spark/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator b/jvm-packages/xgboost4j-spark/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
new file mode 100644
index 000000000000..79d09926acc9
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -0,0 +1,3 @@
+ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
+ml.dmlc.xgboost4j.scala.spark.XGBoostRanker
+ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala
index a5acf2475977..dcbbb2a4d0f8 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala
@@ -31,7 +31,7 @@ import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.ml.util.{DefaultParamsWritable, MLReader, MLWritable, MLWriter}
 import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql._
+import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions.{array, col, udf}
 import org.apache.spark.sql.types._
 
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala
index 2d94aade5ac6..73ae808574fe 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala
@@ -193,16 +193,32 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
 
   final def getMaxQuantileBatches: Int = $(maxQuantileBatches)
 
-  final val minCachePageBytes = new IntParam(this, "minCachePageBytes", "Minimum number of " +
+  final val minCachePageBytes = new LongParam(this, "minCachePageBytes", "Minimum number of " +
     "bytes for each ellpack page in cache. Only used for in-host")
 
-  final def getMinCachePageBytes: Int = $(minCachePageBytes)
+  final def getMinCachePageBytes: Long = $(minCachePageBytes)
+
+  final val cacheBatchNumber = new IntParam(this, "cacheBatchNumber",
+    "Maximum batches to be allowed to be cached. When enabling ExternalMemory, to overlap " +
+    "the caching time, we put the caching process run in the backgroud, this number is to " +
+      "limit how many batches must be cached before continuing to handling the current batch.",
+    ParamValidators.gtEq(1))
+
+  final def getCacheBatchNumber: Int = $(cacheBatchNumber)
+
+  final val cacheHostRatio = new FloatParam(this, "cacheHostRatio",
+    "Used by the GPU implementation. For GPU-based inputs, XGBoost can split the cache into " +
+      "host and device caches to reduce the data transfer overhead. This parameter specifies " +
+      "the size of host cache compared to the size of the entire cache: host / (host + device)")
+
+  final def getCacheHostRatio: Float = $(cacheHostRatio)
 
   setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10),
     numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN,
     featuresCols -> Array.empty, customObj -> null, customEval -> null,
     featureNames -> Array.empty, featureTypes -> Array.empty, useExternalMemory -> false,
-    maxQuantileBatches -> -1, minCachePageBytes -> -1)
+    maxQuantileBatches -> -1, minCachePageBytes -> -1, cacheBatchNumber -> 1,
+    cacheHostRatio -> -1.0f)
 
   addNonXGBoostParam(numWorkers, numRound, numEarlyStoppingRounds, inferBatchSize, featuresCol,
     labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol,
@@ -248,7 +264,13 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
 
   def setMaxQuantileBatches(value: Int): T = set(maxQuantileBatches, value).asInstanceOf[T]
 
-  def setMinCachePageBytes(value: Int): T = set(minCachePageBytes, value).asInstanceOf[T]
+  def setMinCachePageBytes(value: Long): T = set(minCachePageBytes, value).asInstanceOf[T]
+
+  def setCacheBatchNumber(value: Int): T = set(cacheBatchNumber, value)
+    .asInstanceOf[T]
+
+  def setCacheHostRatio(value: Float): T = set(cacheHostRatio, value)
+    .asInstanceOf[T]
 
   protected[spark] def featureIsArrayType(schema: StructType): Boolean =
     schema(getFeaturesCol).dataType.isInstanceOf[ArrayType]
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala
index 8ff9839be7ee..f81c69d255a7 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala
@@ -21,7 +21,7 @@ import java.io.{File, FileInputStream}
 import org.apache.commons.io.IOUtils
 import org.apache.spark.SparkContext
 import org.apache.spark.ml.linalg.Vectors
-import org.apache.spark.sql._
+import org.apache.spark.sql.{DataFrame, SparkSession}
 import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
 
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index b9c144dd044f..063d3519425e 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -5,11 +5,11 @@
     4.0.0
     
         ml.dmlc
-        xgboost-jvm_2.12
+        xgboost-jvm_2.13
         3.1.0-SNAPSHOT
     
     xgboost4j
-    xgboost4j_2.12
+    xgboost4j_2.13
     3.1.0-SNAPSHOT
     jar
 
diff --git a/src/common/cuda_dr_utils.cc b/src/common/cuda_dr_utils.cc
index e51a76e00a2d..17d39e02ca11 100644
--- a/src/common/cuda_dr_utils.cc
+++ b/src/common/cuda_dr_utils.cc
@@ -40,6 +40,7 @@ CuDriverApi::CuDriverApi(std::int32_t cu_major, std::int32_t cu_minor, std::int3
   safe_load("cuGetErrorName", &this->cuGetErrorName);
   safe_load("cuDeviceGetAttribute", &this->cuDeviceGetAttribute);
   safe_load("cuDeviceGet", &this->cuDeviceGet);
+  safe_load("cuDeviceGetUuid", &this->cuDeviceGetUuid);
 #if defined(CUDA_HW_DECOM_AVAILABLE)
   // CTK 12.8
   if (((cu_major == 12 && cu_minor >= 8) || cu_major > 12) && (kdm_major >= 570)) {
diff --git a/src/common/cuda_dr_utils.h b/src/common/cuda_dr_utils.h
index 9b9c1da5cd21..ed646227d639 100644
--- a/src/common/cuda_dr_utils.h
+++ b/src/common/cuda_dr_utils.h
@@ -48,6 +48,7 @@ struct CuDriverApi {
   // Device attributes
   using DeviceGetAttribute = CUresult(int *pi, CUdevice_attribute attrib, CUdevice dev);
   using DeviceGet = CUresult(CUdevice *device, int ordinal);
+  using DeviceGetUuid = CUresult(CUuuid *uuid, CUdevice dev);
 
 #if defined(CUDA_HW_DECOM_AVAILABLE)
   using BatchDecompressAsync = CUresult(CUmemDecompressParams *paramsArray, size_t count,
@@ -76,6 +77,7 @@ struct CuDriverApi {
   GetErrorName *cuGetErrorName{nullptr};              // NOLINT
   DeviceGetAttribute *cuDeviceGetAttribute{nullptr};  // NOLINT
   DeviceGet *cuDeviceGet{nullptr};                    // NOLINT
+  DeviceGetUuid *cuDeviceGetUuid{nullptr};            // NOLINT
 
 #if defined(CUDA_HW_DECOM_AVAILABLE)
 
diff --git a/src/common/cuda_pinned_allocator.h b/src/common/cuda_pinned_allocator.h
index 04549d54d3e4..a0790e415097 100644
--- a/src/common/cuda_pinned_allocator.h
+++ b/src/common/cuda_pinned_allocator.h
@@ -11,6 +11,7 @@
 #include    // for numeric_limits
 #include    // for unique_ptr
 #include       // for bad_array_new_length
+#include 
 
 #include "common.h"
 
@@ -89,9 +90,23 @@ struct SamAllocPolicy {
     }
 
     size_type n_bytes = cnt * sizeof(value_type);
-    pointer result = reinterpret_cast(std::malloc(n_bytes));
-    if (!result) {
-      throw std::bad_alloc{};
+    auto constexpr kAlign = 1024ul * 1024ul * 512ul;
+    pointer result = nullptr;
+    if (n_bytes >= kAlign) {
+      result = reinterpret_cast(std::aligned_alloc(kAlign, n_bytes));
+      if (!result) {
+        throw std::bad_alloc{};
+      }
+      if (madvise(result, n_bytes, MADV_HUGEPAGE) != 0) {
+        std::int32_t errsv = errno;
+        auto err = std::error_code{errsv, std::system_category()};
+        LOG(FATAL) << err.message();
+      }
+    } else {
+      result = reinterpret_cast(std::malloc(n_bytes));
+      if (!result) {
+        throw std::bad_alloc{};
+      }
     }
     dh::safe_cuda(cudaHostRegister(result, n_bytes, cudaHostRegisterDefault));
     return result;
diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh
index 1b60670d0d68..90b29ff42932 100644
--- a/src/common/quantile.cuh
+++ b/src/common/quantile.cuh
@@ -181,7 +181,7 @@ class SketchContainer {
     this->Current().shrink_to_fit();
     this->Other().clear();
     this->Other().shrink_to_fit();
-    LOG(DEBUG) << "Quantile memory cost:" << this->MemCapacityBytes();
+    LOG(DEBUG) << "Quantile memory cost:" << common::HumanMemUnit(this->MemCapacityBytes());
   }
 
   /* \brief Merge quantiles from other GPU workers. */
diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu
index 4f4c862e4a8e..b94eea784429 100644
--- a/src/data/ellpack_page_source.cu
+++ b/src/data/ellpack_page_source.cu
@@ -21,8 +21,44 @@
 #include "ellpack_page_source.h"
 #include "proxy_dmatrix.cuh"  // for Dispatch
 #include "xgboost/base.h"     // for bst_idx_t
+#include 
 
 namespace xgboost::data {
+#define safe_nvml(call)                                          \
+  do {                                                           \
+    auto __status = (call);                                      \
+    if (__status != NVML_SUCCESS) {                              \
+      LOG(FATAL) << (#call) << ":" << nvmlErrorString(__status); \
+    }                                                            \
+  } while (0)
+
+void SetOptimalCpuAffinity() {
+  // fixme: check win
+  safe_nvml(nvmlInit());
+
+  std::int32_t ordinal = curt::CurrentDevice();
+  nvmlDevice_t device;
+  CUuuid dev_uuid;
+
+  std::stringstream s;
+  std::unordered_set dashPos{0, 4, 6, 8, 10};
+
+  cudr::GetGlobalCuDriverApi().cuDeviceGetUuid(&dev_uuid, ordinal);
+
+  s << "GPU";
+  for (int i = 0; i < 16; i++) {
+    if (dashPos.count(i)) {
+      s << '-';
+    }
+    s << std::hex << std::setfill('0') << std::setw(2) << (0xFF & (int)dev_uuid.bytes[i]);
+  }
+  std::cout << "s:" << s.str() << std::endl;
+  // fixme: maybe check not not supported error
+  safe_nvml(nvmlDeviceGetHandleByUUID(s.str().c_str(), &device));
+  safe_nvml(nvmlDeviceSetCpuAffinity(device));
+  safe_nvml(nvmlShutdown());
+}
+
 /**
  * Cache
  */
@@ -144,6 +180,7 @@ class EllpackHostCacheStreamImpl {
                    std::size_t{1});
       return n_bytes;
     };
+    static thread_local bool nvml_set = false;
     // Finish writing a (concatenated) cache page.
     auto commit_page = [cache_host_ratio, get_host_nbytes](EllpackPageImpl const* old_impl) {
       CHECK_EQ(old_impl->gidx_buffer.Resource()->Type(), common::ResourceHandler::kCudaMalloc);
@@ -154,6 +191,13 @@ class EllpackHostCacheStreamImpl {
       // Host cache
       auto n_bytes = get_host_nbytes(old_impl);
       CHECK_LE(n_bytes, old_impl->gidx_buffer.size_bytes());
+
+      if (!nvml_set) {
+        SetOptimalCpuAffinity();
+        // SetCpuAff();
+        nvml_set = true;
+      }
+
       new_impl->gidx_buffer =
           common::MakeFixedVecWithPinnedMalloc(n_bytes);
       if (n_bytes > 0) {
@@ -231,6 +275,13 @@ class EllpackHostCacheStreamImpl {
     CHECK_EQ(this->cache_->h_pages.size(), this->cache_->d_pages.size());
     auto [h_page, d_page] = this->cache_->At(this->ptr_);
 
+    static thread_local bool nvml_set = false;
+    if (!nvml_set) {
+      SetOptimalCpuAffinity();
+      // SetCpuAff();
+      nvml_set = true;
+    }
+
     auto ctx = Context{}.MakeCUDA(dh::CurrentDevice());
     auto out_impl = out->Impl();
     if (prefetch_copy) {
diff --git a/src/data/quantile_dmatrix.cu b/src/data/quantile_dmatrix.cu
index be13c260406b..6acbe34d3d47 100644
--- a/src/data/quantile_dmatrix.cu
+++ b/src/data/quantile_dmatrix.cu
@@ -103,6 +103,9 @@ void MakeSketches(Context const* ctx,
       }
       if (sketches.back().second > (1ul << (sketches.size() - 1)) ||
           sketches.back().second == static_cast(max_quantile_blocks)) {
+        LOG(DEBUG) << "Prune sub-stream. sketch:" << (1ul << (sketches.size() - 1))
+                   << " mqb:" << max_quantile_blocks
+                   << " sub-stream size:" << sketches.back().second;
         // Cut the sub-stream.
         auto n_cuts_per_feat =
             common::detail::RequiredSampleCutsPerColumn(p.max_bin, ext_info.accumulated_rows);