Skip to content

feat: add heuristics for checkpoint files prefetching. #4765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 44 additions & 26 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple

import psutil
import safetensors
import torch
import torch._dynamo.config

import tensorrt_llm.bindings.internal.userbuffers as ub
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank,
local_mpi_size, nvtx_range, release_gc,
torch_dtype_to_str, trace_func)
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
from tensorrt_llm.logger import logger
Expand Down Expand Up @@ -129,6 +131,14 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig,
model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant


def _prefetch_one_file(file_name, rank):
if os.path.exists(file_name):
logger.info(f"Rank {rank} prefetching {file_name} to memory...")
with open(file_name, 'rb') as f:
f.read()
logger.info(f"Rank {rank} finished prefetching {file_name}.")


def prefetch_files(file_names: List[str], mapping: Mapping):
"""
Prefetch safetensors files to memory so that the weight loading will be much faster.
Expand All @@ -137,33 +147,35 @@ def prefetch_files(file_names: List[str], mapping: Mapping):
heuristics about when to prefetch and when not to.
"""

def _prefetch_one_file(file_name, rank):
if os.path.exists(file_name):
logger.info(f"Rank {rank} prefetching {file_name} to memory...")
with open(file_name, 'rb') as f:
f.read()
logger.info(f"Rank {rank} finished prefetching {file_name}.")

# Find out the files to prefetch for the current rank.
# Each rank loads files with indices rank, rank + world_size, rank + 2*world_size, etc.
local_file_names = file_names[mapping.rank::mapping.world_size]

processes = []
for file_name in local_file_names:
process = multiprocessing.Process(target=_prefetch_one_file,
args=(file_name, mapping.rank))
process.start()
processes.append(process)

for process in processes:
process.join()
# Each rank loads files with indices local_rank, local_rank + local_mpi_size, local_rank + 2*local_mpi_size, etc.
local_file_names = file_names[local_mpi_rank()::local_mpi_size()]

max_processes = min(multiprocessing.cpu_count() * 2, 16)
with multiprocessing.Pool(processes=max_processes) as pool:
pool.starmap(
_prefetch_one_file,
[(file_name, mapping.rank) for file_name in local_file_names],
)


def load_weights(checkpoint_dir: str, mapping: Mapping):
def load_weights(
checkpoint_dir: str,
mapping: Mapping,
):
weights = {}
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
if weight_files:
prefetch_files(weight_files, mapping)
# Prefetch the weight files to CPU memory if the size is less than 90% of the available memory.
# This is a heuristic to avoid prefetching files that are too large and causing file cache thrashing.
prefetch_size = sum(os.path.getsize(file) for file in weight_files)
enable_prefetch = prefetch_size < psutil.virtual_memory(
).available * 0.9
if enable_prefetch:
logger.info(
f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files."
)
prefetch_files(weight_files, mapping)
for file in weight_files:
logger.info(f"Loading {file}")
part_weights = safetensors.torch.load_file(file)
Expand Down Expand Up @@ -922,16 +934,22 @@ def init_meta_tensor(t: torch.Tensor):
model = AutoModelForCausalLM.from_config(config)

model.to("cuda")
rank_model_storage = get_rank_model_storage(model)
logger.info(
f"Rank {self.mapping.rank} uses {get_rank_model_storage(model) / (1024**3):.2f} GB for model weights."
f"Rank {self.mapping.rank} uses {rank_model_storage / (1024**3):.2f} GB for model weights."
)

if load_format == LoadFormat.AUTO:
if hasattr(model, 'llm_checkpoint_dir'):
weights = load_weights(model.llm_checkpoint_dir,
self.mapping)
weights = load_weights(
model.llm_checkpoint_dir,
self.mapping,
)
else:
weights = load_weights(checkpoint_dir, self.mapping)
weights = load_weights(
checkpoint_dir,
self.mapping,
)

model.load_weights(weights)

Expand Down