diff --git a/paper_check.py b/paper_check.py index 5ae0cb4..b8948a5 100644 --- a/paper_check.py +++ b/paper_check.py @@ -6,13 +6,32 @@ import language_tool_python from pdf2image import convert_from_path from openai import OpenAI +try: + import google.generativeai as genai +except ImportError: + genai = None DEBUG = False +SUPPORTED_MODELS = { + "grok-2-latest": {"provider": "xai", "vision_model": "grok-2-vision-1212"}, + "gemini-2.5-flash": {"provider": "gemini", "vision_model": "gemini-2.5-flash"}, + "gemini-2.5-pro": {"provider": "gemini", "vision_model": "gemini-2.5-pro"}, + "gpt-4o-mini": {"provider": "openai", "vision_model": "gpt-4o-mini"}, +} + class ChecklistEvaluator: - def __init__(self, pdf_path, latex_path_or_dir, openai_api_key=None): + def __init__(self, pdf_path, latex_path_or_dir, api_key=None, model_name=None, skip_llm=False): self.pdf_path = pdf_path self.file_text_map = {} + self.model_name = model_name + self.model_config = SUPPORTED_MODELS.get(model_name) if model_name else None + self.provider = self.model_config['provider'] if self.model_config else None + self.skip_llm = skip_llm or (self.model_config is None) + self.api_key = api_key + self._clients = {} + if not self.skip_llm and self.model_config is None: + raise ValueError(f"Unsupported model '{model_name}'. Supported models: {', '.join(SUPPORTED_MODELS)}") if os.path.isdir(latex_path_or_dir): # Prioritize loading 'paper.tex' and process \input commands to include corresponding files self.latex_root_dir = os.path.abspath(latex_path_or_dir) @@ -23,8 +42,8 @@ def __init__(self, pdf_path, latex_path_or_dir, openai_api_key=None): # Extract sections from the merged LaTeX text self.sections = self.extract_sections(self.latex_text) self.pdf_text = self.load_pdf() - if openai_api_key: - os.environ["XAI_API_KEY"] = openai_api_key + if self.provider: + self._store_api_key() self.report = {} @staticmethod @@ -54,6 +73,33 @@ def _iterate_latex_files(self): else: yield "document.tex", self.latex_text + def _store_api_key(self): + """Persist provided API key to appropriate environment variable for downstream SDKs.""" + if not self.api_key: + return + if self.provider == "xai": + os.environ["XAI_API_KEY"] = self.api_key + elif self.provider == "gemini": + os.environ["GEMINI_API_KEY"] = self.api_key + elif self.provider == "openai": + os.environ["OPENAI_API_KEY"] = self.api_key + + def _get_api_key(self): + """Return the API key for the configured provider.""" + if self.provider == "xai": + return self.api_key or os.getenv("XAI_API_KEY") + if self.provider == "gemini": + return self.api_key or os.getenv("GEMINI_API_KEY") + if self.provider == "openai": + return self.api_key or os.getenv("OPENAI_API_KEY") + return None + + def _get_vision_model(self): + """Return the vision-capable model to use for figure analysis.""" + if not self.model_config: + return None + return self.model_config.get("vision_model", self.model_name) + def load_pdf(self): """Extract text from PDF""" reader = PdfReader(self.pdf_path) @@ -732,20 +778,64 @@ def static_check_italics_usage(self): return True # ---------------- LLM Checks ---------------- - def call_llm(self, prompt, model="grok-2-latest"): - XAI_API_KEY = os.getenv("XAI_API_KEY") - client = OpenAI( - api_key=XAI_API_KEY, - base_url="https://api.x.ai/v1", - ) + def call_llm(self, prompt, max_tokens=2000): + """Dispatch a text-only prompt to the configured LLM provider.""" + if self.skip_llm: + return "LLM checks skipped." + if not self.provider: + raise ValueError("LLM model is not configured.") + try: - response = client.completions.create( - model=model, - max_tokens=10000, - prompt=prompt, - temperature=0 - ) - return response.choices[0].text.strip() + if self.provider == "xai": + api_key = self._get_api_key() + if not api_key: + raise ValueError("XAI_API_KEY is not set. Provide it via environment or --api_key.") + if "xai_text" not in self._clients: + self._clients["xai_text"] = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1") + response = self._clients["xai_text"].completions.create( + model=self.model_name, + max_tokens=max_tokens, + prompt=prompt, + temperature=0 + ) + return response.choices[0].text.strip() + + if self.provider == "gemini": + if genai is None: + raise ImportError("google-generativeai is not installed. Install it to use Gemini models.") + api_key = self._get_api_key() + if not api_key: + raise ValueError("Gemini API key is not set. Provide it via environment or --api_key.") + if "gemini_text" not in self._clients: + genai.configure(api_key=api_key) + self._clients["gemini_text"] = genai.GenerativeModel(self.model_name) + response = self._clients["gemini_text"].generate_content(prompt) + if hasattr(response, "text") and response.text: + return response.text + candidates = getattr(response, "candidates", []) + collected = [] + for candidate in candidates: + for part in getattr(candidate, "content", []): + text_part = getattr(part, "text", None) + if text_part: + collected.append(text_part) + return "\n".join(collected) if collected else "" + + if self.provider == "openai": + api_key = self._get_api_key() + if not api_key: + raise ValueError("OpenAI API key is not set. Provide it via environment or --api_key.") + if "openai_text" not in self._clients: + self._clients["openai_text"] = OpenAI(api_key=api_key) + response = self._clients["openai_text"].chat.completions.create( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + temperature=0 + ) + return response.choices[0].message.content + + raise ValueError(f"Unsupported provider '{self.provider}'.") except Exception as e: return f"LLM call failed: {e}" @@ -804,18 +894,103 @@ def llm_evaluate_writing_for_section(self, section_title, section_content): # return self.call_llm(prompt) return self.call_llm(prompt) + def call_vision_llm(self, prompt, base64_image, max_tokens=2000): + """Dispatch a multimodal prompt (text + image) to the configured provider.""" + if self.skip_llm: + return "LLM checks skipped." + if not self.provider: + raise ValueError("LLM model is not configured.") + + vision_model = self._get_vision_model() + try: + if self.provider == "xai": + api_key = self._get_api_key() + if not api_key: + raise ValueError("XAI_API_KEY is not set. Provide it via environment or --api_key.") + if "xai_vision" not in self._clients: + self._clients["xai_vision"] = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1") + response = self._clients["xai_vision"].chat.completions.create( + model=vision_model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}" }} + ] + } + ], + max_tokens=max_tokens + ) + return response.choices[0].message.content + + if self.provider == "gemini": + if genai is None: + raise ImportError("google-generativeai is not installed. Install it to use Gemini models.") + api_key = self._get_api_key() + if not api_key: + raise ValueError("Gemini API key is not set. Provide it via environment or --api_key.") + if "gemini_vision" not in self._clients: + genai.configure(api_key=api_key) + self._clients["gemini_vision"] = genai.GenerativeModel(vision_model) + import base64 as _base64 + image_bytes = _base64.b64decode(base64_image) + response = self._clients["gemini_vision"].generate_content( + [ + {"text": prompt}, + {"inline_data": {"mime_type": "image/png", "data": image_bytes}} + ] + ) + if hasattr(response, "text") and response.text: + return response.text + candidates = getattr(response, "candidates", []) + collected = [] + for candidate in candidates: + for part in getattr(candidate, "content", []): + text_part = getattr(part, "text", None) + if text_part: + collected.append(text_part) + return "\n".join(collected) if collected else "" + + if self.provider == "openai": + api_key = self._get_api_key() + if not api_key: + raise ValueError("OpenAI API key is not set. Provide it via environment or --api_key.") + if "openai_multimodal" not in self._clients: + self._clients["openai_multimodal"] = OpenAI(api_key=api_key) + response = self._clients["openai_multimodal"].responses.create( + model=vision_model, + input=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "input_image", "image": {"b64": base64_image}} + ] + } + ], + max_output_tokens=max_tokens + ) + return getattr(response, "output_text", None) or "" + + raise ValueError(f"Unsupported provider '{self.provider}'.") + except Exception as e: + return f"LLM vision call failed: {e}" + def llm_evaluate_figures_from_pdf(self): """ Uses pdf2image to convert each page of the PDF to an image, - then calls the visual model grok-2-vision to check the figures. + then calls the configured vision-capable model to check the figures. """ import base64 import tempfile - from pathlib import Path - + + if self.skip_llm: + return "LLM checks skipped." + images = convert_from_path(self.pdf_path) results = {} - + # Create temp directory to save images with tempfile.TemporaryDirectory() as temp_dir: for i, image in enumerate(images): @@ -838,43 +1013,14 @@ def llm_evaluate_figures_from_pdf(self): 7. Figures/tables must be referenced in the text and the text should tell the reader how to read it and what to take away from it 8. Figures should appear right after (using [h]) where they are mentioned (in most cases except layout issues)""" - # Get API key (using XAI_API_KEY instead of OPENAI_API_KEY) - XAI_API_KEY = os.getenv("XAI_API_KEY") - try: - # Initialize client - client = OpenAI( - api_key=XAI_API_KEY, - base_url="https://api.x.ai/v1", - ) - - # Call the API with vision capabilities - response = client.chat.completions.create( - model="grok-2-vision-1212", # Using the specific Grok vision model - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{base64_image}" - } - } - ] - } - ], - max_tokens=10000 - ) - - # Store the result - results[f"page_{i+1}"] = response.choices[0].message.content - + # Store the result from the configured multimodal model + results[f"page_{i+1}"] = self.call_vision_llm(prompt, base64_image) + except Exception as e: results[f"page_{i+1}"] = f"Image analysis failed: {e}" # break - + return results def run_all_checks(self): @@ -892,19 +1038,26 @@ def run_all_checks(self): # Currently unable to handle latex format, will report a large number of spelling errors # self.static_spell_check() - # Then call LLM for subjective evaluation (by module) - # Example: LLM checks for titles, introduction, and figures - self.report['LLM_Titles'] = self.llm_evaluate_titles(title, section_titles) - self.report['LLM_Introduction'] = self.llm_evaluate_introduction() - self.report['LLM_Figures_Vision'] = self.llm_evaluate_figures_from_pdf() - - # # Check writing quality for each section - llm_writing_results = {} - for section_title, section_content in self.sections.items(): - result = self.llm_evaluate_writing_for_section(section_title, section_content) - llm_writing_results[section_title] = result - - self.report['LLM_Writing'] = llm_writing_results + if not self.skip_llm: + # Then call LLM for subjective evaluation (by module) + # Example: LLM checks for titles, introduction, and figures + self.report['LLM_Titles'] = self.llm_evaluate_titles(title, section_titles) + self.report['LLM_Introduction'] = self.llm_evaluate_introduction() + self.report['LLM_Figures_Vision'] = self.llm_evaluate_figures_from_pdf() + + # # # Check writing quality for each section + llm_writing_results = {} + for section_title, section_content in self.sections.items(): + result = self.llm_evaluate_writing_for_section(section_title, section_content) + llm_writing_results[section_title] = result + + self.report['LLM_Writing'] = llm_writing_results + else: + skip_message = "LLM checks skipped." + self.report['LLM_Titles'] = skip_message + self.report['LLM_Introduction'] = skip_message + self.report['LLM_Figures_Vision'] = skip_message + self.report['LLM_Writing'] = {section_title: skip_message for section_title in self.sections} return self.report @@ -912,44 +1065,85 @@ def main(): parser = argparse.ArgumentParser(description="Check paper drafts against quality checklist using LLM and static methods.") parser.add_argument("--pdf", required=True, help="Path to the paper draft PDF file") parser.add_argument("--latex", required=True, help="Path to LaTeX source file or directory") - parser.add_argument("--api_key", required=False, help="OpenAI API key (optional, or set via XAI_API_KEY environment variable)") + parser.add_argument( + "--model", + required=False, + help=f"LLM model to use. Supported models: {', '.join(SUPPORTED_MODELS.keys())}" + ) + parser.add_argument( + "--api_key", + required=False, + help="API key for the selected LLM provider (optional if set via environment variable)." + ) args = parser.parse_args() - evaluator = ChecklistEvaluator(args.pdf, args.latex, args.api_key) + model_name = args.model + skip_llm = False + + if not model_name: + choice = input("No --model specified. Do you want to skip LLM checks? (y/N): ").strip().lower() + if choice in ("y", "yes"): + skip_llm = True + else: + available_models = ", ".join(SUPPORTED_MODELS.keys()) + while not model_name: + candidate = input(f"Please specify the model name ({available_models}): ").strip() + if candidate in SUPPORTED_MODELS: + model_name = candidate + else: + print(f"Invalid model name '{candidate}'. Supported models: {available_models}") + else: + if model_name not in SUPPORTED_MODELS: + raise ValueError(f"Unsupported model '{model_name}'. Supported models: {', '.join(SUPPORTED_MODELS.keys())}") + + evaluator = ChecklistEvaluator( + args.pdf, + args.latex, + api_key=args.api_key, + model_name=model_name, + skip_llm=skip_llm + ) report = evaluator.run_all_checks() - - # Save the report as a markdown document - markdown_output = "# Paper Quality Check Report\n\n" - - # Add a summary section - markdown_output += "## Summary\n\n" - total_checks = 0 - passed_checks = 0 - + + # Separate static checks (with 'result') from LLM checks + static_checks = [] + llm_checks = [] + other_entries = [] for check_name, check_data in report.items(): if isinstance(check_data, dict) and 'result' in check_data: - total_checks += 1 - if check_data['result'] is True: - passed_checks += 1 + static_checks.append((check_name, check_data)) + elif check_name.startswith('LLM_'): + llm_checks.append((check_name, check_data)) + else: + other_entries.append((check_name, check_data)) - if total_checks > 0: + # Save the report as a markdown document + markdown_output = "# Paper Quality Check Report\n\n" + + # Static summary + markdown_output += "## Static Check Summary\n\n" + if static_checks: + total_checks = len(static_checks) + passed_checks = sum(1 for _, data in static_checks if data.get('result') is True) pass_rate = (passed_checks / total_checks) * 100 markdown_output += f"- **Pass Rate**: {pass_rate:.1f}% ({passed_checks}/{total_checks} checks passed)\n\n" - - # Add detailed results for each check + else: + markdown_output += "No static checks were executed.\n\n" + + # Static detailed results markdown_output += "## Detailed Results\n\n" - - for check_name, check_data in report.items(): - if isinstance(check_data, dict) and 'result' in check_data: + if not static_checks: + markdown_output += "No static check results available.\n\n" + else: + for check_name, check_data in static_checks: status = "✅ PASS" if check_data['result'] is True else "❌ FAIL" check_title = check_data.get('check', check_name) markdown_output += f"### {status}: {check_title}\n\n" markdown_output += f"**Details**: {check_data.get('detail', 'No detailed information')}\n\n" - - if 'failed_content' in check_data and check_data['failed_content']: + + failed_content = check_data.get('failed_content') + if failed_content: markdown_output += "**Error content**:\n\n" - failed_content = check_data['failed_content'] - if isinstance(failed_content, list): for i, item in enumerate(failed_content, 1): markdown_output += f"{i}. {item}\n" @@ -963,12 +1157,15 @@ def main(): markdown_output += f" {value}\n" else: markdown_output += f"{failed_content}\n" - markdown_output += "\n" - elif isinstance(check_data, dict) and check_name.startswith('LLM_'): - # Handle nested LLM evaluation results + + # LLM section + markdown_output += "## LLM Check\n\n" + if not llm_checks: + markdown_output += "No LLM checks were executed.\n\n" + else: + for check_name, check_data in llm_checks: markdown_output += f"### {check_name}\n\n" - if isinstance(check_data, dict): for section, section_data in check_data.items(): if isinstance(section_data, dict): @@ -977,6 +1174,12 @@ def main(): markdown_output += f"**{key}**: {value}\n\n" else: markdown_output += f"**{section}**: {section_data}\n\n" + else: + markdown_output += f"{check_data}\n\n" + + # Any other entries + for check_name, check_data in other_entries: + markdown_output += f"## {check_name}\n\n{check_data}\n\n" # Write the markdown report to a file report_filename = os.path.splitext(os.path.basename(args.pdf))[0] + "_quality_report.md" diff --git a/requirements.txt b/requirements.txt index fdcbb7d..f9b275e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ tqdm==4.67.1 typing-inspection==0.4.2 typing_extensions==4.15.0 urllib3==2.5.0 +google-generativeai==0.8.4