Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC][RFC] Metals BSP third-party code navigation #21143

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ __pycache__/
.idea
/.vscode/
.cache
.metals
.scala-build
.pants.d
# TODO: We can probably delete these 3. They have not been used in a long time, if ever.
# In fact there's a lot of things we can clean up in this .gitignore. It's not harmful
Expand Down
130 changes: 103 additions & 27 deletions src/python/pants/backend/scala/bsp/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from __future__ import annotations

import dataclasses
import logging
import textwrap
from dataclasses import dataclass
Expand Down Expand Up @@ -50,7 +51,7 @@
)
from pants.core.util_rules.system_binaries import BashBinary, ReadlinkBinary
from pants.engine.addresses import Addresses
from pants.engine.fs import AddPrefix, CreateDigest, Digest, FileContent, MergeDigests, Workspace
from pants.engine.fs import AddPrefix, CreateDigest, Digest, FileContent, FileEntry, MergeDigests, Workspace
from pants.engine.internals.native_engine import Snapshot
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.process import Process, ProcessResult
Expand Down Expand Up @@ -114,7 +115,7 @@ class ThirdpartyModulesRequest:
@dataclass(frozen=True)
class ThirdpartyModules:
resolve: CoursierResolveKey
entries: dict[CoursierLockfileEntry, ClasspathEntry]
entries: dict[CoursierLockfileEntry, tuple[ClasspathEntry, list[CoursierLockfileEntry]]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inner list should be a tuple instead since "frozen" dataclasses need to be hashable and list is not hashable (since it is mutable).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe it would be better to just use a new dataclass instead of a 2-tuple? Then you could give a name to each component which will help with readability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And is there any particular ordering of these CoursierLockfileEntry instances?

merged_digest: Digest


Expand All @@ -128,6 +129,11 @@ async def collect_thirdparty_modules(
lockfile = await Get(CoursierResolvedLockfile, CoursierResolveKey, resolve)

applicable_lockfile_entries: dict[CoursierLockfileEntry, CoarsenedTarget] = {}
applicable_lockfile_source_entries: dict[CoursierLockfileEntry, CoursierLockfileEntry] = {}
applicable_lockfile_source_entries_inverse: dict[
CoursierLockfileEntry, list[CoursierLockfileEntry]
] = {}

for ct in coarsened_targets.coarsened_closure():
for tgt in ct.members:
if not JvmArtifactFieldSet.is_applicable(tgt):
Expand All @@ -142,6 +148,21 @@ async def collect_thirdparty_modules(
continue
applicable_lockfile_entries[entry] = ct

artifact_requirement_source = dataclasses.replace(
artifact_requirement,
coordinate=dataclasses.replace(
artifact_requirement.coordinate, classifier="sources"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall from the PR description, that you had some concern about hard-coding to "sources" classifier. Seems fine to me. I suggest defining a set of constants for well-known classifier names though and using a constant here instead of the open-coded "sources" string.

),
)
entrySource = get_entry_for_coord(lockfile, artifact_requirement_source.coordinate)
if not entrySource:
_logger.warning(
f"No lockfile source entry for {artifact_requirement_source.coordinate} in resolve {resolve.name}."
)
continue
applicable_lockfile_source_entries[entrySource] = entry
applicable_lockfile_source_entries_inverse[entry] = [entrySource]

classpath_entries = await MultiGet(
Get(
ClasspathEntry,
Expand All @@ -151,11 +172,35 @@ async def collect_thirdparty_modules(
for target in applicable_lockfile_entries.values()
)

resolve_digest = await Get(Digest, MergeDigests(cpe.digest for cpe in classpath_entries))
digests = []
for cpe in classpath_entries:
digests.append(cpe.digest)
for alse in applicable_lockfile_source_entries:
new_file = FileEntry(alse.file_name, alse.file_digest)
digest = await Get(Digest, CreateDigest([new_file]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware all these individual CreateDigest etc would not be done one-by-one in a for-loop, I just left everything as simple as I could to present the overall approach

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I suggest a TODO comment here as a reminder to just convert this to a single CreateDigest with a list of FileEntry.

digests.append(digest)

for dep in alse.dependencies:
coord = Coordinate.from_coord_str(dep.to_coord_str())
dep_artifact_requirement = ArtifactRequirement(coord)
dep_entry = get_entry_for_coord(lockfile, dep_artifact_requirement.coordinate)
dep_new_file = FileEntry(dep_entry.file_name, dep_entry.file_digest)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversion to FileEntry seems to occur more than once. Maybe define a to_file_entry method on the dep_entry?

dep_digest = await Get(Digest, CreateDigest([dep_new_file]))
digests.append(dep_digest)
src_ent = applicable_lockfile_source_entries.get(alse)
applicable_lockfile_source_entries_inverse.get(src_ent).append(dep_entry)

resolve_digest = await Get(Digest, MergeDigests(digests))
inverse = dict(zip(classpath_entries, applicable_lockfile_entries))

s = map(
lambda x: (x, applicable_lockfile_source_entries_inverse.get(inverse.get(x))),
classpath_entries,
)
Comment on lines +175 to +199
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole thing is convoluted just to see that it works, just mashing it in here to avoid changing too much elsewhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also just be a list comprehension.


return ThirdpartyModules(
resolve,
dict(zip(applicable_lockfile_entries, classpath_entries)),
dict(zip(applicable_lockfile_entries, s)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert s to tuple here given earlier comment to not include list in a frozen dataclass.

resolve_digest,
)

Expand Down Expand Up @@ -341,7 +386,8 @@ async def handle_bsp_scalac_options_request(
) -> HandleScalacOptionsResult:
targets = await Get(Targets, BuildTargetIdentifier, request.bsp_target_id)
thirdparty_modules = await Get(
ThirdpartyModules, ThirdpartyModulesRequest(Addresses(tgt.address for tgt in targets))
ThirdpartyModules,
ThirdpartyModulesRequest(Addresses(tgt.address for tgt in targets)),
)
resolve = thirdparty_modules.resolve

Expand All @@ -352,12 +398,16 @@ async def handle_bsp_scalac_options_request(

local_plugins_prefix = f"jvm/resolves/{resolve.name}/plugins"
local_plugins = await Get(
ScalaPlugins, ScalaPluginsRequest.from_target_plugins(scalac_plugin_targets, resolve)
ScalaPlugins,
ScalaPluginsRequest.from_target_plugins(scalac_plugin_targets, resolve),
)

thirdparty_modules_prefix = f"jvm/resolves/{resolve.name}/lib"
thirdparty_modules_digest, local_plugins_digest = await MultiGet(
Get(Digest, AddPrefix(thirdparty_modules.merged_digest, thirdparty_modules_prefix)),
Get(
Digest,
AddPrefix(thirdparty_modules.merged_digest, thirdparty_modules_prefix),
),
Get(Digest, AddPrefix(local_plugins.classpath.digest, local_plugins_prefix)),
)

Expand All @@ -370,7 +420,7 @@ async def handle_bsp_scalac_options_request(
build_root.pathlib_path.joinpath(
f".pants.d/bsp/{thirdparty_modules_prefix}/{filename}"
).as_uri()
for cp_entry in thirdparty_modules.entries.values()
for cp_entry, _ in thirdparty_modules.entries.values()
for filename in cp_entry.filenames
)

Expand Down Expand Up @@ -436,6 +486,12 @@ async def bsp_scala_test_classes_request(request: ScalaTestClassesParams) -> Sca
)


# -----------------------------------------------------------------------------------------------
# Dependency Sources
# -----------------------------------------------------------------------------------------------

# TODO

# -----------------------------------------------------------------------------------------------
# Dependency Modules
# -----------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -464,33 +520,53 @@ async def scala_bsp_dependency_modules(
ThirdpartyModules,
ThirdpartyModulesRequest(Addresses(fs.address for fs in request.field_sets)),
)

resolve = thirdparty_modules.resolve

resolve_digest = await Get(
Digest, AddPrefix(thirdparty_modules.merged_digest, f"jvm/resolves/{resolve.name}/lib")
Digest,
AddPrefix(thirdparty_modules.merged_digest, f"jvm/resolves/{resolve.name}/lib"),
)

modules = [
DependencyModule(
name=f"{entry.coord.group}:{entry.coord.artifact}",
version=entry.coord.version,
data=MavenDependencyModule(
organization=entry.coord.group,
name=entry.coord.artifact,
modules = []

for entry, (cp_entry, source_entry) in thirdparty_modules.entries.items():
a1 = [
MavenDependencyModuleArtifact(
uri=build_root.pathlib_path.joinpath(
f".pants.d/bsp/jvm/resolves/{resolve.name}/lib/{filename}"
).as_uri(),
)
for filename in cp_entry.filenames
]

a2 = None
if source_entry is not None:
a2 = [
MavenDependencyModuleArtifact(
uri=build_root.pathlib_path.joinpath(
f".pants.d/bsp/jvm/resolves/{resolve.name}/lib/{se.file_name}"
).as_uri(),
classifier="sources",
)
for se in source_entry
]
else:
a2 = []

modules.append(
DependencyModule(
name=f"{entry.coord.group}:{entry.coord.artifact}",
version=entry.coord.version,
scope=None,
artifacts=tuple(
MavenDependencyModuleArtifact(
uri=build_root.pathlib_path.joinpath(
f".pants.d/bsp/jvm/resolves/{resolve.name}/lib/{filename}"
).as_uri()
)
for filename in cp_entry.filenames
data=MavenDependencyModule(
organization=entry.coord.group,
name=entry.coord.artifact,
version=entry.coord.version,
scope=None,
artifacts=tuple(a1 + a2),
),
),
)
)
for entry, cp_entry in thirdparty_modules.entries.items()
]

return BSPDependencyModulesResult(
modules=tuple(modules),
Expand Down
33 changes: 30 additions & 3 deletions src/python/pants/bsp/util_rules/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import itertools
import logging
import typing
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -57,6 +58,7 @@
Targets,
)
from pants.engine.unions import UnionMembership, UnionRule, union
from pants.jvm.bsp.spec import MavenDependencyModule, MavenDependencyModuleArtifact
from pants.source.source_root import SourceRootsRequest, SourceRootsResult
from pants.util.frozendict import FrozenDict
from pants.util.ordered_set import FrozenOrderedSet, OrderedSet
Expand Down Expand Up @@ -397,6 +399,7 @@ async def generate_one_bsp_build_target_request(
# directory or else be configurable by the user. It is used as a hint in IntelliJ for where to place the
# corresponding IntelliJ module.
source_info = await Get(BSPBuildTargetSourcesInfo, BSPBuildTargetInternal, request.bsp_target)

if source_info.source_roots:
roots = [build_root.pathlib_path.joinpath(p) for p in source_info.source_roots]
else:
Expand Down Expand Up @@ -479,7 +482,6 @@ async def materialize_bsp_build_target_sources(
) -> MaterializeBuildTargetSourcesResult:
bsp_target = await Get(BSPBuildTargetInternal, BuildTargetIdentifier, request.bsp_target_id)
source_info = await Get(BSPBuildTargetSourcesInfo, BSPBuildTargetInternal, bsp_target)

if source_info.source_roots:
roots = [build_root.pathlib_path.joinpath(p) for p in source_info.source_roots]
else:
Expand Down Expand Up @@ -516,6 +518,18 @@ async def bsp_build_target_sources(request: SourcesParams) -> SourcesResult:
# -----------------------------------------------------------------------------------------------


@dataclass(frozen=True)
class BSPDependencySourcesRequest:
"""Hook to allow language backends to provide dependency sources."""

params: DependencySourcesParams


@dataclass(frozen=True)
class BSPDependencyModulesResult:
result: DependencySourcesResult


class DependencySourcesHandlerMapping(BSPHandlerMapping):
method_name = "buildTarget/dependencySources"
request_type = DependencySourcesParams
Expand All @@ -524,9 +538,22 @@ class DependencySourcesHandlerMapping(BSPHandlerMapping):

@rule
async def bsp_dependency_sources(request: DependencySourcesParams) -> DependencySourcesResult:
# TODO: This is a stub.
dependency_modules = await Get(
DependencyModulesResult, DependencyModulesParams, DependencyModulesParams(request.targets)
)
Comment on lines +541 to +543
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would have its own backend-specific rule I suppose, for now it just pulling it from dependency_modules fully expecting it to be a maven module, again to try it out without implementing a whole rule-indirection for it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. Maybe add a TODO comment as a reminder to create the union hook which makes this backend-agnostic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would have its own backend-specific rule I suppose

Yes, because the core BSP logic should not be aware of the per-language specifics.


sources = {}
for i in dependency_modules.items:
for m in i.modules:
mavenmod: MavenDependencyModule = m.data
for x in mavenmod.artifacts:
if x.classifier == "sources":
sources[x.uri] = x

files = sources.keys()

return DependencySourcesResult(
tuple(DependencySourcesItem(target=tgt, sources=()) for tgt in request.targets)
tuple(DependencySourcesItem(target=tgt, sources=tuple(files)) for tgt in request.targets)
)


Expand Down
36 changes: 27 additions & 9 deletions src/python/pants/jvm/resolve/coursier_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,19 +213,25 @@ def direct_dependencies(
self, key: CoursierResolveKey, coord: Coordinate
) -> tuple[CoursierLockfileEntry, tuple[CoursierLockfileEntry, ...]]:
"""Return the entry for the given Coordinate, and for its direct dependencies."""
entries = {(i.coord.group, i.coord.artifact): i for i in self.entries}
entry = entries.get((coord.group, coord.artifact))
entries = {(i.coord.group, i.coord.artifact, i.coord.classifier): i for i in self.entries}
entry = entries.get((coord.group, coord.artifact, coord.classifier))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to make sure everything includes classifier as group, artifact might not be unique anymore

if entry is None:
raise self._coordinate_not_found(key, coord)

return (entry, tuple(entries[(i.group, i.artifact)] for i in entry.direct_dependencies))
return (
entry,
tuple(
entries[(i.group, i.artifact, i.coord.classifier)]
for i in entry.direct_dependencies
),
)

def dependencies(
self, key: CoursierResolveKey, coord: Coordinate
) -> tuple[CoursierLockfileEntry, tuple[CoursierLockfileEntry, ...]]:
"""Return the entry for the given Coordinate, and for its transitive dependencies."""
entries = {(i.coord.group, i.coord.artifact): i for i in self.entries}
entry = entries.get((coord.group, coord.artifact))
entries = {(i.coord.group, i.coord.artifact, i.coord.classifier): i for i in self.entries}
entry = entries.get((coord.group, coord.artifact, coord.classifier))
if entry is None:
raise self._coordinate_not_found(key, coord)

Expand All @@ -238,7 +244,8 @@ def dependencies(
# https://github.com/coursier/coursier/issues/2884
# As a workaround, if this happens, we want to skip the dependency.
# TODO Drop the check once the bug is fixed.
if (dependency_entry := entries.get((d.group, d.artifact))) is not None
if (dependency_entry := entries.get((d.group, d.artifact, d.classifier)))
is not None
),
)

Expand Down Expand Up @@ -627,14 +634,25 @@ async def coursier_fetch_one_coord(
report = json.loads(report_contents[0].content)

report_deps = report["dependencies"]

if len(report_deps) == 0:
raise CoursierError("Coursier fetch report has no dependencies (i.e. nothing was fetched).")
elif len(report_deps) > 1:

dep = None
for report_dep in report_deps:
report_dep_coord = Coordinate.from_coord_str(report_dep["coord"])
if report_dep_coord == request.coord:
if dep is not None:
raise CoursierError(
"Coursier fetch report has multiple dependencies, but exactly 1 was expected."
)
dep = report_dep

if dep is None:
raise CoursierError(
"Coursier fetch report has multiple dependencies, but exactly 1 was expected."
f'Coursier fetch report has no matching dependencies for coord "{request.coord.to_coord_str()}".'
)

dep = report_deps[0]
resolved_coord = Coordinate.from_coord_str(dep["coord"])
if resolved_coord != request.coord:
raise CoursierError(
Expand Down
Loading
Loading