Skip to content

Commit

Permalink
Attempt supporting multi-file version constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
knedlsepp committed Nov 10, 2023
1 parent c18d8f2 commit 5484a88
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 9 deletions.
109 changes: 108 additions & 1 deletion conda_lock/models/lock_spec.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import hashlib
import json
import pathlib
import typing

from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union

from pydantic import BaseModel, Field, validator
from typing_extensions import Literal
Expand All @@ -24,23 +26,128 @@ class _BaseDependency(StrictModel):
def sorted_extras(cls, v: List[str]) -> List[str]:
return sorted(v)

def merge(self, other):
if other is None:
return self
if (
self.name != other.name
or self.manager != other.manager
or self.category != other.category
):
raise ValueError(
"Cannot merge incompatible dependencies: {self} != {other}"
)
return _BaseDependency(
name=self.name,
manager=self.manager,
category=self.category,
extras=list(set(self.extras + other.extras)),
)


class VersionedDependency(_BaseDependency):
version: str
build: Optional[str] = None
conda_channel: Optional[str] = None

@staticmethod
def _merge_versions(version1, version2):
if version1 is None or version1 == "":
return version2
if version2 is None or version2 == "":
return version1
return f"{version1},{version2}"

def merge(self, other):
if other is None:
return self
if (
self.build is not None
and other.build is not None
and self.build != other.build
):
raise ValueError(
f"VersionedDependency has two different builds:\n{self}\n{other}"
)

if (
self.conda_channel is not None
and other.conda_channel is not None
and self.conda_channel != other.conda_channel
):
raise ValueError(
f"VersionedDependency has two different conda_channels:\n{self}\n{other}"
)
merged_base = super().merge(other)
return VersionedDependency(
name=merged_base.name,
manager=merged_base.manager,
category=merged_base.category,
extras=merged_base.extras,
version=self._merge_versions(self.version, other.version),
build=self.build or other.build,
conda_channel=self.conda_channel or other.conda_channel,
)


class URLDependency(_BaseDependency):
url: str
hashes: List[str]

def merge(self, other):
if other is None:
return self
if self.url != other.url:
raise ValueError(f"URLDependency has two different urls:\n{self}\n{other}")

if self.hashes != other.hashes:
raise ValueError(
f"URLDependency has two different hashess:\n{self}\n{other}"
)
merged_base = super().merge(other)

return URLDependency(
name=merged_base.name,
manager=merged_base.manager,
category=merged_base.category,
extras=merged_base.extras,
url=self.url,
hashes=self.hashes,
)


class VCSDependency(_BaseDependency):
source: str
vcs: str
rev: Optional[str] = None

def merge(self, other):
if other is None:
return self
if self.source != other.source:
raise ValueError(
f"VCSDependency has two different sources:\n{self}\n{other}"
)

if self.vcs != other.vcs:
raise ValueError()(
f"VCSDependency has two different vcss:\n{self}\n{other}"
)

if self.rev is not None and other.rev is not None and self.rev != other.rev:
raise ValueError(f"VCSDependency has two different revs:\n{self}\n{other}")
merged_base = super().merge(other)

return VCSDependency(
name=merged_base.name,
manager=merged_base.manager,
category=merged_base.category,
extras=merged_base.extras,
source=self.source,
vcs=self.vcs,
rev=self.rev or other.rev,
)


Dependency = Union[VersionedDependency, URLDependency, VCSDependency]

Expand Down
2 changes: 1 addition & 1 deletion conda_lock/src_parser/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def aggregate_lock_specs(
lock_spec.dependencies.get(platform, []) for lock_spec in lock_specs
):
key = (dep.manager, dep.name)
unique_deps[key] = dep
unique_deps[key] = dep.merge(unique_deps.get(key))

dependencies[platform] = list(unique_deps.values())

Expand Down
20 changes: 13 additions & 7 deletions tests/test_conda_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,22 +1622,28 @@ def test_aggregate_lock_specs():
assert actual.content_hash() == expected.content_hash()


def test_aggregate_lock_specs_override_version():
base_spec = LockSpecification(
dependencies={"linux-64": [_make_spec("package", "=1.0")]},
def test_aggregate_lock_specs_combine_version():
first_spec = LockSpecification(
dependencies={"linux-64": [_make_spec("package", ">1.0")]},
channels=[Channel.from_string("conda-forge")],
sources=[Path("base.yml")],
)

override_spec = LockSpecification(
dependencies={"linux-64": [_make_spec("package", "=2.0")]},
second_spec = LockSpecification(
dependencies={"linux-64": [_make_spec("package", "<2.0")]},
channels=[Channel.from_string("internal"), Channel.from_string("conda-forge")],
sources=[Path("override.yml")],
)

result_spec = LockSpecification(
dependencies={"linux-64": [_make_spec("package", ">1.0,<2.0")]},
channels=[Channel.from_string("internal"), Channel.from_string("conda-forge")],
sources=[Path("override.yml")],
)

agg_spec = aggregate_lock_specs([base_spec, override_spec], platforms=["linux-64"])
agg_spec = aggregate_lock_specs([first_spec, second_spec], platforms=["linux-64"])

assert agg_spec.dependencies == override_spec.dependencies
assert agg_spec.dependencies == result_spec.dependencies


def test_aggregate_lock_specs_invalid_channels():
Expand Down

0 comments on commit 5484a88

Please sign in to comment.