diff --git a/src/resolution.rs b/src/resolution.rs index 9cc665f..f7507ec 100644 --- a/src/resolution.rs +++ b/src/resolution.rs @@ -1,9 +1,17 @@ use crate::{DependencyGroupSpecifier, DependencyGroups, ResolvedDependencies}; use indexmap::IndexMap; -use pep508_rs::Requirement; +use pep508_rs::{ExtraName, Requirement}; use std::fmt::Display; +use std::str::FromStr; use thiserror::Error; +/// Normalize a group/extra name according to PEP 685. +fn normalize_name(name: &str) -> String { + ExtraName::from_str(name) + .map(|extra| extra.to_string()) + .unwrap_or_else(|_| name.to_string()) +} + #[derive(Debug, Error)] #[error(transparent)] pub struct ResolveError(#[from] ResolveErrorKind); @@ -105,7 +113,16 @@ fn resolve_optional_dependency( return Ok(requirements.clone()); } - let Some(unresolved_requirements) = optional_dependencies.get(extra) else { + let normalized_extra = normalize_name(extra); + + // Find the key in optional_dependencies by comparing normalized versions + // TODO: next breaking release remove this once Extra is added + let unresolved_requirements = optional_dependencies + .iter() + .find(|(key, _)| normalize_name(key) == normalized_extra) + .map(|(_, reqs)| reqs); + + let Some(unresolved_requirements) = unresolved_requirements else { let parent = parents .iter() .last() @@ -460,4 +477,38 @@ mod tests { vec![Requirement::from_str("numpy").unwrap()] ); } + + #[test] + fn optional_dependencies_with_underscores() { + // Test that optional dependency group names with underscores are normalized + // when referenced in extras. PEP 685 specifies that extras should be normalized + // by replacing _, ., - with a single -. + let source = r#" + [project] + name = "foo" + + [project.optional-dependencies] + all = [ + "foo[group-one]", + "foo[group_two]", + ] + group_one = [ + "anyio>=4.9.0", + ] + group-two = [ + "trio>=0.31.0", + ] + "#; + let pyproject_toml = PyProjectToml::new(source).unwrap(); + let resolved_dependencies = pyproject_toml.resolve().unwrap(); + + // Both group-one and group_two should resolve correctly + assert_eq!( + resolved_dependencies.optional_dependencies["all"], + vec![ + Requirement::from_str("anyio>=4.9.0").unwrap(), + Requirement::from_str("trio>=0.31.0").unwrap(), + ] + ); + } }