Skip to content

Commit 17eb243

Browse files
Multiple providers and unit tests (GH-13)
2 parents c6b4c73 + ff8385f commit 17eb243

18 files changed

+250
-149
lines changed

README.md

-6
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@
99
FastAPI OAuth2 is a middleware-based social authentication mechanism supporting several auth providers. It depends on
1010
the [social-core](https://github.com/python-social-auth/social-core) authentication backends.
1111

12-
## Features to be implemented
13-
14-
- Use multiple OAuth2 providers at the same time
15-
* There need to be provided a way to configure the OAuth2 for multiple providers
16-
- Customizable OAuth2 routes
17-
1812
## Installation
1913

2014
```shell

examples/demonstration/.env

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
OAUTH2_CLIENT_ID=eccd08d6736b7999a32a
2-
OAUTH2_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
1+
# These id and secret are generated especially for testing purposes,
2+
# if you have your own, please use them, otherwise you can use these.
3+
OAUTH2_GITHUB_CLIENT_ID=eccd08d6736b7999a32a
4+
OAUTH2_GITHUB_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
5+
6+
OAUTH2_GOOGLE_CLIENT_ID=105851609656-uueuan570963mnnf4288nv40eieh9f5l.apps.googleusercontent.com
7+
OAUTH2_GOOGLE_CLIENT_SECRET=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
38

49
JWT_SECRET=secret
510
JWT_ALGORITHM=HS256

examples/demonstration/config.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dotenv import load_dotenv
44
from social_core.backends.github import GithubOAuth2
5+
from social_core.backends.google import GoogleOAuth2
56

67
from fastapi_oauth2.claims import Claims
78
from fastapi_oauth2.client import OAuth2Client
@@ -17,14 +18,22 @@
1718
clients=[
1819
OAuth2Client(
1920
backend=GithubOAuth2,
20-
client_id=os.getenv("OAUTH2_CLIENT_ID"),
21-
client_secret=os.getenv("OAUTH2_CLIENT_SECRET"),
22-
# redirect_uri="http://127.0.0.1:8000/",
21+
client_id=os.getenv("OAUTH2_GITHUB_CLIENT_ID"),
22+
client_secret=os.getenv("OAUTH2_GITHUB_CLIENT_SECRET"),
2323
scope=["user:email"],
2424
claims=Claims(
2525
picture="avatar_url",
2626
identity=lambda user: "%s:%s" % (user.get("provider"), user.get("id")),
2727
),
2828
),
29+
OAuth2Client(
30+
backend=GoogleOAuth2,
31+
client_id=os.getenv("OAUTH2_GOOGLE_CLIENT_ID"),
32+
client_secret=os.getenv("OAUTH2_GOOGLE_CLIENT_SECRET"),
33+
scope=["openid", "profile", "email"],
34+
claims=Claims(
35+
identity=lambda user: "%s:%s" % (user.get("provider"), user.get("sub")),
36+
),
37+
),
2938
]
3039
)

examples/demonstration/main.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastapi import APIRouter
22
from fastapi import FastAPI
3+
from fastapi.staticfiles import StaticFiles
34
from sqlalchemy.orm import Session
45

56
from config import oauth2_config
@@ -24,16 +25,18 @@ async def on_auth(auth: Auth, user: User):
2425
db: Session = next(get_db())
2526
query = db.query(UserModel)
2627
if user.identity and not query.filter_by(identity=user.identity).first():
28+
# create a local user by OAuth2 user's data if it does not exist yet
2729
UserModel(**{
28-
"identity": user.get("identity"),
29-
"username": user.get("username"),
30-
"image": user.get("image"),
31-
"email": user.get("email"),
32-
"name": user.get("name"),
30+
"identity": user.identity, # User property
31+
"username": user.get("username"), # custom attribute
32+
"name": user.display_name, # User property
33+
"image": user.picture, # User property
34+
"email": user.email, # User property
3335
}).save(db)
3436

3537

3638
app = FastAPI()
3739
app.include_router(app_router)
3840
app.include_router(oauth2_router)
41+
app.mount("/static", StaticFiles(directory="static"), name="static")
3942
app.add_middleware(OAuth2Middleware, config=oauth2_config, callback=on_auth)
+5
Loading
Loading

examples/demonstration/templates/index.html

+17-5
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,30 @@
2121
<a href="/auth" style="display: flex; align-items: center; color: #dfdfd6; margin-right: 1rem; text-decoration: none;">
2222
Simulate Login
2323
</a>
24-
<a href="/oauth2/github/auth" style="display: flex; align-items: center;">
25-
<svg style="height: 50px; width: 50px;" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16">
26-
<path fill="#dfdfd6" d="M7.499,1C3.91,1,1,3.906,1,7.49c0,2.867,1.862,5.299,4.445,6.158C5.77,13.707,6,13.375,6,13.125 c0-0.154,0.003-0.334,0-0.875c-1.808,0.392-2.375-0.875-2.375-0.875c-0.296-0.75-0.656-0.963-0.656-0.963 c-0.59-0.403,0.044-0.394,0.044-0.394C3.666,10.064,4,10.625,4,10.625c0.5,0.875,1.63,0.791,2,0.625 c0-0.397,0.044-0.688,0.154-0.873C4.111,10.02,2.997,8.84,3,7.208c0.002-0.964,0.335-1.715,0.876-2.269 C3.639,4.641,3.479,3.625,3.961,3c1.206,0,1.927,0.873,1.927,0.873s0.565-0.248,1.61-0.248c1.045,0,1.608,0.234,1.608,0.234 S9.829,3,11.035,3c0.482,0.625,0.322,1.641,0.132,1.918C11.684,5.461,12,6.21,12,7.208c0,1.631-1.11,2.81-3.148,3.168 C8.982,10.572,9,10.842,9,11.25c0,0.867,0,1.662,0,1.875c0,0.25,0.228,0.585,0.558,0.522C12.139,12.787,14,10.356,14,7.49 C14,3.906,11.09,1,7.499,1z"></path>
27-
</svg>
28-
</a>
24+
{% for provider in request.auth.clients %}
25+
<a href="/oauth2/{{ provider }}/auth" style="display: flex; align-items: center;">
26+
<img
27+
alt="{{ provider }} icon"
28+
src="/static/{{ provider }}.svg"
29+
style="width: 50px; height: 50px; margin-right: 1rem;"
30+
>
31+
</a>
32+
{% endfor %}
2933
{% endif %}
3034
</div>
3135
</header>
3236
<section
3337
style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: calc(100vh - 70px);">
3438
{% if request.user.is_authenticated %}
3539
<h1>Hi, {{ request.user.display_name }}</h1>
40+
<h3>
41+
You're signed in using
42+
{% if request.auth.provider %}
43+
external {{ request.auth.provider.provider }} OAuth2 provider.
44+
{% else %}
45+
local authentication system.
46+
{% endif %}
47+
</h3>
3648
<p>This is what your JWT contains currently</p>
3749
<pre>{{ json.dumps(request.user, indent=4) }}</pre>
3850
{% else %}

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ license_files = LICENSE
2727
platforms = unix, linux, osx, win32
2828
classifiers =
2929
Operating System :: OS Independent
30-
Development Status :: 2 - Pre-Alpha
30+
Development Status :: 3 - Alpha
3131
Framework :: FastAPI
3232
Programming Language :: Python
3333
Programming Language :: Python :: 3

src/fastapi_oauth2/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0-alpha.1"
1+
__version__ = "1.0.0-alpha.2"

src/fastapi_oauth2/core.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import httpx
1212
from oauthlib.oauth2 import WebApplicationClient
13+
from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error
1314
from social_core.backends.oauth import BaseOAuth2
1415
from social_core.strategy import BaseStrategy
1516
from starlette.exceptions import HTTPException
@@ -46,9 +47,10 @@ class OAuth2Core:
4647

4748
client_id: str = None
4849
client_secret: str = None
49-
callback_url: Optional[str] = None
5050
scope: Optional[List[str]] = None
5151
claims: Optional[Claims] = None
52+
provider: str = None
53+
redirect_uri: str = None
5254
backend: BaseOAuth2 = None
5355
_oauth_client: Optional[WebApplicationClient] = None
5456

@@ -108,9 +110,12 @@ async def token_redirect(self, request: Request) -> RedirectResponse:
108110
auth = httpx.BasicAuth(self.client_id, self.client_secret)
109111
async with httpx.AsyncClient() as session:
110112
response = await session.post(token_url, headers=headers, content=content, auth=auth)
111-
token = self.oauth_client.parse_request_body_response(json.dumps(response.json()))
112-
token_data = self.standardize(self.backend.user_data(token.get("access_token")))
113-
access_token = request.auth.jwt_create(token_data)
113+
try:
114+
token = self.oauth_client.parse_request_body_response(json.dumps(response.json()))
115+
token_data = self.standardize(self.backend.user_data(token.get("access_token")))
116+
access_token = request.auth.jwt_create(token_data)
117+
except (CustomOAuth2Error, Exception) as e:
118+
raise OAuth2LoginError(400, str(e))
114119

115120
response = RedirectResponse(self.redirect_uri or request.base_url)
116121
response.set_cookie(

src/fastapi_oauth2/middleware.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Dict
77
from typing import List
88
from typing import Optional
9-
from typing import Sequence
109
from typing import Tuple
1110
from typing import Union
1211

@@ -39,16 +38,15 @@ class Auth(AuthCredentials):
3938
scopes: List[str]
4039
clients: Dict[str, OAuth2Core] = {}
4140

42-
provider: str
43-
default_provider: str = "local"
41+
_provider: OAuth2Core = None
4442

45-
def __init__(
46-
self,
47-
scopes: Optional[Sequence[str]] = None,
48-
provider: str = default_provider,
49-
) -> None:
50-
super().__init__(scopes)
51-
self.provider = provider
43+
@property
44+
def provider(self) -> Union[OAuth2Core, None]:
45+
return self._provider
46+
47+
@provider.setter
48+
def provider(self, identifier) -> None:
49+
self._provider = self.clients.get(identifier)
5250

5351
@classmethod
5452
def set_http(cls, http: bool) -> None:
@@ -145,18 +143,16 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
145143
return Auth(), User()
146144

147145
user = User(Auth.jwt_decode(param))
148-
user.update(provider=user.get("provider", Auth.default_provider))
149-
auth = Auth(user.pop("scope", []), user.get("provider"))
150-
client = Auth.clients.get(auth.provider)
151-
claims = client.claims if client else Claims()
152-
user = user.use_claims(claims)
146+
auth = Auth(user.pop("scope", []))
147+
auth.provider = user.get("provider")
148+
claims = auth.provider.claims if auth.provider else {}
153149

154150
# Call the callback function on authentication
155151
if callable(self.callback):
156-
coroutine = self.callback(auth, user)
152+
coroutine = self.callback(auth, user.use_claims(claims))
157153
if issubclass(type(coroutine), Awaitable):
158154
await coroutine
159-
return auth, user
155+
return auth, user.use_claims(claims)
160156

161157

162158
class OAuth2Middleware:

src/fastapi_oauth2/security.py

+11-18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
from typing import Any
2-
from typing import Callable
3-
from typing import Dict
41
from typing import Optional
5-
from typing import Tuple
62
from typing import Type
73

84
from fastapi.security import OAuth2 as FastAPIOAuth2
@@ -12,32 +8,29 @@
128
from starlette.requests import Request
139

1410

15-
def use_cookies(cls: Type[FastAPIOAuth2]) -> Callable[[Tuple[Any], Dict[str, Any]], FastAPIOAuth2]:
16-
"""OAuth2 classes wrapped with this decorator will use cookies for the Authorization header."""
11+
class OAuth2Cookie(type):
12+
"""OAuth2 classes using this metaclass will use cookies for the Authorization header."""
13+
14+
def __new__(metacls, name, bases, attrs) -> Type:
15+
instance = super().__new__(metacls, name, bases, attrs)
1716

18-
def _use_cookies(*args, **kwargs) -> FastAPIOAuth2:
1917
async def __call__(self: FastAPIOAuth2, request: Request) -> Optional[str]:
2018
authorization = request.headers.get("Authorization", request.cookies.get("Authorization"))
2119
if authorization:
2220
request._headers = Headers({**request.headers, "Authorization": authorization})
23-
return await super(cls, self).__call__(request)
24-
25-
cls.__call__ = __call__
26-
return cls(*args, **kwargs)
21+
return await instance.__base__.__call__(self, request)
2722

28-
return _use_cookies
23+
instance.__call__ = __call__
24+
return instance
2925

3026

31-
@use_cookies
32-
class OAuth2(FastAPIOAuth2):
27+
class OAuth2(FastAPIOAuth2, metaclass=OAuth2Cookie):
3328
"""Wrapper class of the `fastapi.security.OAuth2` class."""
3429

3530

36-
@use_cookies
37-
class OAuth2PasswordBearer(FastAPIPasswordBearer):
31+
class OAuth2PasswordBearer(FastAPIPasswordBearer, metaclass=OAuth2Cookie):
3832
"""Wrapper class of the `fastapi.security.OAuth2PasswordBearer` class."""
3933

4034

41-
@use_cookies
42-
class OAuth2AuthorizationCodeBearer(FastAPICodeBearer):
35+
class OAuth2AuthorizationCodeBearer(FastAPICodeBearer, metaclass=OAuth2Cookie):
4336
"""Wrapper class of the `fastapi.security.OAuth2AuthorizationCodeBearer` class."""

tests/conftest.py

+60
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33

44
import pytest
55
import social_core.backends as backends
6+
from fastapi import APIRouter
7+
from fastapi import Depends
8+
from fastapi import FastAPI
9+
from fastapi import Request
10+
from social_core.backends.github import GithubOAuth2
611
from social_core.backends.oauth import BaseOAuth2
12+
from starlette.responses import Response
13+
14+
from fastapi_oauth2.client import OAuth2Client
15+
from fastapi_oauth2.middleware import OAuth2Middleware
16+
from fastapi_oauth2.router import router as oauth2_router
17+
from fastapi_oauth2.security import OAuth2
718

819
package_path = backends.__path__[0]
920

@@ -24,3 +35,52 @@ def backends():
2435
except ImportError:
2536
continue
2637
return backend_instances
38+
39+
40+
@pytest.fixture
41+
def get_app():
42+
def fixture_wrapper(authentication: OAuth2 = None):
43+
if not authentication:
44+
authentication = OAuth2()
45+
46+
oauth2 = authentication
47+
application = FastAPI()
48+
app_router = APIRouter()
49+
50+
@app_router.get("/user")
51+
def user(request: Request, _: str = Depends(oauth2)):
52+
return request.user
53+
54+
@app_router.get("/auth")
55+
def auth(request: Request):
56+
access_token = request.auth.jwt_create({
57+
"name": "test",
58+
"sub": "test",
59+
"id": "test",
60+
})
61+
response = Response()
62+
response.set_cookie(
63+
"Authorization",
64+
value=f"Bearer {access_token}",
65+
max_age=request.auth.expires,
66+
expires=request.auth.expires,
67+
httponly=request.auth.http,
68+
)
69+
return response
70+
71+
application.include_router(app_router)
72+
application.include_router(oauth2_router)
73+
application.add_middleware(OAuth2Middleware, config={
74+
"allow_http": True,
75+
"clients": [
76+
OAuth2Client(
77+
backend=GithubOAuth2,
78+
client_id="test_id",
79+
client_secret="test_secret",
80+
),
81+
],
82+
})
83+
84+
return application
85+
86+
return fixture_wrapper

tests/test_backends.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
3+
from fastapi_oauth2.client import OAuth2Client
4+
from fastapi_oauth2.core import OAuth2Core
5+
6+
7+
@pytest.mark.anyio
8+
async def test_core_init_with_all_backends(backends):
9+
for backend in backends:
10+
try:
11+
OAuth2Core(OAuth2Client(
12+
backend=backend,
13+
client_id="test_client_id",
14+
client_secret="test_client_secret",
15+
))
16+
except (NotImplementedError, Exception):
17+
assert False

0 commit comments

Comments
 (0)