Skip to content

feat: add domain verification #17832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dev/environment
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,6 @@ HELPDESK_BACKEND="warehouse.helpdesk.services.ConsoleHelpDeskService"

# HELPDESK_NOTIFICATION_SERVICE_URL="https://..."
HELPDESK_NOTIFICATION_BACKEND="warehouse.helpdesk.services.ConsoleAdminNotificationService"

# Example of Domain Status configuration
# DOMAIN_STATUS_BACKEND="warehouse.accounts.services.DomainrDomainStatusService client_id=some_client_id"
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@

from warehouse import admin, config, static
from warehouse.accounts import services as account_services
from warehouse.accounts.interfaces import ITokenService, IUserService
from warehouse.accounts.interfaces import (
IDomainStatusService,
ITokenService,
IUserService,
)
from warehouse.admin.flags import AdminFlag, AdminFlagValue
from warehouse.attestations import services as attestations_services
from warehouse.attestations.interfaces import IIntegrityService
Expand Down Expand Up @@ -169,6 +173,7 @@ def pyramid_services(
notification_service,
query_results_cache_service,
search_service,
domain_status_service,
):
services = _Services()

Expand All @@ -194,6 +199,7 @@ def pyramid_services(
services.register_service(notification_service, IAdminNotificationService)
services.register_service(query_results_cache_service, IQueryResultsCache)
services.register_service(search_service, ISearchService)
services.register_service(domain_status_service, IDomainStatusService)

return services

Expand Down Expand Up @@ -543,6 +549,11 @@ def search_service():
return search_services.NullSearchService()


@pytest.fixture
def domain_status_service():
return account_services.NullDomainStatusService()


class QueryRecorder:
def __init__(self):
self.queries = []
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/accounts/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from warehouse import accounts
from warehouse.accounts.interfaces import (
IDomainStatusService,
IEmailBreachedService,
IPasswordBreachedService,
ITokenService,
Expand All @@ -25,6 +26,7 @@
from warehouse.accounts.services import (
HaveIBeenPwnedEmailBreachedService,
HaveIBeenPwnedPasswordBreachedService,
NullDomainStatusService,
TokenServiceFactory,
database_login_factory,
)
Expand Down Expand Up @@ -186,6 +188,7 @@ def test_includeme(monkeypatch):
HaveIBeenPwnedEmailBreachedService.create_service,
IEmailBreachedService,
),
pretend.call(NullDomainStatusService.create_service, IDomainStatusService),
pretend.call(RateLimit("10 per 5 minutes"), IRateLimiter, name="user.login"),
pretend.call(RateLimit("10 per 5 minutes"), IRateLimiter, name="ip.login"),
pretend.call(
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/accounts/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from warehouse.accounts import services
from warehouse.accounts.interfaces import (
BurnedRecoveryCode,
IDomainStatusService,
IEmailBreachedService,
InvalidRecoveryCode,
IPasswordBreachedService,
Expand Down Expand Up @@ -1635,3 +1636,79 @@ def test_factory(self):

assert isinstance(svc, services.NullEmailBreachedService)
assert svc.get_email_breach_count("[email protected]") == 0


class TestNullDomainStatusService:
def test_verify_service(self):
assert verifyClass(IDomainStatusService, services.NullDomainStatusService)

def test_get_domain_status(self):
svc = services.NullDomainStatusService()
assert svc.get_domain_status("example.com") == ["active"]

def test_factory(self):
context = pretend.stub()
request = pretend.stub()
svc = services.NullDomainStatusService.create_service(context, request)

assert isinstance(svc, services.NullDomainStatusService)
assert svc.get_domain_status("example.com") == ["active"]


class TestDomainrDomainStatusService:
def test_verify_service(self):
assert verifyClass(IDomainStatusService, services.DomainrDomainStatusService)

def test_successful_domain_status_check(self):
response = pretend.stub(
json=lambda: {
"status": [{"domain": "example.com", "status": "undelegated inactive"}]
},
raise_for_status=lambda: None,
)
session = pretend.stub(get=pretend.call_recorder(lambda *a, **kw: response))
svc = services.DomainrDomainStatusService(
session=session, client_id="some_client_id"
)

assert svc.get_domain_status("example.com") == ["undelegated", "inactive"]
assert session.get.calls == [
pretend.call(
"https://api.domainr.com/v2/status",
params={"client_id": "some_client_id", "domain": "example.com"},
timeout=5,
)
]

def test_domainr_exception_returns_empty(self):
class DomainrException(requests.HTTPError):
def __init__(self):
self.response = pretend.stub(status_code=400)

response = pretend.stub(raise_for_status=pretend.raiser(DomainrException))
session = pretend.stub(get=pretend.call_recorder(lambda *a, **kw: response))
svc = services.DomainrDomainStatusService(
session=session, client_id="some_client_id"
)

assert svc.get_domain_status("example.com") == []
assert session.get.calls == [
pretend.call(
"https://api.domainr.com/v2/status",
params={"client_id": "some_client_id", "domain": "example.com"},
timeout=5,
)
]

def test_factory(self):
context = pretend.stub()
request = pretend.stub(
http=pretend.stub(),
registry=pretend.stub(
settings={"domain_status.client_id": "some_client_id"}
),
)
svc = services.DomainrDomainStatusService.create_service(context, request)

assert svc._http is request.http
assert svc.client_id == "some_client_id"
7 changes: 7 additions & 0 deletions tests/unit/admin/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ def test_includeme():
factory="warehouse.accounts.models:UserFactory",
traverse="/{username}",
),
pretend.call(
"admin.user.email_domain_check",
"/admin/users/{username}/email_domain_check/",
domain=warehouse,
factory="warehouse.accounts.models:UserFactory",
traverse="/{username}",
),
pretend.call(
"admin.user.delete",
"/admin/users/{username}/delete/",
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/admin/views/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,3 +1539,26 @@ def test_no_recovery_codes_provided(self, db_request, monkeypatch, user_service)
]
assert result.status_code == 303
assert result.location == "/foobar"


class TestUserEmailDomainCheck:
def test_user_email_domain_check(self, db_request):
user = UserFactory.create(with_verified_primary_email=True)
db_request.POST["email_address"] = user.primary_email.email
db_request.route_path = pretend.call_recorder(lambda *a, **kw: "/foobar")
db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)

result = views.user_email_domain_check(user, db_request)

assert isinstance(result, HTTPSeeOther)
assert result.headers["Location"] == "/foobar"
assert db_request.session.flash.calls == [
pretend.call(
f"Domain status check for '{user.primary_email.domain}' completed",
queue="success",
)
]
assert user.primary_email.domain_last_checked is not None
assert user.primary_email.domain_last_status == ["active"]
10 changes: 10 additions & 0 deletions warehouse/accounts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from celery.schedules import crontab

from warehouse.accounts.interfaces import (
IDomainStatusService,
IEmailBreachedService,
IPasswordBreachedService,
ITokenService,
Expand All @@ -25,6 +26,7 @@
from warehouse.accounts.services import (
HaveIBeenPwnedEmailBreachedService,
HaveIBeenPwnedPasswordBreachedService,
NullDomainStatusService,
NullEmailBreachedService,
NullPasswordBreachedService,
TokenServiceFactory,
Expand Down Expand Up @@ -131,6 +133,14 @@ def includeme(config):
breached_email_class.create_service, IEmailBreachedService
)

# Register our domain status service.
domain_status_class = config.maybe_dotted(
config.registry.settings.get("domain_status.backend", NullDomainStatusService)
)
config.register_service_factory(
domain_status_class.create_service, IDomainStatusService
)

# Register our security policies.
config.set_security_policy(
MultiSecurityPolicy(
Expand Down
7 changes: 7 additions & 0 deletions warehouse/accounts/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,10 @@ def get_email_breach_count(email: str) -> int | None:
"""
Returns count of times the email appears in verified breaches.
"""


class IDomainStatusService(Interface):
def get_domain_status(domain: str) -> list[str]:
"""
Returns a list of status strings for the given domain.
"""
11 changes: 10 additions & 1 deletion warehouse/accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
select,
sql,
)
from sqlalchemy.dialects.postgresql import CITEXT, UUID as PG_UUID
from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, UUID as PG_UUID
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapped, mapped_column
Expand Down Expand Up @@ -424,6 +424,15 @@ class Email(db.ModelBase):
unverify_reason: Mapped[UnverifyReasons | None]
transient_bounces: Mapped[int] = mapped_column(server_default=sql.text("0"))

# Domain validation information
domain_last_checked: Mapped[datetime.datetime | None] = mapped_column(
comment="Last time domain was checked with the domain validation service.",
)
domain_last_status: Mapped[list[str] | None] = mapped_column(
ARRAY(String),
comment="Status strings returned by the domain validation service.",
)

@property
def domain(self):
return self.email.split("@")[-1].lower()
Expand Down
47 changes: 47 additions & 0 deletions warehouse/accounts/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections
import datetime
import functools
Expand All @@ -18,6 +20,7 @@
import logging
import os
import secrets
import typing
import urllib.parse

import passlib.exc
Expand All @@ -35,6 +38,7 @@

from warehouse.accounts.interfaces import (
BurnedRecoveryCode,
IDomainStatusService,
IEmailBreachedService,
InvalidRecoveryCode,
IPasswordBreachedService,
Expand Down Expand Up @@ -62,6 +66,9 @@
from warehouse.rate_limiting import DummyRateLimiter, IRateLimiter
from warehouse.utils.crypto import BadData, SignatureExpired, URLSafeTimedSerializer

if typing.TYPE_CHECKING:
from pyramid.request import Request

logger = logging.getLogger(__name__)

PASSWORD_FIELD = "password"
Expand Down Expand Up @@ -962,3 +969,43 @@ def create_service(cls, context, request):
def get_email_breach_count(self, email):
# This service allows *every* email as a non-breached email.
return 0


@implementer(IDomainStatusService)
class NullDomainStatusService:
@classmethod
def create_service(cls, _context, _request):
return cls()

def get_domain_status(self, _domain: str) -> list[str]:
return ["active"]


@implementer(IDomainStatusService)
class DomainrDomainStatusService:
def __init__(self, session, client_id):
self._http = session
self.client_id = client_id

@classmethod
def create_service(cls, _context, request: Request) -> DomainrDomainStatusService:
domainr_client_id = request.registry.settings.get("domain_status.client_id")
return cls(session=request.http, client_id=domainr_client_id)

def get_domain_status(self, domain: str) -> list[str]:
"""
Check if a domain is available or not.
See https://domainr.com/docs/api/v2/status
"""
try:
resp = self._http.get(
"https://api.domainr.com/v2/status",
params={"client_id": self.client_id, "domain": domain},
timeout=5,
)
resp.raise_for_status()
except requests.RequestException as exc:
logger.warning("Error contacting Domainr: %r", exc)
return []

return resp.json()["status"][0]["status"].split()
7 changes: 7 additions & 0 deletions warehouse/admin/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def includeme(config):
factory="warehouse.accounts.models:UserFactory",
traverse="/{username}",
)
config.add_route(
"admin.user.email_domain_check",
"/admin/users/{username}/email_domain_check/",
factory="warehouse.accounts.models:UserFactory",
domain=warehouse,
traverse="/{username}",
)
config.add_route(
"admin.user.delete",
"/admin/users/{username}/delete/",
Expand Down
Loading