diff --git a/dev/environment b/dev/environment index 45a356ac8c35..2c211998cdee 100644 --- a/dev/environment +++ b/dev/environment @@ -91,3 +91,6 @@ HELPDESK_BACKEND="warehouse.helpdesk.services.ConsoleHelpDeskService" # HELPDESK_NOTIFICATION_SERVICE_URL="https://..." HELPDESK_NOTIFICATION_BACKEND="warehouse.helpdesk.services.ConsoleAdminNotificationService" + +# Example of Domain Status configuration +# DOMAIN_STATUS_BACKEND="warehouse.accounts.services.DomainrDomainStatusService client_id=some_client_id" diff --git a/tests/conftest.py b/tests/conftest.py index d8b64944582b..c16f1137339e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,11 @@ from warehouse import admin, config, static from warehouse.accounts import services as account_services -from warehouse.accounts.interfaces import ITokenService, IUserService +from warehouse.accounts.interfaces import ( + IDomainStatusService, + ITokenService, + IUserService, +) from warehouse.admin.flags import AdminFlag, AdminFlagValue from warehouse.attestations import services as attestations_services from warehouse.attestations.interfaces import IIntegrityService @@ -169,6 +173,7 @@ def pyramid_services( notification_service, query_results_cache_service, search_service, + domain_status_service, ): services = _Services() @@ -194,6 +199,7 @@ def pyramid_services( services.register_service(notification_service, IAdminNotificationService) services.register_service(query_results_cache_service, IQueryResultsCache) services.register_service(search_service, ISearchService) + services.register_service(domain_status_service, IDomainStatusService) return services @@ -543,6 +549,11 @@ def search_service(): return search_services.NullSearchService() +@pytest.fixture +def domain_status_service(): + return account_services.NullDomainStatusService() + + class QueryRecorder: def __init__(self): self.queries = [] diff --git a/tests/unit/accounts/test_core.py b/tests/unit/accounts/test_core.py index ef00be4092ce..58a7a4655292 100644 --- a/tests/unit/accounts/test_core.py +++ b/tests/unit/accounts/test_core.py @@ -17,6 +17,7 @@ from warehouse import accounts from warehouse.accounts.interfaces import ( + IDomainStatusService, IEmailBreachedService, IPasswordBreachedService, ITokenService, @@ -25,6 +26,7 @@ from warehouse.accounts.services import ( HaveIBeenPwnedEmailBreachedService, HaveIBeenPwnedPasswordBreachedService, + NullDomainStatusService, TokenServiceFactory, database_login_factory, ) @@ -186,6 +188,7 @@ def test_includeme(monkeypatch): HaveIBeenPwnedEmailBreachedService.create_service, IEmailBreachedService, ), + pretend.call(NullDomainStatusService.create_service, IDomainStatusService), pretend.call(RateLimit("10 per 5 minutes"), IRateLimiter, name="user.login"), pretend.call(RateLimit("10 per 5 minutes"), IRateLimiter, name="ip.login"), pretend.call( diff --git a/tests/unit/accounts/test_services.py b/tests/unit/accounts/test_services.py index ea6b13a555ee..4a99a8758608 100644 --- a/tests/unit/accounts/test_services.py +++ b/tests/unit/accounts/test_services.py @@ -30,6 +30,7 @@ from warehouse.accounts import services from warehouse.accounts.interfaces import ( BurnedRecoveryCode, + IDomainStatusService, IEmailBreachedService, InvalidRecoveryCode, IPasswordBreachedService, @@ -1635,3 +1636,79 @@ def test_factory(self): assert isinstance(svc, services.NullEmailBreachedService) assert svc.get_email_breach_count("foo@example.com") == 0 + + +class TestNullDomainStatusService: + def test_verify_service(self): + assert verifyClass(IDomainStatusService, services.NullDomainStatusService) + + def test_get_domain_status(self): + svc = services.NullDomainStatusService() + assert svc.get_domain_status("example.com") == ["active"] + + def test_factory(self): + context = pretend.stub() + request = pretend.stub() + svc = services.NullDomainStatusService.create_service(context, request) + + assert isinstance(svc, services.NullDomainStatusService) + assert svc.get_domain_status("example.com") == ["active"] + + +class TestDomainrDomainStatusService: + def test_verify_service(self): + assert verifyClass(IDomainStatusService, services.DomainrDomainStatusService) + + def test_successful_domain_status_check(self): + response = pretend.stub( + json=lambda: { + "status": [{"domain": "example.com", "status": "undelegated inactive"}] + }, + raise_for_status=lambda: None, + ) + session = pretend.stub(get=pretend.call_recorder(lambda *a, **kw: response)) + svc = services.DomainrDomainStatusService( + session=session, client_id="some_client_id" + ) + + assert svc.get_domain_status("example.com") == ["undelegated", "inactive"] + assert session.get.calls == [ + pretend.call( + "https://api.domainr.com/v2/status", + params={"client_id": "some_client_id", "domain": "example.com"}, + timeout=5, + ) + ] + + def test_domainr_exception_returns_empty(self): + class DomainrException(requests.HTTPError): + def __init__(self): + self.response = pretend.stub(status_code=400) + + response = pretend.stub(raise_for_status=pretend.raiser(DomainrException)) + session = pretend.stub(get=pretend.call_recorder(lambda *a, **kw: response)) + svc = services.DomainrDomainStatusService( + session=session, client_id="some_client_id" + ) + + assert svc.get_domain_status("example.com") == [] + assert session.get.calls == [ + pretend.call( + "https://api.domainr.com/v2/status", + params={"client_id": "some_client_id", "domain": "example.com"}, + timeout=5, + ) + ] + + def test_factory(self): + context = pretend.stub() + request = pretend.stub( + http=pretend.stub(), + registry=pretend.stub( + settings={"domain_status.client_id": "some_client_id"} + ), + ) + svc = services.DomainrDomainStatusService.create_service(context, request) + + assert svc._http is request.http + assert svc.client_id == "some_client_id" diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 49e25656f932..3abe9b372a51 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -89,6 +89,13 @@ def test_includeme(): factory="warehouse.accounts.models:UserFactory", traverse="/{username}", ), + pretend.call( + "admin.user.email_domain_check", + "/admin/users/{username}/email_domain_check/", + domain=warehouse, + factory="warehouse.accounts.models:UserFactory", + traverse="/{username}", + ), pretend.call( "admin.user.delete", "/admin/users/{username}/delete/", diff --git a/tests/unit/admin/views/test_users.py b/tests/unit/admin/views/test_users.py index 2503de2c71b3..bf9cf89a2364 100644 --- a/tests/unit/admin/views/test_users.py +++ b/tests/unit/admin/views/test_users.py @@ -1539,3 +1539,26 @@ def test_no_recovery_codes_provided(self, db_request, monkeypatch, user_service) ] assert result.status_code == 303 assert result.location == "/foobar" + + +class TestUserEmailDomainCheck: + def test_user_email_domain_check(self, db_request): + user = UserFactory.create(with_verified_primary_email=True) + db_request.POST["email_address"] = user.primary_email.email + db_request.route_path = pretend.call_recorder(lambda *a, **kw: "/foobar") + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + + result = views.user_email_domain_check(user, db_request) + + assert isinstance(result, HTTPSeeOther) + assert result.headers["Location"] == "/foobar" + assert db_request.session.flash.calls == [ + pretend.call( + f"Domain status check for '{user.primary_email.domain}' completed", + queue="success", + ) + ] + assert user.primary_email.domain_last_checked is not None + assert user.primary_email.domain_last_status == ["active"] diff --git a/warehouse/accounts/__init__.py b/warehouse/accounts/__init__.py index bad7b340d6a8..11d64295fdf2 100644 --- a/warehouse/accounts/__init__.py +++ b/warehouse/accounts/__init__.py @@ -13,6 +13,7 @@ from celery.schedules import crontab from warehouse.accounts.interfaces import ( + IDomainStatusService, IEmailBreachedService, IPasswordBreachedService, ITokenService, @@ -25,6 +26,7 @@ from warehouse.accounts.services import ( HaveIBeenPwnedEmailBreachedService, HaveIBeenPwnedPasswordBreachedService, + NullDomainStatusService, NullEmailBreachedService, NullPasswordBreachedService, TokenServiceFactory, @@ -131,6 +133,14 @@ def includeme(config): breached_email_class.create_service, IEmailBreachedService ) + # Register our domain status service. + domain_status_class = config.maybe_dotted( + config.registry.settings.get("domain_status.backend", NullDomainStatusService) + ) + config.register_service_factory( + domain_status_class.create_service, IDomainStatusService + ) + # Register our security policies. config.set_security_policy( MultiSecurityPolicy( diff --git a/warehouse/accounts/interfaces.py b/warehouse/accounts/interfaces.py index aa3f7f501039..15f7114955c0 100644 --- a/warehouse/accounts/interfaces.py +++ b/warehouse/accounts/interfaces.py @@ -298,3 +298,10 @@ def get_email_breach_count(email: str) -> int | None: """ Returns count of times the email appears in verified breaches. """ + + +class IDomainStatusService(Interface): + def get_domain_status(domain: str) -> list[str]: + """ + Returns a list of status strings for the given domain. + """ diff --git a/warehouse/accounts/models.py b/warehouse/accounts/models.py index 4a8571fd97d3..55e93a8c65e2 100644 --- a/warehouse/accounts/models.py +++ b/warehouse/accounts/models.py @@ -29,7 +29,7 @@ select, sql, ) -from sqlalchemy.dialects.postgresql import CITEXT, UUID as PG_UUID +from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, UUID as PG_UUID from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import Mapped, mapped_column @@ -424,6 +424,15 @@ class Email(db.ModelBase): unverify_reason: Mapped[UnverifyReasons | None] transient_bounces: Mapped[int] = mapped_column(server_default=sql.text("0")) + # Domain validation information + domain_last_checked: Mapped[datetime.datetime | None] = mapped_column( + comment="Last time domain was checked with the domain validation service.", + ) + domain_last_status: Mapped[list[str] | None] = mapped_column( + ARRAY(String), + comment="Status strings returned by the domain validation service.", + ) + @property def domain(self): return self.email.split("@")[-1].lower() diff --git a/warehouse/accounts/services.py b/warehouse/accounts/services.py index a20a0f131158..efb9904c4a7a 100644 --- a/warehouse/accounts/services.py +++ b/warehouse/accounts/services.py @@ -10,6 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import collections import datetime import functools @@ -18,6 +20,7 @@ import logging import os import secrets +import typing import urllib.parse import passlib.exc @@ -35,6 +38,7 @@ from warehouse.accounts.interfaces import ( BurnedRecoveryCode, + IDomainStatusService, IEmailBreachedService, InvalidRecoveryCode, IPasswordBreachedService, @@ -62,6 +66,9 @@ from warehouse.rate_limiting import DummyRateLimiter, IRateLimiter from warehouse.utils.crypto import BadData, SignatureExpired, URLSafeTimedSerializer +if typing.TYPE_CHECKING: + from pyramid.request import Request + logger = logging.getLogger(__name__) PASSWORD_FIELD = "password" @@ -962,3 +969,43 @@ def create_service(cls, context, request): def get_email_breach_count(self, email): # This service allows *every* email as a non-breached email. return 0 + + +@implementer(IDomainStatusService) +class NullDomainStatusService: + @classmethod + def create_service(cls, _context, _request): + return cls() + + def get_domain_status(self, _domain: str) -> list[str]: + return ["active"] + + +@implementer(IDomainStatusService) +class DomainrDomainStatusService: + def __init__(self, session, client_id): + self._http = session + self.client_id = client_id + + @classmethod + def create_service(cls, _context, request: Request) -> DomainrDomainStatusService: + domainr_client_id = request.registry.settings.get("domain_status.client_id") + return cls(session=request.http, client_id=domainr_client_id) + + def get_domain_status(self, domain: str) -> list[str]: + """ + Check if a domain is available or not. + See https://domainr.com/docs/api/v2/status + """ + try: + resp = self._http.get( + "https://api.domainr.com/v2/status", + params={"client_id": self.client_id, "domain": domain}, + timeout=5, + ) + resp.raise_for_status() + except requests.RequestException as exc: + logger.warning("Error contacting Domainr: %r", exc) + return [] + + return resp.json()["status"][0]["status"].split() diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 4c011f120de6..b19932fa45ba 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -87,6 +87,13 @@ def includeme(config): factory="warehouse.accounts.models:UserFactory", traverse="/{username}", ) + config.add_route( + "admin.user.email_domain_check", + "/admin/users/{username}/email_domain_check/", + factory="warehouse.accounts.models:UserFactory", + domain=warehouse, + traverse="/{username}", + ) config.add_route( "admin.user.delete", "/admin/users/{username}/delete/", diff --git a/warehouse/admin/templates/admin/users/detail.html b/warehouse/admin/templates/admin/users/detail.html index 69fa269babf2..5be7ef344880 100644 --- a/warehouse/admin/templates/admin/users/detail.html +++ b/warehouse/admin/templates/admin/users/detail.html @@ -532,33 +532,45 @@