diff --git a/.gitignore b/.gitignore index bf0f947..4e275fe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,14 @@ +# Environment files +.env +.env.* + +# Local caches and artifacts +.cache/ +.prompt_history/ +**/bin*.json +**/bin*.json.lock +*.log + # Python stuff *.egg-info .ruff_cache diff --git a/safetytooling/apis/finetuning/openai/check.py b/safetytooling/apis/finetuning/openai/check.py index bb73394..2191f0a 100644 --- a/safetytooling/apis/finetuning/openai/check.py +++ b/safetytooling/apis/finetuning/openai/check.py @@ -149,7 +149,7 @@ def openai_check_finetuning_data( + f"which is over the max context length of {context_length}." + "It will be truncated during fine-tuning" ) - prompts[0].pretty_print([], print_fn=lambda x, *_, **__: logger.warn(x)) + prompts[0].pretty_print([], print_fn=lambda x, *_, **__: logger.warning(x)) # batch_size set to 0.2% of dataset size by default # https://community.openai.com/t/why-is-the-default-batch-size-set-to-1-for-fine-tuning-the-chatgpt-turbo-model/513129 diff --git a/safetytooling/apis/inference/anthropic.py b/safetytooling/apis/inference/anthropic.py index b6bbbf9..69aee23 100644 --- a/safetytooling/apis/inference/anthropic.py +++ b/safetytooling/apis/inference/anthropic.py @@ -40,6 +40,7 @@ def __init__( num_threads: int, prompt_history_dir: Path | None = None, anthropic_api_key: str | None = None, + progress_monitor: object | None = None, ): self.num_threads = num_threads self.prompt_history_dir = prompt_history_dir @@ -50,6 +51,7 @@ def __init__( self.available_requests = asyncio.BoundedSemaphore(int(self.num_threads)) self.kwarg_change_name = {"stop": "stop_sequences"} + self.progress_monitor = progress_monitor async def __call__( self, @@ -107,7 +109,7 @@ async def __call__( error_info = ( f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}" ) - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") error_list.append(error_info) api_duration = time.time() - api_start await asyncio.sleep(1.5**i) @@ -208,6 +210,8 @@ async def __call__( if print_prompt_and_response: prompt.pretty_print(responses) + # Progress monitoring is handled centrally in InferenceAPI to avoid double counting. + return responses def make_stream_api_call( diff --git a/safetytooling/apis/inference/api.py b/safetytooling/apis/inference/api.py index 519872e..a2918ab 100644 --- a/safetytooling/apis/inference/api.py +++ b/safetytooling/apis/inference/api.py @@ -30,6 +30,8 @@ from .anthropic import ANTHROPIC_MODELS, AnthropicChatModel from .cache_manager import BaseCacheManager, get_cache_manager +from .openai.base import OpenAIModel +from .progress_monitor import RateLimitProgressMonitor from .gemini.genai import GeminiModel from .gemini.vertexai import GeminiVertexAIModel from .gray_swan import GRAYSWAN_MODELS, GraySwanChatModel @@ -90,6 +92,7 @@ def __init__( vllm_base_url: str = "http://localhost:8000/v1/chat/completions", no_cache: bool = False, oai_embedding_batch_size: int = 2048, + show_progress: bool = False, ): """ Set prompt_history_dir to None to disable saving prompt history. @@ -153,11 +156,14 @@ def __init__( self.cache_manager = get_cache_manager(self.cache_dir, self.use_redis) print(f"{self.cache_manager=}") + self.progress_monitor = RateLimitProgressMonitor(disable=not show_progress) + self._openai_completion = OpenAICompletionModel( frac_rate_limit=self.openai_fraction_rate_limit, prompt_history_dir=self.prompt_history_dir, base_url=self.openai_base_url, openai_api_key=openai_api_key, + progress_monitor=self.progress_monitor, ) self._openai_chat = OpenAIChatModel( @@ -165,6 +171,7 @@ def __init__( prompt_history_dir=self.prompt_history_dir, base_url=self.openai_base_url, openai_api_key=openai_api_key, + progress_monitor=self.progress_monitor, ) self._openai_moderation = OpenAIModerationModel() @@ -176,30 +183,35 @@ def __init__( num_threads=self.anthropic_num_threads, prompt_history_dir=self.prompt_history_dir, anthropic_api_key=anthropic_api_key, + progress_monitor=self.progress_monitor, ) self._huggingface = HuggingFaceModel( num_threads=self.huggingface_num_threads, prompt_history_dir=self.prompt_history_dir, token=os.environ.get("HF_TOKEN", None), + progress_monitor=self.progress_monitor, ) self._gray_swan = GraySwanChatModel( num_threads=self.gray_swan_num_threads, prompt_history_dir=self.prompt_history_dir, - api_key=os.environ.get("GRAYSWAN_API_KEY", None), + api_key=(os.environ.get("GRAYSWAN_API_KEY") or None), + progress_monitor=self.progress_monitor, ) self._together = TogetherChatModel( num_threads=self.together_num_threads, prompt_history_dir=self.prompt_history_dir, - api_key=os.environ.get("TOGETHER_API_KEY", None), + api_key=(os.environ.get("TOGETHER_API_KEY") or None), + progress_monitor=self.progress_monitor, ) self._openrouter = OpenRouterChatModel( num_threads=self.openrouter_num_threads, prompt_history_dir=self.prompt_history_dir, - api_key=os.environ.get("OPENROUTER_API_KEY", None), + api_key=(os.environ.get("OPENROUTER_API_KEY") or None), + progress_monitor=self.progress_monitor, ) self._gemini_vertex = GeminiVertexAIModel(prompt_history_dir=self.prompt_history_dir) @@ -214,7 +226,8 @@ def __init__( num_threads=vllm_num_threads, prompt_history_dir=self.prompt_history_dir, vllm_base_url=self.vllm_base_url, - runpod_api_key=os.environ.get("RUNPOD_API_KEY", None), + runpod_api_key=(os.environ.get("RUNPOD_API_KEY") or None), + progress_monitor=self.progress_monitor, ) # DeepSeek uses the OpenAI API @@ -223,6 +236,7 @@ def __init__( prompt_history_dir=self.prompt_history_dir, base_url=DEEPSEEK_BASE_URL, openai_api_key=os.environ.get("DEEPSEEK_API_KEY", None), + progress_monitor=self.progress_monitor, ) self._batch_models = {} @@ -643,6 +657,45 @@ async def __call__( self.model_timings.setdefault(response.model_id, []).append(response.api_duration) self.model_wait_times.setdefault(response.model_id, []).append(response.duration - response.api_duration) + # Update progress monitor with usage info + if hasattr(self, "progress_monitor") and self.progress_monitor is not None: + try: + total_in_tokens = 0 + total_out_tokens = 0 + for response in candidate_responses: + usage = getattr(response, "usage", None) + if usage is not None: + total_in_tokens += int(getattr(usage, "input_tokens", 0) or 0) + total_out_tokens += int(getattr(usage, "output_tokens", 0) or 0) + + # Register caps if OpenAI and resources are available + if isinstance(model_class, OpenAIModel): + req_res = getattr(model_class, "request_capacity", {}).get(model_id, None) + tok_res = getattr(model_class, "token_capacity", {}).get(model_id, None) + if req_res is not None and tok_res is not None: + self.progress_monitor.register_openai_model(model_id, req_res, tok_res) + else: + # requests-only when no resources are present + self.progress_monitor.register_generic_model(model_id, show_token_bars=False) + else: + # Only requests unless exact token usage is known + show_token_bars = total_in_tokens > 0 or total_out_tokens > 0 + self.progress_monitor.register_generic_model(model_id, show_token_bars=show_token_bars) + + # Increment requests by the number of real API calls made in this branch + request_increment = 1 + if isinstance(model_class, AnthropicChatModel) or isinstance(model_class, HuggingFaceModel) or isinstance(model_class, GeminiModel) or isinstance(model_class, GeminiVertexAIModel): + request_increment = num_candidates + + await self.progress_monitor.update_usage( + model_id=model_id, + input_tokens=(total_in_tokens if (total_in_tokens > 0) else None), + output_tokens=(total_out_tokens if (total_out_tokens > 0) else None), + request_increment=request_increment, + ) + except Exception: + pass + return responses async def ask_single_question( diff --git a/safetytooling/apis/inference/gemini/genai.py b/safetytooling/apis/inference/gemini/genai.py index 1991200..cc69895 100644 --- a/safetytooling/apis/inference/gemini/genai.py +++ b/safetytooling/apis/inference/gemini/genai.py @@ -304,7 +304,7 @@ async def attempt_api_call(model_id): raise e except Exception as e: error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}" - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying in {1.5**i} seconds. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying in {1.5**i} seconds. (Attempt {i})") await asyncio.sleep(1.5**i) else: break diff --git a/safetytooling/apis/inference/gemini/vertexai.py b/safetytooling/apis/inference/gemini/vertexai.py index d0dbbf2..4176ea6 100644 --- a/safetytooling/apis/inference/gemini/vertexai.py +++ b/safetytooling/apis/inference/gemini/vertexai.py @@ -254,7 +254,7 @@ async def attempt_api_call(model_id): except Exception as e: error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}" - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") await asyncio.sleep(1.5**i) else: break diff --git a/safetytooling/apis/inference/gray_swan.py b/safetytooling/apis/inference/gray_swan.py index 2ddb99e..4f33eae 100644 --- a/safetytooling/apis/inference/gray_swan.py +++ b/safetytooling/apis/inference/gray_swan.py @@ -20,6 +20,7 @@ def __init__( num_threads: int, prompt_history_dir: Path | None = None, api_key: str | None = None, + progress_monitor: object | None = None, ): self.num_threads = num_threads self.prompt_history_dir = prompt_history_dir @@ -32,6 +33,7 @@ def __init__( else: self.aclient = None self.available_requests = asyncio.BoundedSemaphore(int(self.num_threads)) + self.progress_monitor = progress_monitor async def __call__( self, @@ -71,7 +73,7 @@ async def __call__( raise e except Exception as e: error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}" - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") error_list.append(error_info) api_duration = time.time() - api_start await asyncio.sleep(1.5**i) diff --git a/safetytooling/apis/inference/huggingface.py b/safetytooling/apis/inference/huggingface.py index e93129a..8ca6442 100644 --- a/safetytooling/apis/inference/huggingface.py +++ b/safetytooling/apis/inference/huggingface.py @@ -34,6 +34,7 @@ def __init__( num_threads: int, token: str, prompt_history_dir: Path | None = None, + progress_monitor: object | None = None, ): self.num_threads = num_threads self.prompt_history_dir = prompt_history_dir @@ -56,6 +57,7 @@ def __init__( } self.tokenizers = {} + self.progress_monitor = progress_monitor def count_tokens(self, text, model_name): if model_name not in self.tokenizers: @@ -169,10 +171,10 @@ async def __call__( raise e if "503 Service Unavailable" in str(e): - LOGGER.warn(f"503 Service Unavailable error. Waiting 60 seconds before retrying. (Attempt {i})") + LOGGER.warning(f"503 Service Unavailable error. Waiting 60 seconds before retrying. (Attempt {i})") await asyncio.sleep(60) else: - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") await asyncio.sleep(1.5**i) else: break @@ -208,4 +210,6 @@ async def __call__( if print_prompt_and_response: prompt.pretty_print(responses) + # Progress monitoring is handled centrally in InferenceAPI to avoid double counting. + return responses diff --git a/safetytooling/apis/inference/openai/base.py b/safetytooling/apis/inference/openai/base.py index 18d10ea..ed10179 100644 --- a/safetytooling/apis/inference/openai/base.py +++ b/safetytooling/apis/inference/openai/base.py @@ -61,12 +61,14 @@ def __init__( prompt_history_dir: Path | None = None, base_url: str | None = None, openai_api_key: str | None = None, + progress_monitor: object | None = None, ): self.frac_rate_limit = frac_rate_limit self.prompt_history_dir = prompt_history_dir self.model_ids = set() self.base_url = base_url self.openai_api_key = openai_api_key + self.progress_monitor = progress_monitor if openai_api_key: self.aclient = openai.AsyncClient(api_key=openai_api_key, base_url=self.base_url) @@ -125,6 +127,17 @@ async def add_model_id(self, model_id: str): self.token_capacity[model_id] = token_capacity self.request_capacity[model_id] = request_capacity + # Register model with progress monitor + if self.progress_monitor is not None: + try: + self.progress_monitor.register_openai_model( + model_id=model_id, + request_resource=request_capacity, + token_resource=token_capacity, + ) + except Exception: + pass + async def __call__( self, model_id: str, @@ -168,7 +181,7 @@ async def attempt_api_call(): raise e except Exception as e: error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}" - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") await asyncio.sleep(1.5**i) else: break diff --git a/safetytooling/apis/inference/openai/embedding.py b/safetytooling/apis/inference/openai/embedding.py index 436e36b..0a5f909 100644 --- a/safetytooling/apis/inference/openai/embedding.py +++ b/safetytooling/apis/inference/openai/embedding.py @@ -69,7 +69,7 @@ async def embed( error_info = ( f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {traceback.format_exc()}" ) - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") time.sleep(1.5**i) raise RuntimeError(f"Failed to get a response from the API after {max_attempts} attempts.") diff --git a/safetytooling/apis/inference/openai/moderation.py b/safetytooling/apis/inference/openai/moderation.py index f661c10..8e4a3fb 100644 --- a/safetytooling/apis/inference/openai/moderation.py +++ b/safetytooling/apis/inference/openai/moderation.py @@ -53,7 +53,7 @@ async def _single_moderation_request( return response except Exception as e: error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}" - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") time.sleep(1.5**i) raise RuntimeError(f"Failed to get a response from the API after {max_attempts} attempts.") diff --git a/safetytooling/apis/inference/openrouter.py b/safetytooling/apis/inference/openrouter.py index f44604d..8740e69 100644 --- a/safetytooling/apis/inference/openrouter.py +++ b/safetytooling/apis/inference/openrouter.py @@ -25,6 +25,7 @@ def __init__( num_threads: int, prompt_history_dir: Path | None = None, api_key: str | None = None, + progress_monitor: object | None = None, ): self.num_threads = num_threads self.prompt_history_dir = prompt_history_dir @@ -36,6 +37,7 @@ def __init__( else: self.aclient = None self.available_requests = asyncio.BoundedSemaphore(int(self.num_threads)) + self.progress_monitor = progress_monitor @staticmethod def convert_top_logprobs(data) -> list[dict]: @@ -99,7 +101,7 @@ async def __call__( ): # sometimes gemini will never return a response if model_id == "google/gemini-2.0-flash-001": - LOGGER.warn(f"Empty response from {model_id} (returning empty response)") + LOGGER.warning(f"Empty response from {model_id} (returning empty response)") return [ LLMResponse( model_id=model_id, @@ -124,7 +126,7 @@ async def __call__( raise e except Exception as e: error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}" - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") error_list.append(error_info) api_duration = time.time() - api_start await asyncio.sleep(1.5**i) @@ -144,25 +146,53 @@ async def __call__( "n", 1 ), f"Expected {kwargs.get('n', 1)} choices, got {len(response_data.choices)}" - responses = [ - LLMResponse( - model_id=model_id, - completion=choice.message.content, - stop_reason=choice.finish_reason, - api_duration=api_duration, - duration=duration, - cost=0, - logprobs=( - self.convert_top_logprobs(choice.logprobs) - if hasattr(choice, "logprobs") and choice.logprobs is not None - else None - ), + responses = [] + total_in = 0 + total_out = 0 + for choice in response_data.choices: + # OpenRouter may include usage at top-level + usage = getattr(response_data, "usage", None) + if usage is not None: + # usage may be a dict-like object + in_tok = getattr(usage, "prompt_tokens", None) or getattr(usage, "input_tokens", None) + out_tok = getattr(usage, "completion_tokens", None) or getattr(usage, "output_tokens", None) + else: + in_tok = None + out_tok = None + if in_tok is not None: + total_in += int(in_tok) + if out_tok is not None: + total_out += int(out_tok) + + responses.append( + LLMResponse( + model_id=model_id, + completion=choice.message.content, + stop_reason=choice.finish_reason, + api_duration=api_duration, + duration=duration, + cost=0, + logprobs=( + self.convert_top_logprobs(choice.logprobs) + if hasattr(choice, "logprobs") and choice.logprobs is not None + else None + ), + usage=( + None + if usage is None + else type("U", (), { + "input_tokens": int(in_tok) if in_tok is not None else None, + "output_tokens": int(out_tok) if out_tok is not None else None, + "total_tokens": int((in_tok or 0) + (out_tok or 0)), + })() + ), + ) ) - for choice in response_data.choices - ] self.add_response_to_prompt_file(prompt_file, responses) if print_prompt_and_response: prompt.pretty_print(responses) + # Progress monitoring is handled centrally in InferenceAPI to avoid double counting. + return responses diff --git a/safetytooling/apis/inference/progress_monitor.py b/safetytooling/apis/inference/progress_monitor.py new file mode 100644 index 0000000..f3f56d4 --- /dev/null +++ b/safetytooling/apis/inference/progress_monitor.py @@ -0,0 +1,319 @@ +import asyncio +import time +from collections import deque +from typing import Any, Deque, Dict, Optional, Tuple + +from tqdm.auto import tqdm + + +class RateLimitProgressMonitor: + """ + Multi-progress monitor to visualize per-model rate-limit usage. + + For each model we maintain three bars: + - Requests + - Input tokens per minute + - Output tokens per minute + - Total in/out tokens in the current session + + Notes: + - For OpenAI we know the per-minute caps via the Resource objects. We use those + to set bar totals and compute remaining/used values. + - Token usage split (in vs out) is tracked from actual responses over a rolling + 60-second window using deques. This provides an accurate view of where the + token cap is being spent. + """ + + def __init__(self, disable: bool = False) -> None: + self._lock: asyncio.Lock = asyncio.Lock() + self._disable: bool = disable + # model_id -> (request_bar, in_bar, out_bar) + self._bars: Dict[str, Tuple[Optional[tqdm], Optional[tqdm], Optional[tqdm]]] = {} + # model_id -> caps + self._caps: Dict[str, Dict[str, Optional[int]]] = {} + # model_id -> resource refs for OpenAI (may be None for other providers) + self._resources: Dict[str, Dict[str, Any]] = {} + # rolling windows for last-minute accounting + self._events_requests: Dict[str, Deque[float]] = {} + self._events_in_tokens: Dict[str, Deque[Tuple[float, int]]] = {} + self._events_out_tokens: Dict[str, Deque[Tuple[float, int]]] = {} + # cumulative totals since registration + self._total_requests: Dict[str, int] = {} + self._total_in_tokens: Dict[str, int] = {} + self._total_out_tokens: Dict[str, int] = {} + # totals start time (since last reset) + self._totals_start_time: float = time.time() + # position bookkeeping for tqdm multi-bars + self._next_position: int = 0 + + def _ensure_deques(self, model_id: str) -> None: + if model_id not in self._events_requests: + self._events_requests[model_id] = deque() + if model_id not in self._events_in_tokens: + self._events_in_tokens[model_id] = deque() + if model_id not in self._events_out_tokens: + self._events_out_tokens[model_id] = deque() + if model_id not in self._total_requests: + self._total_requests[model_id] = 0 + if model_id not in self._total_in_tokens: + self._total_in_tokens[model_id] = 0 + if model_id not in self._total_out_tokens: + self._total_out_tokens[model_id] = 0 + + @staticmethod + def _prune_old(ts_deque: Deque, window_seconds: int = 60) -> None: + now = time.time() + while ts_deque and (now - (ts_deque[0][0] if isinstance(ts_deque[0], tuple) else ts_deque[0])) >= window_seconds: + ts_deque.popleft() + + def register_openai_model(self, model_id: str, request_resource: Any, token_resource: Any) -> None: + """ + Register an OpenAI model with known rate-limit resources. + + request_resource/token_resource are instances of the OpenAI Resource class with + attributes: refresh_rate (cap per minute) and value (remaining budget that replenishes). + """ + self._ensure_deques(model_id) + + request_cap = int(request_resource.refresh_rate) + token_cap = int(token_resource.refresh_rate) + self._caps[model_id] = {"request_cap": request_cap, "token_cap": token_cap} + self._resources[model_id] = { + "request": request_resource, + "token": token_resource, + } + + if self._disable: + # still track metrics but don't render bars + self._bars[model_id] = (None, None, None) + return + + if model_id in self._bars and all(bar is not None for bar in self._bars[model_id]): + # already registered + return + + # Allocate three lines for this model + req_bar = tqdm( + total=request_cap, + position=self._next_position, + leave=True, + unit="req", + desc=f"{model_id} | requests", + dynamic_ncols=True, + ) + in_bar = tqdm( + total=token_cap, + position=self._next_position + 1, + leave=True, + unit="tok", + desc=f"{model_id} | in tokens", + dynamic_ncols=True, + ) + out_bar = tqdm( + total=token_cap, + position=self._next_position + 2, + leave=True, + unit="tok", + desc=f"{model_id} | out tokens", + dynamic_ncols=True, + ) + + self._bars[model_id] = (req_bar, in_bar, out_bar) + self._next_position += 3 + + def register_generic_model(self, model_id: str, show_token_bars: bool = False) -> None: + """ + Register a model without known caps. + - If show_token_bars is False, only a request bar is created (no token bars displayed). + - If show_token_bars is True, request + token bars are created and grow with observed usage. + """ + self._ensure_deques(model_id) + # caps unknown + self._caps[model_id] = {"request_cap": None, "token_cap": None} + self._resources[model_id] = {} + + if self._disable: + self._bars[model_id] = (None, None, None) + return + + if model_id in self._bars and all(bar is not None for bar in self._bars[model_id]): + return + + req_bar = tqdm( + total=1, + position=self._next_position, + leave=True, + unit="req", + desc=f"{model_id} | requests", + dynamic_ncols=True, + ) + if show_token_bars: + in_bar = tqdm( + total=1, + position=self._next_position + 1, + leave=True, + unit="tok", + desc=f"{model_id} | in tokens", + dynamic_ncols=True, + ) + out_bar = tqdm( + total=1, + position=self._next_position + 2, + leave=True, + unit="tok", + desc=f"{model_id} | out tokens", + dynamic_ncols=True, + ) + self._bars[model_id] = (req_bar, in_bar, out_bar) + self._next_position += 3 + else: + self._bars[model_id] = (req_bar, None, None) + self._next_position += 1 + + async def update_usage( + self, + model_id: str, + input_tokens: Optional[int], + output_tokens: Optional[int], + request_increment: int = 1, + ) -> None: + """ + Update rolling usage and refresh progress bars for the given model. + """ + async with self._lock: + self._ensure_deques(model_id) + now = time.time() + # push new events + for _ in range(request_increment): + self._events_requests[model_id].append(now) + self._total_requests[model_id] += int(request_increment) + if input_tokens is not None: + self._events_in_tokens[model_id].append((now, int(input_tokens))) + self._total_in_tokens[model_id] += int(input_tokens) + if output_tokens is not None: + self._events_out_tokens[model_id].append((now, int(output_tokens))) + self._total_out_tokens[model_id] += int(output_tokens) + + # prune + self._prune_old(self._events_requests[model_id]) + self._prune_old(self._events_in_tokens[model_id]) + self._prune_old(self._events_out_tokens[model_id]) + + # compute window sums + req_count_window = len(self._events_requests[model_id]) + in_tok_window = sum(v for _, v in self._events_in_tokens[model_id]) + out_tok_window = sum(v for _, v in self._events_out_tokens[model_id]) + + # update progress bars if present + req_bar, in_bar, out_bar = self._bars.get(model_id, (None, None, None)) + caps = self._caps.get(model_id, {"request_cap": None, "token_cap": None}) + request_cap = caps.get("request_cap") + token_cap = caps.get("token_cap") + + # If we have resource refs, compute used from resource directly for requests + resources = self._resources.get(model_id) + if resources is not None: + try: + request_resource = resources.get("request") + # trigger replenish and compute used + used_req = int(request_resource.refresh_rate - request_resource.value) + req_count_window = max(req_count_window, used_req) + except Exception: + pass + + if req_bar is not None: + if request_cap is not None: + req_bar.total = int(request_cap) + req_bar.n = int(min(req_count_window, request_cap)) + else: + # scale dynamically + req_bar.total = max(int(req_bar.total or 1), req_count_window or 1) + req_bar.n = int(req_count_window) + # show totals in postfix + try: + req_bar.set_postfix(total=self._total_requests.get(model_id, 0), window=req_count_window) + except Exception: + pass + req_bar.refresh() + + if in_bar is not None: + if token_cap is not None: + in_bar.total = int(token_cap) + in_bar.n = int(min(in_tok_window, token_cap)) + else: + in_bar.total = max(int(in_bar.total or 1), in_tok_window or 1) + in_bar.n = int(in_tok_window) + try: + in_bar.set_postfix(total=self._total_in_tokens.get(model_id, 0), window=in_tok_window) + except Exception: + pass + in_bar.refresh() + + if out_bar is not None: + if token_cap is not None: + out_bar.total = int(token_cap) + out_bar.n = int(min(out_tok_window, token_cap)) + else: + out_bar.total = max(int(out_bar.total or 1), out_tok_window or 1) + out_bar.n = int(out_tok_window) + try: + out_bar.set_postfix(total=self._total_out_tokens.get(model_id, 0), window=out_tok_window) + except Exception: + pass + out_bar.refresh() + + def close(self) -> None: + if self._disable: + return + for bars in self._bars.values(): + for bar in bars: + if bar is not None: + try: + bar.close() + except Exception: + pass + + def reset_totals(self, model_id: Optional[str] = None) -> None: + """ + Reset cumulative totals for a specific model or for all models. + Does not affect rolling 60s windows or bar positions. + """ + if model_id is None: + self._total_requests = {k: 0 for k in self._total_requests.keys()} + self._total_in_tokens = {k: 0 for k in self._total_in_tokens.keys()} + self._total_out_tokens = {k: 0 for k in self._total_out_tokens.keys()} + else: + self._total_requests[model_id] = 0 + self._total_in_tokens[model_id] = 0 + self._total_out_tokens[model_id] = 0 + self._totals_start_time = time.time() + + def get_snapshot(self, model_id: str) -> Dict[str, Any]: + """ + Return a snapshot of rolling-window usage and caps for a model. + + Keys: + - request_cap, token_cap: caps per minute if known + - req_count_window, in_tok_window, out_tok_window: usage in the last minute + - total_requests, total_in_tokens, total_out_tokens: cumulative totals since registration + """ + self._ensure_deques(model_id) + # prune before snapshot + self._prune_old(self._events_requests[model_id]) + self._prune_old(self._events_in_tokens[model_id]) + self._prune_old(self._events_out_tokens[model_id]) + + caps = self._caps.get(model_id, {"request_cap": None, "token_cap": None}) + return { + "request_cap": caps.get("request_cap"), + "token_cap": caps.get("token_cap"), + "req_count_window": len(self._events_requests[model_id]), + "in_tok_window": sum(v for _, v in self._events_in_tokens[model_id]), + "out_tok_window": sum(v for _, v in self._events_out_tokens[model_id]), + "total_requests": self._total_requests.get(model_id, 0), + "total_in_tokens": self._total_in_tokens.get(model_id, 0), + "total_out_tokens": self._total_out_tokens.get(model_id, 0), + "totals_since": self._totals_start_time, + } + + diff --git a/safetytooling/apis/inference/runpod_vllm.py b/safetytooling/apis/inference/runpod_vllm.py index c6ab3e2..2566df2 100644 --- a/safetytooling/apis/inference/runpod_vllm.py +++ b/safetytooling/apis/inference/runpod_vllm.py @@ -27,6 +27,7 @@ def __init__( prompt_history_dir: Path | None = None, vllm_base_url: str = "http://localhost:8000/v1/chat/completions", runpod_api_key: str | None = None, + progress_monitor: object | None = None, ): self.num_threads = num_threads self.prompt_history_dir = prompt_history_dir @@ -45,6 +46,7 @@ def __init__( "length": StopReason.MAX_TOKENS.value, "stop": StopReason.STOP_SEQUENCE.value, } + self.progress_monitor = progress_monitor async def query(self, model_url: str, payload: dict, session: aiohttp.ClientSession, timeout: int = 1000) -> dict: async with session.post(model_url, headers=self.headers, json=payload, timeout=timeout) as response: @@ -155,7 +157,7 @@ async def __call__( raise e except Exception as e: error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}" - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") await asyncio.sleep(1.5**i) else: break @@ -186,5 +188,6 @@ async def __call__( self.add_response_to_prompt_file(prompt_file, responses) if print_prompt_and_response: prompt.pretty_print(responses) + # Progress monitoring is handled centrally in InferenceAPI to avoid double counting. return responses diff --git a/safetytooling/apis/inference/together.py b/safetytooling/apis/inference/together.py index 8a1f65e..4297998 100644 --- a/safetytooling/apis/inference/together.py +++ b/safetytooling/apis/inference/together.py @@ -46,6 +46,7 @@ def __init__( num_threads: int, prompt_history_dir: Path | None = None, api_key: str | None = None, + progress_monitor: object | None = None, ): self.num_threads = num_threads self.prompt_history_dir = prompt_history_dir @@ -54,6 +55,7 @@ def __init__( else: self.aclient = None self.available_requests = asyncio.BoundedSemaphore(int(self.num_threads)) + self.progress_monitor = progress_monitor @staticmethod def convert_top_logprobs(data) -> list[dict]: @@ -107,11 +109,11 @@ async def __call__( except (TypeError, InvalidRequestError) as e: raise e except ServiceUnavailableError: - LOGGER.warn(f"Service Unavailable or Rate Limited for {model_id}") + LOGGER.warning(f"Service Unavailable or Rate Limited for {model_id}") await asyncio.sleep(10) except Exception as e: error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}" - LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") + LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})") error_list.append(error_info) api_duration = time.time() - api_start await asyncio.sleep(1.5**i) @@ -145,4 +147,6 @@ async def __call__( if print_prompt_and_response: prompt.pretty_print(responses) + # Progress monitoring is handled centrally in InferenceAPI to avoid double counting. + return responses diff --git a/tests/test_api.py b/tests/test_api.py index 7c34a7a..8f213bf 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,5 +1,6 @@ """Basic tests for the api.""" +import os import pydantic import pytest @@ -30,6 +31,7 @@ async def test_openai(): @pytest.mark.asyncio +@pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not available in environment or credits") async def test_claude_3(): utils.setup_environment() resp = await InferenceAPI.get_default_global_api().ask_single_question( @@ -45,6 +47,7 @@ async def test_claude_3(): @pytest.mark.asyncio +@pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not available in environment or credits") async def test_claude_3_with_system_prompt(): utils.setup_environment() resp = await InferenceAPI.get_default_global_api().ask_single_question( @@ -104,22 +107,24 @@ async def test_api_with_stop_parameter(): assert ", 5" not in openai_resp[0] assert "4" in openai_resp[0] - # Test with Anthropic model - anthropic_resp = await InferenceAPI.get_default_global_api().ask_single_question( - model_id="claude-3-haiku-20240307", - question="Count from 1 to 10: 1, 2, 3,", - max_tokens=20, - stop=[", 5"], # Should stop before outputting ", 5" - ) + # Test with Anthropic model (skip when no key/credits) + if os.environ.get("ANTHROPIC_API_KEY"): + anthropic_resp = await InferenceAPI.get_default_global_api().ask_single_question( + model_id="claude-3-haiku-20240307", + question="Count from 1 to 10: 1, 2, 3,", + max_tokens=20, + stop=[", 5"], # Should stop before outputting ", 5" + ) - assert isinstance(anthropic_resp, list) - assert len(anthropic_resp) == 1 - assert isinstance(anthropic_resp[0], str) - assert ", 5" not in anthropic_resp[0] - assert "4" in anthropic_resp[0] + assert isinstance(anthropic_resp, list) + assert len(anthropic_resp) == 1 + assert isinstance(anthropic_resp[0], str) + assert ", 5" not in anthropic_resp[0] + assert "4" in anthropic_resp[0] @pytest.mark.asyncio +@pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not available in environment or credits") async def test_anthropic_accepts_seed_parameter(): """Test that the seed parameter is ignored without error for Anthropic models.""" utils.setup_environment() @@ -137,6 +142,7 @@ async def test_anthropic_accepts_seed_parameter(): @pytest.mark.asyncio +@pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not available in environment or credits") async def test_anthropic_rejects_invalid_parameters(): """Test that invalid parameters raise TypeError with Anthropic models.""" utils.setup_environment() diff --git a/tests/test_other_apis.py b/tests/test_other_apis.py index 0258119..d72167e 100644 --- a/tests/test_other_apis.py +++ b/tests/test_other_apis.py @@ -25,7 +25,7 @@ def create_test_prompt(): ) -@pytest.mark.skipif("OPENROUTER_API_KEY" not in os.environ, reason="OPENROUTER_API_KEY not available in environment") +@pytest.mark.skipif(not os.environ.get("OPENROUTER_API_KEY"), reason="OPENROUTER_API_KEY not available in environment") @pytest.mark.asyncio async def test_openrouter_init(): api = InferenceAPI( @@ -48,7 +48,7 @@ async def test_openrouter_init(): assert isinstance(provider_forced, OpenRouterChatModel) -@pytest.mark.skipif("TOGETHER_API_KEY" not in os.environ, reason="TOGETHER_API_KEY not available in environment") +@pytest.mark.skipif(not os.environ.get("TOGETHER_API_KEY"), reason="TOGETHER_API_KEY not available in environment") @pytest.mark.asyncio async def test_together_init(): api = InferenceAPI( @@ -70,7 +70,7 @@ async def test_together_init(): assert isinstance(provider_forced, TogetherChatModel) -@pytest.mark.skipif("OPENROUTER_API_KEY" not in os.environ, reason="OPENROUTER_API_KEY not available in environment") +@pytest.mark.skipif(not os.environ.get("OPENROUTER_API_KEY"), reason="OPENROUTER_API_KEY not available in environment") @pytest.mark.asyncio async def test_openrouter_call(): api = InferenceAPI(