Skip to content

feat: add TeacherConfig and teacher engine factory for distillation#1359

Closed
penfever wants to merge 1 commit intoNovaSky-AI:mainfrom
penfever:penfever/port-teacher-engine-config
Closed

feat: add TeacherConfig and teacher engine factory for distillation#1359
penfever wants to merge 1 commit intoNovaSky-AI:mainfrom
penfever:penfever/port-teacher-engine-config

Conversation

@penfever
Copy link
Copy Markdown

@penfever penfever commented Mar 20, 2026

Summary

  • Adds TeacherConfig dataclass to skyrl.train.config with fields for model_path, top_k_logprobs, TP/PP size, gpu_memory_utilization, backend, and engine_init_kwargs
  • Adds create_teacher_inference_engines_from_config() to skyrl.train.entrypoints.main_base that creates dedicated vLLM engines for a teacher model with its own tokenizer
  • Teacher engines use max_logprobs=top_k_logprobs (not 1), don't enable sleep mode, and don't set up weight sync (static weights)
  • Unblocks on-policy distillation with teacher logits and best-of-N distillation examples

Test plan

  • Verify TeacherConfig defaults are sensible (model_path=None means no teacher)
  • Verify existing configs without teacher: section still work (default TeacherConfig)
  • Verify create_teacher_inference_engines_from_config creates engines with correct max_logprobs
  • Verify teacher tokenizer is loaded separately from student tokenizer

Dependencies

🤖 Generated with Claude Code


Open with Devin

Adds support for a separate teacher model in distillation workflows:

1. TeacherConfig dataclass with model_path, top_k_logprobs,
   num_inference_engines, TP/PP size, gpu_memory_utilization,
   enforce_eager, backend, and engine_init_kwargs fields.

2. create_teacher_inference_engines_from_config() in main_base.py
   that creates dedicated vLLM engines for the teacher model with
   its own tokenizer (for cross-model retokenization).

Teacher engines differ from student engines: they use max_logprobs
equal to top_k_logprobs (not 1), don't enable sleep mode, and don't
set up weight sync (teacher weights are static).

This unblocks on-policy distillation with teacher logits and
best-of-N distillation examples.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 2 additional findings in Devin Review.

Open in Devin Review

"backend": teacher_cfg.backend,
"engine_init_kwargs": dict(teacher_cfg.engine_init_kwargs),
"enable_ray_prometheus_stats": False,
"max_logprobs": teacher_cfg.top_k_logprobs,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 max_logprobs passed as kwarg to function that doesn't accept it, causing TypeError

create_teacher_inference_engines_from_config passes max_logprobs=teacher_cfg.top_k_logprobs to create_ray_wrapped_inference_engines at line 170, but that function does not have a max_logprobs parameter in its signature (ray_wrapped_inference_engine.py:86-118). This will raise TypeError: create_ray_wrapped_inference_engines() got an unexpected keyword argument 'max_logprobs' at runtime, completely breaking the teacher distillation feature. The underlying function hardcodes max_logprobs=1 on ray_wrapped_inference_engine.py:304, so to make the teacher's top_k_logprobs value actually propagate, a max_logprobs parameter needs to be added to the function signature and used in place of the hardcoded value.

Prompt for agents
The fix requires changes in two files:

1. In skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py, add a max_logprobs parameter to the create_ray_wrapped_inference_engines function signature (around line 117, before the closing parenthesis), e.g.: max_logprobs: int = 1
Then on line 304, change the hardcoded max_logprobs=1 to use the new parameter: max_logprobs=max_logprobs

2. No changes needed in skyrl/train/entrypoints/main_base.py line 170 since it already passes max_logprobs correctly — it just needs the function to accept the parameter.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a TeacherConfig and a factory for creating teacher inference engines, which is a key step for enabling distillation. The implementation is logical and well-structured. My review focuses on improving security and configurability. I've identified a significant security risk with hardcoding trust_remote_code=True and have suggested making this a configurable option that defaults to false. Additionally, I've recommended expanding TeacherConfig with more performance-tuning parameters to align it with the student model's configuration capabilities, allowing for better optimization.

