Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,35 +178,62 @@ def parse_args():
return parser.parse_args()


def get_dataset_optimal_tokens(dataset_info):
def get_dataset_optimal_tokens(dataset_info, model_name=None):
"""
Determine optimal token limit based on dataset complexity and reasoning requirements.
Determine optimal token limit based on dataset complexity, reasoning requirements, and model capabilities.

Token limits are optimized for structured response generation while maintaining
efficiency across different reasoning complexity levels.
efficiency across different reasoning complexity levels and model architectures.

Args:
dataset_info: Dataset information object
model_name: Model identifier (e.g., "openai/gpt-oss-20b", "Qwen/Qwen3-30B-A3B")
"""
dataset_name = dataset_info.name.lower()
difficulty = dataset_info.difficulty_level.lower()

# Optimized token limits per dataset (increased for reasoning mode support)
dataset_tokens = {
"gpqa": 1500, # Graduate-level scientific reasoning
# Determine model type and capabilities
model_multiplier = 1.0
if model_name:
model_lower = model_name.lower()
if "qwen" in model_lower:
# Qwen models are more efficient and can handle longer contexts
model_multiplier = 1.5
elif "deepseek" in model_lower:
# DeepSeek models (e.g., V3.1) are capable and can handle longer contexts
model_multiplier = 1.5
elif "gpt-oss" in model_lower:
# GPT-OSS models use baseline token limits
model_multiplier = 1.0
# Default to baseline for unknown models

# Base token limits per dataset (optimized for gpt-oss20b baseline)
base_dataset_tokens = {
"gpqa": 3000, # Graduate-level scientific reasoning (increased for complex multi-step reasoning)
"truthfulqa": 800, # Misconception analysis
"hellaswag": 800, # Natural continuation reasoning
"arc": 800, # Elementary/middle school science
"commonsenseqa": 1000, # Common sense reasoning
"mmlu": 600 if difficulty == "undergraduate" else 800, # Academic knowledge
"mmlu": 3000, # Academic knowledge (increased for complex technical domains like engineering/chemistry)
}

# Find matching dataset
for dataset_key, tokens in dataset_tokens.items():
# Find matching dataset and apply model multiplier
base_tokens = None
for dataset_key, tokens in base_dataset_tokens.items():
if dataset_key in dataset_name:
return tokens
base_tokens = tokens
break

# Fallback to difficulty-based tokens if dataset not found
if base_tokens is None:
difficulty_tokens = {"graduate": 300, "hard": 300, "moderate": 200, "easy": 150}
base_tokens = difficulty_tokens.get(difficulty, 200)

# Default based on difficulty level
difficulty_tokens = {"graduate": 300, "hard": 300, "moderate": 200, "easy": 150}
# Apply model-specific multiplier and round to nearest 50
final_tokens = int(base_tokens * model_multiplier)
final_tokens = ((final_tokens + 25) // 50) * 50 # Round to nearest 50

return difficulty_tokens.get(difficulty, 200)
return final_tokens


def get_available_models(endpoint: str, api_key: str = "") -> List[str]:
Expand Down Expand Up @@ -507,6 +534,20 @@ def evaluate_model_vllm_multimode(
q.cot_content is not None and q.cot_content.strip() for q in questions[:10]
)

# Debug: Show CoT content status for first few questions
print(f" CoT Debug - Checking first 10 questions:")
for i, q in enumerate(questions[:10]):
cot_status = (
"None"
if q.cot_content is None
else (
f"'{q.cot_content[:50]}...'"
if len(q.cot_content) > 50
else f"'{q.cot_content}'"
)
)
print(f" Q{i+1}: CoT = {cot_status}")

if has_cot_content:
print(f" Dataset has CoT content - using 3 modes: NR, XC, NR_REASONING")
else:
Expand Down Expand Up @@ -827,28 +868,31 @@ def main():
print(f"Router models: {router_models}")
print(f"vLLM models: {vllm_models}")

# Determine optimal token limit for this dataset
if args.max_tokens:
optimal_tokens = args.max_tokens
print(f"Using user-specified max_tokens: {optimal_tokens}")
else:
optimal_tokens = get_dataset_optimal_tokens(dataset_info)
print(
f"Using dataset-optimal max_tokens: {optimal_tokens} (for {dataset_info.name})"
)
# Function to get optimal tokens for a specific model
# For fair comparison, use consistent token limits regardless of model name
def get_model_optimal_tokens(model_name):
if args.max_tokens:
return args.max_tokens
else:
# Use base dataset tokens without model-specific multipliers for fair comparison
return get_dataset_optimal_tokens(dataset_info, model_name=None)

# Router evaluation (NR-only)
if args.run_router and router_endpoint and router_models:
for model in router_models:
model_tokens = get_model_optimal_tokens(model)
print(f"\nEvaluating router model: {model}")
print(
f"Using max_tokens: {model_tokens} (dataset-optimized for fair comparison)"
)
rt_df = evaluate_model_router_transparent(
questions=questions,
dataset=dataset,
model=model,
endpoint=router_endpoint,
api_key=router_api_key,
concurrent_requests=args.concurrent_requests,
max_tokens=optimal_tokens,
max_tokens=model_tokens,
temperature=args.temperature,
)
analysis = analyze_results(rt_df)
Expand All @@ -863,15 +907,19 @@ def main():
# Direct vLLM evaluation (NR/XC with reasoning ON/OFF)
if args.run_vllm and vllm_endpoint and vllm_models:
for model in vllm_models:
model_tokens = get_model_optimal_tokens(model)
print(f"\nEvaluating vLLM model: {model}")
print(
f"Using max_tokens: {model_tokens} (dataset-optimized for fair comparison)"
)
vdf = evaluate_model_vllm_multimode(
questions=questions,
dataset=dataset,
model=model,
endpoint=vllm_endpoint,
api_key=vllm_api_key,
concurrent_requests=args.concurrent_requests,
max_tokens=optimal_tokens,
max_tokens=model_tokens,
temperature=args.temperature,
exec_modes=args.vllm_exec_modes,
)
Expand Down
Loading