|
5 | 5 | from dataclasses import field |
6 | 6 | from typing import Callable |
7 | 7 | from typing import cast |
| 8 | +from urllib.parse import urlparse |
| 9 | +from urllib.parse import urlunparse |
8 | 10 |
|
9 | 11 | from django.apps import apps |
10 | 12 | from django.conf import settings |
|
14 | 16 | from django.template.loader import get_template |
15 | 17 | from django.urls import reverse |
16 | 18 | from django.urls.exceptions import NoReverseMatch |
| 19 | +from django.utils.functional import Promise |
17 | 20 | from django.utils.safestring import mark_safe |
18 | 21 |
|
19 | 22 | from django_simple_nav._templates import get_template_engine |
@@ -63,7 +66,7 @@ def get_template_name(self) -> str: |
63 | 66 | @dataclass(frozen=True) |
64 | 67 | class NavItem: |
65 | 68 | title: str |
66 | | - url: str | None = None |
| 69 | + url: str | Callable[..., str] | None = None |
67 | 70 | permissions: list[str | Callable[[HttpRequest], bool]] = field(default_factory=list) |
68 | 71 | extra_context: dict[str, object] = field(default_factory=dict) |
69 | 72 |
|
@@ -92,14 +95,33 @@ def get_title(self) -> str: |
92 | 95 | def get_url(self) -> str: |
93 | 96 | url: str | None |
94 | 97 |
|
95 | | - try: |
96 | | - url = reverse(self.url) |
97 | | - except NoReverseMatch: |
98 | | - url = self.url |
| 98 | + if isinstance(self.url, Promise): |
| 99 | + # django.urls.base.reverse_lazy |
| 100 | + url = str(self.url) |
| 101 | + elif callable(self.url): |
| 102 | + # django.urls.base.reverse (or some other basic callable) |
| 103 | + url = self.url() |
| 104 | + else: |
| 105 | + try: |
| 106 | + url = reverse(self.url) |
| 107 | + except NoReverseMatch: |
| 108 | + url = self.url |
99 | 109 |
|
100 | 110 | if url is not None: |
101 | | - if settings.APPEND_SLASH and not url.endswith("/"): # pyright: ignore[reportAny] |
102 | | - url += "/" |
| 111 | + parsed_url = urlparse(url) |
| 112 | + path = parsed_url.path |
| 113 | + if settings.APPEND_SLASH and not path.endswith("/"): |
| 114 | + path += "/" |
| 115 | + url = urlunparse( |
| 116 | + ( |
| 117 | + parsed_url.scheme, |
| 118 | + parsed_url.netloc, |
| 119 | + path, |
| 120 | + parsed_url.params, |
| 121 | + parsed_url.query, |
| 122 | + parsed_url.fragment, |
| 123 | + ) |
| 124 | + ) |
103 | 125 | return url |
104 | 126 |
|
105 | 127 | msg = f"{self.__class__!r} must define 'url' or override 'get_url()'" |
|
0 commit comments