Comment on lines +711 to +712
backend: str = "vllm"
engine_init_kwargs: Dict[str, Any] = field(default_factory=dict)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TeacherConfig can be improved by adding more options for security and performance tuning:

  • Security (high severity): Add trust_remote_code: bool = False. Hardcoding trust_remote_code=True when loading the tokenizer is a security risk as it allows arbitrary code execution from the model repository. This should be an explicit, user-opt-in setting.
  • Performance (medium severity): Add enable_prefix_caching, max_num_batched_tokens, and max_num_seqs. These are available for the student model, and adding them here would allow for better performance tuning of the teacher model. For a static teacher, enable_prefix_caching=True could be a beneficial default.
Suggested change
backend: str = "vllm"
engine_init_kwargs: Dict[str, Any] = field(default_factory=dict)
backend: str = "vllm"
trust_remote_code: bool = False
enable_prefix_caching: bool = True
max_num_batched_tokens: Optional[int] = None
max_num_seqs: Optional[int] = None
engine_init_kwargs: Dict[str, Any] = field(default_factory=dict)

assert teacher_cfg.model_path is not None, "teacher.model_path must be set for distillation"

# Load teacher's own tokenizer for cross-model retokenization.
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_cfg.model_path, trust_remote_code=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding trust_remote_code=True introduces a security vulnerability by allowing arbitrary code execution from the model hub. This should be controlled by a configuration flag in TeacherConfig that defaults to False, making the risk explicit to the user.

Suggested change
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_cfg.model_path, trust_remote_code=True)
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_cfg.model_path, trust_remote_code=teacher_cfg.trust_remote_code)

Comment on lines +156 to +165
"enable_prefix_caching": False,
"enforce_eager": teacher_cfg.enforce_eager,
"expert_parallel_size": 1,
"data_parallel_size": 1,
"shared_pg": None, # teacher gets its own placement group
"gpu_memory_utilization": teacher_cfg.gpu_memory_utilization,
"inference_engine_enable_sleep": False, # teacher doesn't share GPU
"async_engine": False,
"max_num_batched_tokens": None,
"max_num_seqs": None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Several engine parameters are hardcoded here (enable_prefix_caching, max_num_batched_tokens, max_num_seqs). For consistency with the student engine configuration and to allow for better performance tuning, these should be configurable via TeacherConfig (as suggested in a separate comment). For a static teacher model, enable_prefix_caching=True could be a good default.

Suggested change
"enable_prefix_caching": False,
"enforce_eager": teacher_cfg.enforce_eager,
"expert_parallel_size": 1,
"data_parallel_size": 1,
"shared_pg": None, # teacher gets its own placement group
"gpu_memory_utilization": teacher_cfg.gpu_memory_utilization,
"inference_engine_enable_sleep": False, # teacher doesn't share GPU
"async_engine": False,
"max_num_batched_tokens": None,
"max_num_seqs": None,
"enable_prefix_caching": teacher_cfg.enable_prefix_caching,
"enforce_eager": teacher_cfg.enforce_eager,
"expert_parallel_size": 1,
"data_parallel_size": 1,
"shared_pg": None, # teacher gets its own placement group
"gpu_memory_utilization": teacher_cfg.gpu_memory_utilization,
"inference_engine_enable_sleep": False, # teacher doesn't share GPU
"async_engine": False,
"max_num_batched_tokens": teacher_cfg.max_num_batched_tokens,
"max_num_seqs": teacher_cfg.max_num_seqs,

@penfever
Copy link
Copy Markdown
Author

Closing for now — this is infrastructure for distillation features (#5, #6 in our tracker) that aren't ready for upstream yet. Will resubmit as part of a distillation PR.

@penfever penfever closed this Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant