-
Notifications
You must be signed in to change notification settings - Fork 12
refactor: update sampling evaluation logic #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughReplaces frequency-distribution baseline with thresholding-aware expected probabilities, adds per-batch valid-token mask computation (top_k/top_p/combined), increases sampling to 500k trials with batched handling, and validates samples using per-batch TVD and expanded error reporting. Changes
Sequence DiagramsequenceDiagram
participant Input as Input Probabilities
participant Mask as Valid Sampling Mask
participant Masked as Masked Probabilities
participant Sampler as Sampler (500k trials)
participant Validator as Correctness Validator
participant Stats as Per-Batch Stats
Input->>Mask: probs, method, params
activate Mask
Note over Mask: compute per-batch boolean mask\n(top_k / top_p / combined, tie & eps)
Mask-->>Masked: mask
deactivate Mask
Input->>Masked: apply mask -> expected_probs
activate Masked
Masked-->>Sampler: expected_probs (batched, padded)
deactivate Masked
Sampler->>Sampler: generate large-sample counts\n(500,000 trials, batched)
Sampler-->>Validator: sample frequency distributions
Validator->>Validator: validate each sample\nagainst per-batch mask
Validator-->>Stats: compute per-batch TVD, abs/rel errors
activate Stats
Note over Stats: aggregate max TVD, max errors
Stats-->>Validator: metrics
deactivate Stats
Validator-->>Input: correctness payload\n(expected_probs, per-batch TVD list, max errors)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Poem
Pre-merge checks and finishing touches✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @zanderjiang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the sampling evaluation logic to enhance its robustness and accuracy, particularly for scenarios involving batched inputs. The core changes involve moving from a single-dimension frequency distribution to per-input probability distributions for TVD calculation, ensuring that the evaluation correctly reflects the behavior of each batch element. It also introduces a more precise method for validating sampled tokens against Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the sampling evaluation logic to correctly handle batched inputs and improve efficiency. The changes replace the previous frequency distribution computation with a more robust method that calculates expected probabilities and validates samples against a generated mask. The logic for collecting sample distributions is also updated to efficiently handle large numbers of trials by repeating inputs.
My review identifies two potential issues: a critical bug that could lead to a ZeroDivisionError when the input batch size is zero, and a high-severity issue with incorrect handling of scalar sample outputs in batch mode. I've provided suggestions to fix both of these to make the evaluation logic more robust.
| original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1 | ||
| vocab_size = inputs["probs"].shape[-1] | ||
| counter = torch.zeros(vocab_size, dtype=torch.int64, device=torch.device(device)) | ||
|
|
||
| trials_needed = (num_trials + batch_size - 1) // batch_size | ||
| total_samples_collected = 0 | ||
|
|
||
| # Repeat entire input batch to fill up to target_batch_size for efficient sampling | ||
| target_batch_size = 10000 | ||
| repeat_count = target_batch_size // original_batch_size | ||
| actual_batch_size = repeat_count * original_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a potential ZeroDivisionError here. If inputs["probs"] has a shape like (0, vocab_size), original_batch_size will be 0, causing a crash on line 316 when calculating repeat_count.
Additionally, if original_batch_size is larger than target_batch_size, repeat_count will be 0, leading to an actual_batch_size of 0. This will create 0-sized tensors and likely cause issues in the runnable.
I suggest handling the original_batch_size == 0 case explicitly and ensuring repeat_count is at least 1 to prevent these issues.
original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
vocab_size = inputs["probs"].shape[-1]
if original_batch_size == 0:
return torch.empty((0, vocab_size), dtype=torch.float32, device=torch.device(device))
# Repeat entire input batch to fill up to target_batch_size for efficient sampling
target_batch_size = 10000
repeat_count = max(1, target_batch_size // original_batch_size)
actual_batch_size = repeat_count * original_batch_size| if samples.dim() == 0: | ||
| # Single sample - assign to first batch element | ||
| sample_idx = samples.item() | ||
| counter[sample_idx] += 1 | ||
| total_samples_collected += 1 | ||
| else: # Batch of samples | ||
| for i in range(samples.numel()): | ||
| sample_idx = samples.flatten()[i].item() | ||
| counter[sample_idx] += 1 | ||
| total_samples_collected += 1 | ||
|
|
||
| frequency = counter.float() / total_samples_collected | ||
| return frequency | ||
|
|
||
|
|
||
| def _check_thresholding( | ||
| samples: torch.Tensor, probs: torch.Tensor, method: str, params: Dict[str, Any] | ||
| ) -> bool: | ||
| """Check if samples conform to the specified thresholding method. | ||
| Parameters | ||
| ---------- | ||
| samples : torch.Tensor | ||
| Sampled token indices. | ||
| probs : torch.Tensor | ||
| Probability distribution used for sampling. | ||
| method : str | ||
| Thresholding method: "top_k", "top_p", "top_k_top_p", or "none". | ||
| params : Dict[str, Any] | ||
| Sampling parameters (top_k, top_p values). | ||
| Returns | ||
| ------- | ||
| bool | ||
| True if samples are valid, False otherwise. | ||
| """ | ||
| batch_size, vocab_size = probs.shape | ||
| device = probs.device | ||
|
|
||
| for i in range(batch_size): | ||
| prob_row = probs[i] | ||
| sample = samples[i].item() | ||
|
|
||
| if method == "top_k": | ||
| if "top_k" not in params: | ||
| raise ValueError("top_k parameter is required for top_k thresholding but not found") | ||
| k = ( | ||
| int(params["top_k"][i].item()) | ||
| if params["top_k"].dim() > 0 | ||
| else int(params["top_k"].item()) | ||
| ) | ||
|
|
||
| if 0 < k < vocab_size: | ||
| sorted_prob_desc, _ = torch.sort(prob_row, descending=True) | ||
| pivot = sorted_prob_desc[k - 1] | ||
| mask_top_k = (prob_row >= pivot).int() | ||
| if mask_top_k[sample] != 1: | ||
| return False | ||
|
|
||
| elif method == "top_p": | ||
| if "top_p" not in params: | ||
| raise ValueError("top_p parameter is required for top_p thresholding but not found") | ||
| p = ( | ||
| float(params["top_p"][i].item()) | ||
| if params["top_p"].dim() > 0 | ||
| else float(params["top_p"].item()) | ||
| ) | ||
|
|
||
| if 0 < p < 1: | ||
| eps = 1e-4 # numerical stability | ||
| sorted_probs, indices = torch.sort(prob_row, descending=False) | ||
| cdf = torch.cumsum(sorted_probs, dim=0) | ||
| valid_mask = cdf > (1 - p) - eps | ||
| valid_indices = indices[valid_mask] | ||
|
|
||
| if sample not in valid_indices: | ||
| return False | ||
|
|
||
| elif method == "top_k_top_p": | ||
| if "top_k" not in params or "top_p" not in params: | ||
| raise ValueError( | ||
| "top_k and top_p parameters are both required for top_k_top_p thresholding but not found" | ||
| ) | ||
| k = ( | ||
| int(params["top_k"][i].item()) | ||
| if params["top_k"].dim() > 0 | ||
| else int(params["top_k"].item()) | ||
| ) | ||
| p = ( | ||
| float(params["top_p"][i].item()) | ||
| if params["top_p"].dim() > 0 | ||
| else float(params["top_p"].item()) | ||
| ) | ||
|
|
||
| if 0 < k < vocab_size: | ||
| sorted_prob_desc, _ = torch.sort(prob_row, descending=True) | ||
| pivot = sorted_prob_desc[k - 1] | ||
| mask_top_k = (prob_row >= pivot).int() | ||
| else: | ||
| mask_top_k = torch.ones(vocab_size, dtype=torch.int32, device=device) | ||
|
|
||
| if 0 < p < 1: | ||
| eps = 1e-4 | ||
| sorted_probs_asc, indices = torch.sort(prob_row, descending=False) | ||
| cdf = torch.cumsum(sorted_probs_asc, dim=0) | ||
| mask_top_p = torch.zeros(vocab_size, dtype=torch.int32, device=device) | ||
| valid_p_mask = cdf > (1 - p) - eps | ||
| mask_top_p[indices[valid_p_mask]] = 1 | ||
| else: | ||
| mask_top_p = torch.ones(vocab_size, dtype=torch.int32, device=device) | ||
|
|
||
| joint_mask = torch.minimum(mask_top_k, mask_top_p) | ||
|
|
||
| if joint_mask[sample] != 1: | ||
| return False | ||
|
|
||
| return True | ||
| counters[0, sample_idx] += 1 | ||
| total_samples_per_batch += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When samples.dim() == 0, the code assumes a batch size of 1 and assigns the sample to the first batch element's counter. This is incorrect if original_batch_size > 1, as it would misattribute samples and lead to an incorrect frequency distribution for all batch items. The runnable should be expected to return a batch of samples matching actual_batch_size. If it returns a scalar when a batch is expected, it's a contract violation that should be flagged with an error.
if samples.dim() == 0:
if actual_batch_size != 1:
raise ValueError(
f"Expected a batch of samples (size {actual_batch_size}), but got a scalar."
)
# Single sample - assign to first batch element
sample_idx = samples.item()
counters[0, sample_idx] += 1
total_samples_per_batch += 1There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer_bench/bench/evaluators/sampling.py (1)
120-121: cuda synchronize called with a string; also guard for non‑CUDA devices.torch.cuda.synchronize expects a CUDA device or index. Passing a str like "cuda:0" may fail; on CPU it always fails. Convert to torch.device and guard by type.
- torch.cuda.synchronize(device) + _dev = torch.device(device) + if _dev.type == "cuda": + torch.cuda.synchronize(_dev)
🧹 Nitpick comments (9)
flashinfer_bench/bench/evaluators/sampling.py (9)
127-129: normalize_outputs receives device=str here but torch.device elsewhere. Make consistent.Pass torch.device(device) for consistency and to avoid downstream type assumptions.
- out_normalized = normalize_outputs( - out, device=device, output_names=output_names, output_dtypes=output_dtypes - ) + out_normalized = normalize_outputs( + out, device=torch.device(device), + output_names=output_names, output_dtypes=output_dtypes + )
224-234: Thresholding method detection should use params, not name.Relying on defn.name can drift from runtime params; infer from presence/values of top_k/top_p.
-def _detect_thresholding_method(defn: Definition) -> str: - name = defn.name.lower() - if "top_k_top_p" in name: - return "top_k_top_p" - elif "top_k" in name: - return "top_k" - elif "top_p" in name: - return "top_p" - else: - return "none" # no thresholding +def _detect_thresholding_method(defn: Definition, params: Optional[Dict[str, Any]] = None) -> str: + params = params or {} + has_k = "top_k" in params and (int(params["top_k"].item()) if isinstance(params["top_k"], torch.Tensor) and params["top_k"].dim()==0 else True) + has_p = "top_p" in params + if has_k and has_p: + return "top_k_top_p" + if has_k: + return "top_k" + if has_p: + return "top_p" + return "none"And update call sites to pass params.
254-271: Ruff TRY003: long exception messages.Trim or refactor messages to constants to satisfy TRY003.
- raise ValueError(f"top_k parameter required for {method} but not found") + raise ValueError("missing required parameter: top_k") ... - raise ValueError(f"top_p parameter required for {method} but not found") + raise ValueError("missing required parameter: top_p")Also applies to: 273-276
349-351: total_samples_per_batch accounting is fragile for unusual outputs; compute from observed shapes.Derive per‑iteration contribution from samples.numel() to avoid assumptions.
- trials_needed = (num_trials + actual_batch_size - 1) // actual_batch_size - total_samples_per_batch = 0 + trials_needed = (num_trials + actual_batch_size - 1) // actual_batch_size + total_samples_per_batch = 0 @@ - else: - # slice and accumulate per original batch element - samples_flat = samples.flatten() - for i in range(samples_flat.numel()): - batch_idx = i % original_batch_size - sample_idx = samples_flat[i].item() - counters[batch_idx, sample_idx] += 1 - total_samples_per_batch += repeat_count + else: + # slice and accumulate per original batch element + samples_flat = samples.flatten() + for i in range(samples_flat.numel()): + batch_idx = i % original_batch_size + sample_idx = int(samples_flat[i]) + counters[batch_idx, sample_idx] += 1 + total_samples_per_batch += samples_flat.numel() // original_batch_sizeAlso applies to: 380-381
352-361: Move invariant computations out of the sampling loop.output_names/output_dtypes don’t change per iteration; compute once before the loop for speed.
- for _ in range(trials_needed): - with torch.no_grad(): - out = runnable(**padded_inputs) - - output_names = list(defn.outputs.keys()) - output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()} + output_names = list(defn.outputs.keys()) + output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()} + for _ in range(trials_needed): + with torch.no_grad(): + out = runnable(**padded_inputs)
365-377: Count samples with vectorized bincount per batch to reduce Python loops.This greatly speeds up 500k trials.
- samples_flat = samples.flatten() - for i in range(samples_flat.numel()): - batch_idx = i % original_batch_size - sample_idx = samples_flat[i].item() - counters[batch_idx, sample_idx] += 1 - total_samples_per_batch += repeat_count + samples = samples.view(-1) # [actual_batch_size] + # reshape to [repeat_count, original_batch_size] if divisible + if samples.numel() % original_batch_size == 0: + reshaped = samples.view(-1, original_batch_size) + for b in range(original_batch_size): + counts = torch.bincount(reshaped[:, b], minlength=vocab_size) + counters[b] += counts + total_samples_per_batch += reshaped.size(0) + else: + # fallback to scalar loop (rare) + for i in range(samples.numel()): + counters[i % original_batch_size, int(samples[i])] += 1 + total_samples_per_batch += samples.numel() // original_batch_size
352-355: Optional: synchronize inside the sampling loop only when using CUDA.Keeps timing/state consistent if kernels are async.
- with torch.no_grad(): - out = runnable(**padded_inputs) + with torch.no_grad(): + out = runnable(**padded_inputs) + _dev = torch.device(device) + if _dev.type == "cuda": + torch.cuda.synchronize(_dev)
52-60: Nit: normalize logits once; consider explicit dtype.If inp["probs"] is already probs, softmax again can skew. Gate behind a flag or defn meta.
- if "probs" in inp: + if "probs" in inp and defn.inputs.get("probs", {}).get("is_logits", True): inp["probs"] = torch.softmax( inp["probs"], dim=-1 ) # convert logits to probs for sampling
246-302: Mask construction looks correct; consider vectorizing top_k/top_p over batch.Current per‑row loops are fine functionally; vectorization would simplify and speed up.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer_bench/bench/evaluators/sampling.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer_bench/bench/evaluators/sampling.py (4)
flashinfer_bench/utils.py (1)
dtype_str_to_torch_dtype(39-45)flashinfer_bench/data/trace.py (2)
Correctness(72-95)EvaluationStatus(126-146)flashinfer_bench/bench/utils.py (2)
make_eval(261-279)compute_error_stats(89-116)flashinfer_bench/compile/runnable.py (1)
Runnable(6-38)
🪛 GitHub Actions: .github/workflows/linting.yaml
flashinfer_bench/bench/evaluators/sampling.py
[error] 1-1: trailing-whitespace: Hooks detected and fixed trailing whitespace. 1 file was modified by this hook.
[error] 1-1: black: Reformatted 1 file(s). All done! 1 file reformatted, 101 files left unchanged. Files were modified by this hook.
🪛 Ruff (0.14.1)
flashinfer_bench/bench/evaluators/sampling.py
256-256: Avoid specifying long messages outside the exception class
(TRY003)
275-275: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.10
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.9
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
🔇 Additional comments (1)
flashinfer_bench/bench/evaluators/sampling.py (1)
1-5: No issues found; review comment is accurate.The file contains no trailing whitespace and is properly formatted per Black standards. The .pre-commit-config.yaml is configured with the appropriate hooks (trailing-whitespace, black, isort). The review comment correctly identifies that formatting has been applied and appropriately advises running hooks locally before pushing—this is informational guidance with no outstanding fixes needed.
| thresholding_method = _detect_thresholding_method(defn) | ||
| params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp} | ||
| valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params) | ||
|
|
||
| masked_probs = inp["probs"] * valid_mask.float() | ||
| expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True) | ||
|
|
||
| outputs.append({"expected_probs": expected_probs}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against zero-sum masked probs to avoid NaNs.
If the valid_mask zeros out all tokens (edge params), masked_probs.sum can be 0 leading to NaNs in expected_probs. Clamp the denominator.
- masked_probs = inp["probs"] * valid_mask.float()
- expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
+ masked_probs = inp["probs"] * valid_mask.float()
+ denom = masked_probs.sum(dim=-1, keepdim=True)
+ # Avoid NaNs if no tokens survive; fall back to uniform over valid_mask
+ denom = torch.where(denom > 0, denom, torch.ones_like(denom))
+ expected_probs = masked_probs / denom
+ # If denom was 0, distribute uniformly across valid tokens
+ zero_rows = (masked_probs.sum(dim=-1, keepdim=True) == 0)
+ if zero_rows.any():
+ uniform = valid_mask.float() / valid_mask.float().sum(dim=-1, keepdim=True).clamp_min(1)
+ expected_probs = torch.where(zero_rows, uniform, expected_probs)🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 62 to 69, the code
divides by masked_probs.sum which can be zero if valid_mask zeros out all
tokens; change the denominator to a clamped value to avoid NaNs by computing the
sum with keepdim=True and then applying .clamp_min(eps) (use a small constant
like 1e-12) before dividing so expected_probs = masked_probs / denom_clamped,
ensuring you preserve shapes and device when creating eps.
| if not valid_mask[batch_idx, sample_idx]: | ||
| correctness = Correctness( | ||
| max_relative_error=float("inf"), max_absolute_error=float("inf") | ||
| ) | ||
| message = ( | ||
| f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}" | ||
| ) | ||
| print(message, file=sys.stderr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid truthiness on 0‑dim Torch bool Tensor.
Using if not valid_mask[...] can raise “Boolean value of Tensor is ambiguous.” Convert to Python bool.
- if not valid_mask[batch_idx, sample_idx]:
+ if not bool(valid_mask[batch_idx, sample_idx].item()):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if not valid_mask[batch_idx, sample_idx]: | |
| correctness = Correctness( | |
| max_relative_error=float("inf"), max_absolute_error=float("inf") | |
| ) | |
| message = ( | |
| f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}" | |
| ) | |
| print(message, file=sys.stderr) | |
| if not bool(valid_mask[batch_idx, sample_idx].item()): | |
| correctness = Correctness( | |
| max_relative_error=float("inf"), max_absolute_error=float("inf") | |
| ) | |
| message = ( | |
| f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}" | |
| ) | |
| print(message, file=sys.stderr) |
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 159 to 166, the
condition uses truthiness on a 0‑dim Torch bool Tensor (if not
valid_mask[batch_idx, sample_idx]) which can raise “Boolean value of Tensor is
ambiguous.” Convert the tensor to a Python bool by calling .item() (and .cpu()
if it may be on GPU) before negation, and use that boolean in the if check so
the branch evaluates correctly.
| sol_freqs = _sample_token_distributions( | ||
| sol_runnable, inp, device, defn, num_trials=500000 | ||
| ) | ||
| torch.cuda.synchronize(device) | ||
| except Exception: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second cuda synchronize has same issue; apply the same guard.
Mirror the earlier fix after sampling distributions.
- torch.cuda.synchronize(device)
+ _dev = torch.device(device)
+ if _dev.type == "cuda":
+ torch.cuda.synchronize(_dev)🧰 Tools
🪛 Ruff (0.14.1)
179-179: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 175 to 179, the
second torch.cuda.synchronize(device) call needs the same protective guard as
the earlier synchronize to avoid raising on non-CUDA or unavailable CUDA
devices; wrap the synchronize in the same conditional/try-except used previously
(e.g., only call if device.type == "cuda" and torch.cuda.is_available(), or call
inside a try: ... except Exception: pass block) so any CUDA sync errors are
handled identically to the earlier fix.
| numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold | ||
| correctness = Correctness( | ||
| max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd} | ||
| max_relative_error=max_rel, | ||
| max_absolute_error=max_abs, | ||
| extra={"tvd": max_tvd, "tvds_per_batch": tvds} | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Protect against missing cfg.sampling_tvd_threshold.
If cfg.sampling_tvd_threshold is None/absent, comparison raises. Default or fail fast with message.
- numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
+ tvd_thresh = getattr(cfg, "sampling_tvd_threshold", None)
+ if tvd_thresh is None:
+ raise ValueError("cfg.sampling_tvd_threshold must be set for sampling evaluation")
+ numerical_incorrect = max_tvd > tvd_thresh📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold | |
| correctness = Correctness( | |
| max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd} | |
| max_relative_error=max_rel, | |
| max_absolute_error=max_abs, | |
| extra={"tvd": max_tvd, "tvds_per_batch": tvds} | |
| ) | |
| tvd_thresh = getattr(cfg, "sampling_tvd_threshold", None) | |
| if tvd_thresh is None: | |
| raise ValueError("cfg.sampling_tvd_threshold must be set for sampling evaluation") | |
| numerical_incorrect = max_tvd > tvd_thresh | |
| correctness = Correctness( | |
| max_relative_error=max_rel, | |
| max_absolute_error=max_abs, | |
| extra={"tvd": max_tvd, "tvds_per_batch": tvds} | |
| ) |
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 203 to 208, the
code compares max_tvd against cfg.sampling_tvd_threshold without handling the
case where that config value is None or missing; add a guard before the
comparison that checks if cfg.sampling_tvd_threshold is None and fail fast with
a clear ValueError (or optionally set a documented default threshold) so the
comparison never raises a TypeError; update the subsequent comparison to use the
validated threshold variable.
| target_batch_size = 10000 | ||
| repeat_count = target_batch_size // original_batch_size | ||
| actual_batch_size = repeat_count * original_batch_size | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
repeat_count can be 0 when original_batch_size > 10_000 → division by zero and empty batches.
Ensure at least one repeat; prefer ceil to keep high utilization.
- target_batch_size = 10000
- repeat_count = target_batch_size // original_batch_size
- actual_batch_size = repeat_count * original_batch_size
+ import math
+ target_batch_size = 10000
+ repeat_count = max(1, math.ceil(target_batch_size / original_batch_size))
+ actual_batch_size = repeat_count * original_batch_sizeAlso consider making target_batch_size configurable.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| target_batch_size = 10000 | |
| repeat_count = target_batch_size // original_batch_size | |
| actual_batch_size = repeat_count * original_batch_size | |
| import math | |
| target_batch_size = 10000 | |
| repeat_count = max(1, math.ceil(target_batch_size / original_batch_size)) | |
| actual_batch_size = repeat_count * original_batch_size |
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 315-318, the
current computation uses integer division target_batch_size //
original_batch_size which can yield 0 when original_batch_size > 10000 (leading
to empty batches) and underutilization; change to use a ceiling division and
ensure at least one repeat: compute repeat_count = max(1, ceil(target_batch_size
/ original_batch_size)) (or equivalent integer math), then set actual_batch_size
= repeat_count * original_batch_size, and expose target_batch_size as a
configurable parameter (with validation to be a positive int) so it can be tuned
instead of hardcoding 10000.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (6)
flashinfer_bench/bench/evaluators/sampling.py (6)
172-172: Guard CUDA synchronization for non-CUDA devices.
torch.cuda.synchronize(device)will raise an error ifdeviceis not CUDA or if CUDA is unavailable. Add a guard to only synchronize for CUDA devices.Apply this diff:
- torch.cuda.synchronize(device) + _dev = torch.device(device) + if _dev.type == "cuda": + torch.cuda.synchronize(_dev)
351-355: Validate runnable contract when receiving scalar samples.If
samples.dim() == 0(scalar) butactual_batch_size > 1, the runnable has violated its contract by returning a single sample when a batch was expected. The current code silently assigns this to the first batch element, which produces incorrect frequency distributions.Apply this diff to catch contract violations:
if samples.dim() == 0: + if actual_batch_size != 1: + raise ValueError( + f"Expected a batch of samples (size {actual_batch_size}), but got a scalar." + ) # Single sample - assign to first batch element sample_idx = samples.item() counters[0, sample_idx] += 1 total_samples_per_batch += 1
62-63: Guard against zero-sum masked probabilities to prevent NaNs.If
valid_maskzeros out all tokens (possible with edge-casetop_k/top_pparameters),masked_probs.sum()will be zero, resulting in NaN values inexpected_probs.Apply this diff to add a safeguard:
masked_probs = inp["probs"] * valid_mask.float() -expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True) +denom = masked_probs.sum(dim=-1, keepdim=True).clamp_min(1e-12) +expected_probs = masked_probs / denom
197-197: Validate cfg.sampling_tvd_threshold before comparison.If
cfg.sampling_tvd_thresholdisNoneor missing, the comparison will raise a TypeError. Validate the attribute exists and has a valid value.Apply this diff:
- numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold + tvd_thresh = getattr(cfg, "sampling_tvd_threshold", None) + if tvd_thresh is None: + raise ValueError("cfg.sampling_tvd_threshold must be set for sampling evaluation") + numerical_incorrect = max_tvd > tvd_thresh
155-155: Convert Tensor to Python bool to avoid ambiguity error.Using
if not valid_mask[batch_idx, sample_idx]directly on a 0-dimensional boolean Tensor can raise "Boolean value of Tensor is ambiguous." Convert to a Python boolean first.Apply this diff:
- if not valid_mask[batch_idx, sample_idx]: + if not valid_mask[batch_idx, sample_idx].item():
297-303: Fix repeat_count calculation to prevent zero-batch and handle large batches.Multiple critical issues:
- If
original_batch_size > target_batch_size(e.g., 15000),repeat_countbecomes 0, leading toactual_batch_size = 0, empty tensors, and division by zero at line 366.- If
inputs["probs"].shape[0]is 0 (empty batch), the fallback to 1 is semantically incorrect.- Integer division underutilizes batch capacity when
original_batch_sizedoesn't evenly dividetarget_batch_size.Apply this diff to ensure at least one repeat and handle edge cases:
- original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1 + original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1 + + if original_batch_size == 0: + return torch.empty((0, vocab_size), dtype=torch.float32, device=torch.device(device)) + vocab_size = inputs["probs"].shape[-1] # Repeat entire input batch to fill up to target_batch_size for efficient sampling target_batch_size = 10000 - repeat_count = target_batch_size // original_batch_size + repeat_count = max(1, target_batch_size // original_batch_size) actual_batch_size = repeat_count * original_batch_size
🧹 Nitpick comments (1)
flashinfer_bench/bench/evaluators/sampling.py (1)
250-250: Consider using custom exception classes for complex error messages.Static analysis suggests avoiding long messages in exception constructors. While the current approach works, custom exception classes can improve maintainability for complex error scenarios.
Based on static analysis hints.
Also applies to: 265-265
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer_bench/bench/evaluators/sampling.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer_bench/bench/evaluators/sampling.py (4)
flashinfer_bench/utils.py (1)
dtype_str_to_torch_dtype(39-45)flashinfer_bench/data/trace.py (2)
Correctness(72-95)EvaluationStatus(126-146)flashinfer_bench/bench/utils.py (2)
make_eval(261-279)compute_error_stats(89-116)flashinfer_bench/compile/runnable.py (1)
Runnable(6-38)
🪛 Ruff (0.14.3)
flashinfer_bench/bench/evaluators/sampling.py
250-250: Avoid specifying long messages outside the exception class
(TRY003)
265-265: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.13
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.10
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
🔇 Additional comments (3)
flashinfer_bench/bench/evaluators/sampling.py (3)
179-195: Well-structured per-batch TVD and error computation.The implementation correctly computes TVD and error statistics for each batch element independently, then aggregates the worst-case metrics. This approach properly handles batch-wise validation and provides detailed diagnostics.
230-287: Robust thresholding mask implementation with proper tie-breaking.The
_compute_valid_sampling_maskfunction correctly implements:
- Top-k tie-breaking by including all tokens with probability ≥ k-th largest
- Top-p epsilon tolerance for numerical precision
- Proper handling of edge cases (k=0, k≥vocab_size, p=0, p≥1)
- Batch-wise mask computation with correct indexing
305-329: Proper input batching with dimension-aware repetition.The input padding logic correctly handles different tensor types:
- Repeats probability tensors along batch dimension
- Properly expands scalar and batched sampling parameters
- Maintains correct shapes for multi-dimensional inputs
This PR updates several issues with the previous sampling evaluation logic:
The previous version compresses all input probs into single dim frequencies, this introduces vulnerabilities when input tensor's batchsize > 1, this PR addresses by retaining the input shape for sampled token distributions.
For sampled tokens, we compute per input probability distribution TVD against the ground truth. The Evaluation class will record the worst (max) TVD amongst all input batch elements.
To reduce correctness sampling iterations, we repeat the original input tensor for 10,000 // original_batch_size times, this still allows us to sample the non-deterministic kernel while running fewer forward passes to reduce benchmarking time.
Summary by CodeRabbit
New Features
Bug Fixes
Refactor