-
Notifications
You must be signed in to change notification settings - Fork 621
[WIP]V0.11.0 dev-Token level re-inference #4508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v0.11.0-dev
Are you sure you want to change the base?
Conversation
cover network err(only for dp)
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant fault tolerance and token-level re-inference mechanism for vLLM on the Ascend platform. The changes include new components for fault detection, a recovery handler chain, and distributed coordination for error recovery. My review has identified several critical and high-severity issues that should be addressed. These include critical bugs such as incorrect assertion syntax which bypasses important safety checks, and missing resource cleanup logic on failure paths which will lead to resource leaks. Additionally, there are high-severity design concerns, including the use of internal, non-public APIs, hardcoded model-specific configurations that limit the feature's applicability, and a hardcoded fault-tolerance level that restricts functionality. Addressing these points will improve the robustness, maintainability, and generality of this new fault tolerance system.
| assert( | ||
| torch.distributed.is_initialized() | ||
| ),"Default torch process group must be initialized" | ||
|
|
||
| assert( | ||
| torch.distributed.is_gloo_available() | ||
| ),"Gloo process group must be available" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assert statements on lines 36 and 40 use tuple syntax assert(condition, message), which is a common pitfall in Python. A non-empty tuple always evaluates to True, so these assertions will never fail, even if the conditions torch.distributed.is_initialized() or torch.distributed.is_gloo_available() are false. This is a critical issue as it bypasses important validations, potentially allowing the program to proceed in an invalid state and causing hard-to-debug errors later.
| assert( | |
| torch.distributed.is_initialized() | |
| ),"Default torch process group must be initialized" | |
| assert( | |
| torch.distributed.is_gloo_available() | |
| ),"Gloo process group must be available" | |
| assert torch.distributed.is_initialized(), "Default torch process group must be initialized" | |
| assert torch.distributed.is_gloo_available(), "Gloo process group must be available" |
| elif torch.equal(ft_action,FaultAction.RAISE_EXCEPTION): | ||
| logger.info(f"Raise exception at rank {self.rank}") | ||
| # TODO: Need to clear cache for current batch and destroy all group | ||
| raise e | ||
| elif torch.equal(ft_action,FaultAction.RETURN): | ||
| logger.info(f"Abort current batch at rank {self.rank}") | ||
| # TODO: Need to clear cache for current batch and destroy all group | ||
| return None | ||
| else: | ||
| # TODO: Need to clear cache for current batch and destroy all group | ||
| logger.info(f"Unknown fault action found at rank {self.rank} ") | ||
| raise e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fault_tolerance_decorator catches exceptions but does not clean up resources for the failed batch when the recovery action is RAISE_EXCEPTION or RETURN. The TODO comments indicate this is a known missing piece. Without freeing the KV cache blocks and other resources associated with the aborted requests, this will lead to resource leaks, eventually causing the system to hang or crash when it runs out of memory or cache blocks. This is a critical issue that must be addressed.
| class FaultStatus(Enum): | ||
| """ | ||
| Fault status which fault_tolerance put into fault_queue | ||
| """ | ||
| ACTIVE = torch.tensor([0]) | ||
| UCE_ERR = torch.tensor([1]) | ||
| FORCE_STOP = torch.tensor([2]) | ||
| NETWORK_ERR = torch.tensor([3]) | ||
|
|
||
| class FaultCommand: | ||
| """ | ||
| Fault command which rank 0 broadcast in fault_aware | ||
| """ | ||
| INIT_CMD = torch.tensor([0]) | ||
| SILENCE_CMD = torch.tensor([1]) | ||
| STOP_DEVICE_CMD = torch.tensor([2]) | ||
|
|
||
| class UCEType(Enum): | ||
| """ | ||
| Specific uce type for HBM UCE | ||
| """ | ||
| WEIGHTS_UCE = "WEIGHTS_UCE" | ||
| KVCACHE_UCE = "KVCACHE_UCE" | ||
| ACTIVATION_UCE = "ACTIVATION_UCE" | ||
| UNKNOWN_UCE = "UNKNOWN_UCE" | ||
|
|
||
| class RecoveryStatus: | ||
| SUCCESS = torch.tensor([0]) | ||
| FAILED = torch.tensor([1]) | ||
|
|
||
| class FaultAction: | ||
| RAISE_EXCEPTION = torch.tensor([0]) | ||
| RETURN = torch.tensor([1]) | ||
| RECOMPUTE = torch.tensor([2]) No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classes FaultStatus, FaultCommand, RecoveryStatus, and FaultAction use torch.tensor objects as values for constants. This is unconventional and can lead to subtle bugs related to device placement or tensor properties. It is also inconsistent, as FaultStatus is an Enum while the others are plain classes. A better practice is to use simple integer values in enums (e.g., using enum.IntEnum) and create tensors from them only when needed for distributed communication. This would improve readability, maintainability, and avoid potential pitfalls with using tensors as constants.
from enum import IntEnum
class FaultStatus(IntEnum):
"""
Fault status which fault_tolerance put into fault_queue
"""
ACTIVE = 0
UCE_ERR = 1
FORCE_STOP = 2
NETWORK_ERR = 3
class FaultCommand(IntEnum):
"""
Fault command which rank 0 broadcast in fault_aware
"""
INIT_CMD = 0
SILENCE_CMD = 1
STOP_DEVICE_CMD = 2
class UCEType(Enum):
"""
Specific uce type for HBM UCE
"""
WEIGHTS_UCE = "WEIGHTS_UCE"
KVCACHE_UCE = "KVCACHE_UCE"
ACTIVATION_UCE = "ACTIVATION_UCE"
UNKNOWN_UCE = "UNKNOWN_UCE"
class RecoveryStatus(IntEnum):
SUCCESS = 0
FAILED = 1
class FaultAction(IntEnum):
RAISE_EXCEPTION = 0
RETURN = 1
RECOMPUTE = 2| memory_block_info = ctx.memory_block_info | ||
| if not memory_block_info.initialized: | ||
| memory_block_info.initialize() | ||
| uce_ptrs = _get_uce_addr() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code uses _get_uce_addr (imported on line 12), which is an internal API from torch_npu.npu.utils as indicated by the leading underscore. Relying on internal APIs is risky because they are not part of the public contract and can be changed or removed without notice in future versions of torch_npu. This could break the fault tolerance feature unexpectedly upon a library update.
| def map_to_original_param(self,merged_name:str,mapping_config:Dict[str,List[Tuple[str,Any]]] = None) -> List[str]: | ||
| default_mapping={ | ||
| "qkv_proj":[ | ||
| ("q_proj","q"), | ||
| ("k_proj","k"), | ||
| ("v_proj","v"), | ||
| ], | ||
| "gate_up_proj":[ | ||
| ("gate_proj",0), | ||
| ("up_proj",1) | ||
| ] | ||
| } | ||
| mapping = mapping_config if mapping_config is not None else default_mapping | ||
| original_names = [] | ||
| for merged_param_name,mappings in mapping.items(): | ||
| if merged_param_name in merged_name: | ||
| for original_param_name,_ in mappings: | ||
| original_name = merged_name.replace(merged_param_name,original_param_name) | ||
| original_names.append(original_name) | ||
| break | ||
| if not original_names: | ||
| return [merged_name] | ||
| return original_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The map_to_original_param method contains a hardcoded default_mapping for converting merged parameter names (like qkv_proj) back to their original names. This mapping is specific to certain model architectures (e.g., Llama-style models) and will cause the weight recovery feature to fail for models that use different naming conventions. The presence of the unused _load_mapping_config function suggests this was intended to be configurable. To make this feature more general and robust, the mapping should be loaded from a model-specific configuration.
| self.fault_tolerance = FaultTolerance( | ||
| vllm_config=self.vllm_config, | ||
| model=self.model_runner.model, | ||
| level=FaultToleranceLevel.BASIC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fault tolerance level is hardcoded to FaultToleranceLevel.BASIC. This prevents users from enabling FULL fault tolerance, which is defined to include features like KV cache UCE recovery. To allow users to leverage all fault tolerance capabilities, this level should be made configurable, for example, through VllmConfig or AscendConfig.
| level=FaultToleranceLevel.BASIC | |
| level=self.vllm_config.ascend_config.fault_tolerance_level |
Provide token-level reinference capability for vLLM on the Ascend platform,
currently only supporting token recomputation in network link failure scenarios under TP parallelism.