feat: add TeacherConfig and teacher engine factory for distillation#1359
feat: add TeacherConfig and teacher engine factory for distillation#1359penfever wants to merge 1 commit intoNovaSky-AI:mainfrom
Conversation
Adds support for a separate teacher model in distillation workflows: 1. TeacherConfig dataclass with model_path, top_k_logprobs, num_inference_engines, TP/PP size, gpu_memory_utilization, enforce_eager, backend, and engine_init_kwargs fields. 2. create_teacher_inference_engines_from_config() in main_base.py that creates dedicated vLLM engines for the teacher model with its own tokenizer (for cross-model retokenization). Teacher engines differ from student engines: they use max_logprobs equal to top_k_logprobs (not 1), don't enable sleep mode, and don't set up weight sync (teacher weights are static). This unblocks on-policy distillation with teacher logits and best-of-N distillation examples. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| "backend": teacher_cfg.backend, | ||
| "engine_init_kwargs": dict(teacher_cfg.engine_init_kwargs), | ||
| "enable_ray_prometheus_stats": False, | ||
| "max_logprobs": teacher_cfg.top_k_logprobs, |
There was a problem hiding this comment.
🔴 max_logprobs passed as kwarg to function that doesn't accept it, causing TypeError
create_teacher_inference_engines_from_config passes max_logprobs=teacher_cfg.top_k_logprobs to create_ray_wrapped_inference_engines at line 170, but that function does not have a max_logprobs parameter in its signature (ray_wrapped_inference_engine.py:86-118). This will raise TypeError: create_ray_wrapped_inference_engines() got an unexpected keyword argument 'max_logprobs' at runtime, completely breaking the teacher distillation feature. The underlying function hardcodes max_logprobs=1 on ray_wrapped_inference_engine.py:304, so to make the teacher's top_k_logprobs value actually propagate, a max_logprobs parameter needs to be added to the function signature and used in place of the hardcoded value.
Prompt for agents
The fix requires changes in two files:
1. In skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py, add a max_logprobs parameter to the create_ray_wrapped_inference_engines function signature (around line 117, before the closing parenthesis), e.g.: max_logprobs: int = 1
Then on line 304, change the hardcoded max_logprobs=1 to use the new parameter: max_logprobs=max_logprobs
2. No changes needed in skyrl/train/entrypoints/main_base.py line 170 since it already passes max_logprobs correctly — it just needs the function to accept the parameter.
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
Code Review
This pull request introduces a TeacherConfig and a factory for creating teacher inference engines, which is a key step for enabling distillation. The implementation is logical and well-structured. My review focuses on improving security and configurability. I've identified a significant security risk with hardcoding trust_remote_code=True and have suggested making this a configurable option that defaults to false. Additionally, I've recommended expanding TeacherConfig with more performance-tuning parameters to align it with the student model's configuration capabilities, allowing for better optimization.
| backend: str = "vllm" | ||
| engine_init_kwargs: Dict[str, Any] = field(default_factory=dict) |
There was a problem hiding this comment.
The TeacherConfig can be improved by adding more options for security and performance tuning:
- Security (
highseverity): Addtrust_remote_code: bool = False. Hardcodingtrust_remote_code=Truewhen loading the tokenizer is a security risk as it allows arbitrary code execution from the model repository. This should be an explicit, user-opt-in setting. - Performance (
mediumseverity): Addenable_prefix_caching,max_num_batched_tokens, andmax_num_seqs. These are available for the student model, and adding them here would allow for better performance tuning of the teacher model. For a static teacher,enable_prefix_caching=Truecould be a beneficial default.
| backend: str = "vllm" | |
| engine_init_kwargs: Dict[str, Any] = field(default_factory=dict) | |
| backend: str = "vllm" | |
| trust_remote_code: bool = False | |
| enable_prefix_caching: bool = True | |
| max_num_batched_tokens: Optional[int] = None | |
| max_num_seqs: Optional[int] = None | |
| engine_init_kwargs: Dict[str, Any] = field(default_factory=dict) |
| assert teacher_cfg.model_path is not None, "teacher.model_path must be set for distillation" | ||
|
|
||
| # Load teacher's own tokenizer for cross-model retokenization. | ||
| teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_cfg.model_path, trust_remote_code=True) |
There was a problem hiding this comment.
Hardcoding trust_remote_code=True introduces a security vulnerability by allowing arbitrary code execution from the model hub. This should be controlled by a configuration flag in TeacherConfig that defaults to False, making the risk explicit to the user.
| teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_cfg.model_path, trust_remote_code=True) | |
| teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_cfg.model_path, trust_remote_code=teacher_cfg.trust_remote_code) |
| "enable_prefix_caching": False, | ||
| "enforce_eager": teacher_cfg.enforce_eager, | ||
| "expert_parallel_size": 1, | ||
| "data_parallel_size": 1, | ||
| "shared_pg": None, # teacher gets its own placement group | ||
| "gpu_memory_utilization": teacher_cfg.gpu_memory_utilization, | ||
| "inference_engine_enable_sleep": False, # teacher doesn't share GPU | ||
| "async_engine": False, | ||
| "max_num_batched_tokens": None, | ||
| "max_num_seqs": None, |
There was a problem hiding this comment.
Several engine parameters are hardcoded here (enable_prefix_caching, max_num_batched_tokens, max_num_seqs). For consistency with the student engine configuration and to allow for better performance tuning, these should be configurable via TeacherConfig (as suggested in a separate comment). For a static teacher model, enable_prefix_caching=True could be a good default.
| "enable_prefix_caching": False, | |
| "enforce_eager": teacher_cfg.enforce_eager, | |
| "expert_parallel_size": 1, | |
| "data_parallel_size": 1, | |
| "shared_pg": None, # teacher gets its own placement group | |
| "gpu_memory_utilization": teacher_cfg.gpu_memory_utilization, | |
| "inference_engine_enable_sleep": False, # teacher doesn't share GPU | |
| "async_engine": False, | |
| "max_num_batched_tokens": None, | |
| "max_num_seqs": None, | |
| "enable_prefix_caching": teacher_cfg.enable_prefix_caching, | |
| "enforce_eager": teacher_cfg.enforce_eager, | |
| "expert_parallel_size": 1, | |
| "data_parallel_size": 1, | |
| "shared_pg": None, # teacher gets its own placement group | |
| "gpu_memory_utilization": teacher_cfg.gpu_memory_utilization, | |
| "inference_engine_enable_sleep": False, # teacher doesn't share GPU | |
| "async_engine": False, | |
| "max_num_batched_tokens": teacher_cfg.max_num_batched_tokens, | |
| "max_num_seqs": teacher_cfg.max_num_seqs, |
Summary
TeacherConfigdataclass toskyrl.train.configwith fields for model_path, top_k_logprobs, TP/PP size, gpu_memory_utilization, backend, and engine_init_kwargscreate_teacher_inference_engines_from_config()toskyrl.train.entrypoints.main_basethat creates dedicated vLLM engines for a teacher model with its own tokenizerTest plan
TeacherConfigdefaults are sensible (model_path=None means no teacher)teacher:section still work (default TeacherConfig)create_teacher_inference_engines_from_configcreates engines with correct max_logprobsDependencies
max_logprobstocreate_ray_wrapped_inference_engines🤖 Generated with Claude Code