Skip to content

Commit 4ef2df2

Browse files
radoeringabn
authored andcommitted
improve performance for merging markers from overrides
1 parent 30fe6c7 commit 4ef2df2

File tree

2 files changed

+153
-110
lines changed

2 files changed

+153
-110
lines changed

src/poetry/puzzle/solver.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,12 @@ def _solve_in_compatibility_mode(
140140
self,
141141
overrides: tuple[dict[Package, dict[str, Dependency]], ...],
142142
) -> dict[Package, TransitivePackageInfo]:
143-
packages: dict[Package, TransitivePackageInfo] = {}
143+
override_packages: list[
144+
tuple[
145+
dict[Package, dict[str, Dependency]],
146+
dict[Package, TransitivePackageInfo],
147+
]
148+
] = []
144149
for override in overrides:
145150
self._provider.debug(
146151
# ignore the warning as provider does not do interpolation
@@ -149,9 +154,9 @@ def _solve_in_compatibility_mode(
149154
)
150155
self._provider.set_overrides(override)
151156
new_packages = self._solve()
152-
merge_packages_from_override(packages, new_packages, override)
157+
override_packages.append((override, new_packages))
153158

154-
return packages
159+
return merge_override_packages(override_packages)
155160

156161
def _solve(self) -> dict[Package, TransitivePackageInfo]:
157162
if self._provider._overrides:
@@ -406,34 +411,63 @@ def calculate_markers(
406411
transitive_info.markers = transitive_marker
407412

408413

409-
def merge_packages_from_override(
410-
packages: dict[Package, TransitivePackageInfo],
411-
new_packages: dict[Package, TransitivePackageInfo],
412-
override: dict[Package, dict[str, Dependency]],
413-
) -> None:
414-
override_marker: BaseMarker = AnyMarker()
415-
for deps in override.values():
416-
for dep in deps.values():
417-
override_marker = override_marker.intersect(dep.marker.without_extras())
418-
for new_package, new_package_info in new_packages.items():
419-
if package_info := packages.get(new_package):
420-
# update existing package
421-
package_info.depth = max(package_info.depth, new_package_info.depth)
422-
package_info.groups.update(new_package_info.groups)
423-
for group, marker in new_package_info.markers.items():
424-
package_info.markers[group] = package_info.markers.get(
425-
group, EmptyMarker()
426-
).union(override_marker.intersect(marker))
427-
for package in packages:
428-
if package == new_package:
429-
for dep in new_package.requires:
430-
if dep not in package.requires:
431-
package.add_dependency(dep)
432-
414+
def merge_override_packages(
415+
override_packages: list[
416+
tuple[
417+
dict[Package, dict[str, Dependency]], dict[Package, TransitivePackageInfo]
418+
]
419+
],
420+
) -> dict[Package, TransitivePackageInfo]:
421+
result: dict[Package, TransitivePackageInfo] = {}
422+
all_packages: dict[
423+
Package, list[tuple[Package, TransitivePackageInfo, BaseMarker]]
424+
] = {}
425+
for override, o_packages in override_packages:
426+
override_marker: BaseMarker = AnyMarker()
427+
for deps in override.values():
428+
for dep in deps.values():
429+
override_marker = override_marker.intersect(dep.marker.without_extras())
430+
for package, info in o_packages.items():
431+
all_packages.setdefault(package, []).append(
432+
(package, info, override_marker)
433+
)
434+
for package_duplicates in all_packages.values():
435+
base = package_duplicates[0]
436+
package = base[0]
437+
package_info = base[1]
438+
first_override_marker = base[2]
439+
result[package] = package_info
440+
package_info.depth = max(info.depth for _, info, _ in package_duplicates)
441+
package_info.groups = {
442+
g for _, info, _ in package_duplicates for g in info.groups
443+
}
444+
if all(
445+
info.markers == package_info.markers for _, info, _ in package_duplicates
446+
):
447+
# performance shortcut:
448+
# if markers are the same for all overrides,
449+
# we can use less expensive marker operations
450+
override_marker = EmptyMarker()
451+
for _, _, marker in package_duplicates:
452+
override_marker = override_marker.union(marker)
453+
package_info.markers = {
454+
group: override_marker.intersect(marker)
455+
for group, marker in package_info.markers.items()
456+
}
433457
else:
434-
for group, marker in new_package_info.markers.items():
435-
new_package_info.markers[group] = override_marker.intersect(marker)
436-
packages[new_package] = new_package_info
458+
# fallback / general algorithm with performance issues
459+
for group, marker in package_info.markers.items():
460+
package_info.markers[group] = first_override_marker.intersect(marker)
461+
for _, info, override_marker in package_duplicates[1:]:
462+
for group, marker in info.markers.items():
463+
package_info.markers[group] = package_info.markers.get(
464+
group, EmptyMarker()
465+
).union(override_marker.intersect(marker))
466+
for duplicate_package, _, _ in package_duplicates[1:]:
467+
for dep in duplicate_package.requires:
468+
if dep not in package.requires:
469+
package.add_dependency(dep)
470+
return result
437471

438472

439473
@functools.cache

tests/puzzle/test_solver_internals.py

Lines changed: 89 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from poetry.puzzle.solver import PackageNode
1414
from poetry.puzzle.solver import Solver
1515
from poetry.puzzle.solver import depth_first_search
16-
from poetry.puzzle.solver import merge_packages_from_override
16+
from poetry.puzzle.solver import merge_override_packages
1717

1818

1919
if TYPE_CHECKING:
@@ -359,28 +359,29 @@ def test_propagate_markers_with_cycle(package: ProjectPackage, solver: Solver) -
359359
}
360360

361361

362-
def test_merge_packages_from_override_restricted(package: ProjectPackage) -> None:
362+
def test_merge_override_packages_restricted(package: ProjectPackage) -> None:
363363
"""Markers of dependencies should be intersected with override markers."""
364364
a = Package("a", "1")
365365

366-
packages: dict[Package, TransitivePackageInfo] = {}
367-
merge_packages_from_override(
368-
packages,
369-
{
370-
a: TransitivePackageInfo(
371-
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
372-
)
373-
},
374-
{package: {"a": dep("b", 'python_version < "3.9"')}},
375-
)
376-
merge_packages_from_override(
377-
packages,
378-
{
379-
a: TransitivePackageInfo(
380-
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
381-
)
382-
},
383-
{package: {"a": dep("b", 'python_version >= "3.9"')}},
366+
packages = merge_override_packages(
367+
[
368+
(
369+
{package: {"a": dep("b", 'python_version < "3.9"')}},
370+
{
371+
a: TransitivePackageInfo(
372+
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
373+
)
374+
},
375+
),
376+
(
377+
{package: {"a": dep("b", 'python_version >= "3.9"')}},
378+
{
379+
a: TransitivePackageInfo(
380+
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
381+
)
382+
},
383+
),
384+
]
384385
)
385386
assert len(packages) == 1
386387
assert packages[a].groups == {"main"}
@@ -392,28 +393,33 @@ def test_merge_packages_from_override_restricted(package: ProjectPackage) -> Non
392393
}
393394

394395

395-
def test_merge_packages_from_override_extras(package: ProjectPackage) -> None:
396+
def test_merge_override_packages_extras(package: ProjectPackage) -> None:
396397
"""Extras from overrides should not be visible in the resulting marker."""
397398
a = Package("a", "1")
398399

399-
packages: dict[Package, TransitivePackageInfo] = {}
400-
merge_packages_from_override(
401-
packages,
402-
{
403-
a: TransitivePackageInfo(
404-
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
405-
)
406-
},
407-
{package: {"a": dep("b", 'python_version < "3.9" and extra == "foo"')}},
408-
)
409-
merge_packages_from_override(
410-
packages,
411-
{
412-
a: TransitivePackageInfo(
413-
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
414-
)
415-
},
416-
{package: {"a": dep("b", 'python_version >= "3.9" and extra == "foo"')}},
400+
packages = merge_override_packages(
401+
[
402+
(
403+
{package: {"a": dep("b", 'python_version < "3.9" and extra == "foo"')}},
404+
{
405+
a: TransitivePackageInfo(
406+
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
407+
)
408+
},
409+
),
410+
(
411+
{
412+
package: {
413+
"a": dep("b", 'python_version >= "3.9" and extra == "foo"')
414+
}
415+
},
416+
{
417+
a: TransitivePackageInfo(
418+
0, {"main"}, {"main": parse_marker("sys_platform == 'linux'")}
419+
)
420+
},
421+
),
422+
]
417423
)
418424
assert len(packages) == 1
419425
assert packages[a].groups == {"main"}
@@ -425,21 +431,23 @@ def test_merge_packages_from_override_extras(package: ProjectPackage) -> None:
425431
}
426432

427433

428-
def test_merge_packages_from_override_multiple_deps(package: ProjectPackage) -> None:
434+
def test_merge_override_packages_multiple_deps(package: ProjectPackage) -> None:
429435
"""All override markers should be intersected."""
430436
a = Package("a", "1")
431437

432-
packages: dict[Package, TransitivePackageInfo] = {}
433-
merge_packages_from_override(
434-
packages,
435-
{a: TransitivePackageInfo(0, {"main"}, {"main": AnyMarker()})},
436-
{
437-
package: {
438-
"a": dep("b", 'python_version < "3.9"'),
439-
"c": dep("d", 'sys_platform == "linux"'),
440-
},
441-
a: {"e": dep("f", 'python_version >= "3.8"')},
442-
},
438+
packages = merge_override_packages(
439+
[
440+
(
441+
{
442+
package: {
443+
"a": dep("b", 'python_version < "3.9"'),
444+
"c": dep("d", 'sys_platform == "linux"'),
445+
},
446+
a: {"e": dep("f", 'python_version >= "3.8"')},
447+
},
448+
{a: TransitivePackageInfo(0, {"main"}, {"main": AnyMarker()})},
449+
),
450+
]
443451
)
444452

445453
assert len(packages) == 1
@@ -452,44 +460,45 @@ def test_merge_packages_from_override_multiple_deps(package: ProjectPackage) ->
452460
}
453461

454462

455-
def test_merge_packages_from_override_groups(package: ProjectPackage) -> None:
463+
def test_merge_override_packages_groups(package: ProjectPackage) -> None:
456464
a = Package("a", "1")
457465
b = Package("b", "1")
458466

459-
packages: dict[Package, TransitivePackageInfo] = {}
460-
merge_packages_from_override(
461-
packages,
462-
{
463-
a: TransitivePackageInfo(
464-
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
465-
),
466-
b: TransitivePackageInfo(
467-
0,
468-
{"main", "dev"},
467+
packages = merge_override_packages(
468+
[
469+
(
470+
{package: {"a": dep("b", 'python_version < "3.9"')}},
469471
{
470-
"main": parse_marker("sys_platform == 'win32'"),
471-
"dev": parse_marker("sys_platform == 'linux'"),
472+
a: TransitivePackageInfo(
473+
0, {"main"}, {"main": parse_marker("sys_platform == 'win32'")}
474+
),
475+
b: TransitivePackageInfo(
476+
0,
477+
{"main", "dev"},
478+
{
479+
"main": parse_marker("sys_platform == 'win32'"),
480+
"dev": parse_marker("sys_platform == 'linux'"),
481+
},
482+
),
472483
},
473484
),
474-
},
475-
{package: {"a": dep("b", 'python_version < "3.9"')}},
476-
)
477-
merge_packages_from_override(
478-
packages,
479-
{
480-
a: TransitivePackageInfo(
481-
0, {"dev"}, {"dev": parse_marker("sys_platform == 'linux'")}
482-
),
483-
b: TransitivePackageInfo(
484-
0,
485-
{"main", "dev"},
485+
(
486+
{package: {"a": dep("b", 'python_version >= "3.9"')}},
486487
{
487-
"main": parse_marker("platform_machine == 'amd64'"),
488-
"dev": parse_marker("platform_machine == 'aarch64'"),
488+
a: TransitivePackageInfo(
489+
0, {"dev"}, {"dev": parse_marker("sys_platform == 'linux'")}
490+
),
491+
b: TransitivePackageInfo(
492+
0,
493+
{"main", "dev"},
494+
{
495+
"main": parse_marker("platform_machine == 'amd64'"),
496+
"dev": parse_marker("platform_machine == 'aarch64'"),
497+
},
498+
),
489499
},
490500
),
491-
},
492-
{package: {"a": dep("b", 'python_version >= "3.9"')}},
501+
]
493502
)
494503
assert len(packages) == 2
495504
assert packages[a].groups == {"main", "dev"}

0 commit comments

Comments
 (0)