diff --git a/graphql_server/aiohttp/graphqlview.py b/graphql_server/aiohttp/graphqlview.py index 9d28f02..84a5f11 100644 --- a/graphql_server/aiohttp/graphqlview.py +++ b/graphql_server/aiohttp/graphqlview.py @@ -19,6 +19,7 @@ from graphql_server.render_graphiql import ( GraphiQLConfig, GraphiQLData, + GraphiQLOptions, render_graphiql_async, ) @@ -39,6 +40,9 @@ class GraphQLView: enable_async = False subscriptions = None headers = None + default_query = None + header_editor_enabled = None + should_persist_headers = None accepted_methods = ["GET", "POST", "PUT", "DELETE"] @@ -174,8 +178,13 @@ async def __call__(self, request): graphiql_html_title=self.graphiql_html_title, jinja_env=self.jinja_env, ) + graphiql_options = GraphiQLOptions( + default_query=self.default_query, + header_editor_enabled=self.header_editor_enabled, + should_persist_headers=self.should_persist_headers, + ) source = await render_graphiql_async( - data=graphiql_data, config=graphiql_config + data=graphiql_data, config=graphiql_config, options=graphiql_options ) return web.Response(text=source, content_type="text/html") diff --git a/graphql_server/flask/graphqlview.py b/graphql_server/flask/graphqlview.py index 33467c9..1b33433 100644 --- a/graphql_server/flask/graphqlview.py +++ b/graphql_server/flask/graphqlview.py @@ -1,3 +1,5 @@ +import copy +from collections.abc import MutableMapping from functools import partial from typing import List @@ -18,6 +20,7 @@ from graphql_server.render_graphiql import ( GraphiQLConfig, GraphiQLData, + GraphiQLOptions, render_graphiql_sync, ) @@ -25,6 +28,7 @@ class GraphQLView(View): schema = None root_value = None + context = None pretty = False graphiql = False graphiql_version = None @@ -34,6 +38,9 @@ class GraphQLView(View): batch = False subscriptions = None headers = None + default_query = None + header_editor_enabled = None + should_persist_headers = None methods = ["GET", "POST", "PUT", "DELETE"] @@ -50,12 +57,18 @@ def __init__(self, **kwargs): self.schema, GraphQLSchema ), "A Schema is required to be provided to GraphQLView." - # noinspection PyUnusedLocal def get_root_value(self): return self.root_value - def get_context_value(self): - return request + def get_context(self): + context = ( + copy.copy(self.context) + if self.context and isinstance(self.context, MutableMapping) + else {} + ) + if isinstance(context, MutableMapping) and "request" not in context: + context.update({"request": request}) + return context def get_middleware(self): return self.middleware @@ -80,7 +93,7 @@ def dispatch_request(self): catch=catch, # Execute options root_value=self.get_root_value(), - context_value=self.get_context_value(), + context_value=self.get_context(), middleware=self.get_middleware(), ) result, status_code = encode_execution_results( @@ -105,8 +118,13 @@ def dispatch_request(self): graphiql_html_title=self.graphiql_html_title, jinja_env=None, ) + graphiql_options = GraphiQLOptions( + default_query=self.default_query, + header_editor_enabled=self.header_editor_enabled, + should_persist_headers=self.should_persist_headers, + ) source = render_graphiql_sync( - data=graphiql_data, config=graphiql_config + data=graphiql_data, config=graphiql_config, options=graphiql_options ) return render_template_string(source) diff --git a/graphql_server/render_graphiql.py b/graphql_server/render_graphiql.py index 8ae4107..c942300 100644 --- a/graphql_server/render_graphiql.py +++ b/graphql_server/render_graphiql.py @@ -201,7 +201,7 @@ class GraphiQLOptions(TypedDict): default_query An optional GraphQL string to use when no query is provided and no stored - query exists from a previous session. If undefined is provided, GraphiQL + query exists from a previous session. If None is provided, GraphiQL will use its own default query. header_editor_enabled An optional boolean which enables the header editor when true. diff --git a/graphql_server/sanic/graphqlview.py b/graphql_server/sanic/graphqlview.py index d3fefaa..110ea2e 100644 --- a/graphql_server/sanic/graphqlview.py +++ b/graphql_server/sanic/graphqlview.py @@ -21,6 +21,7 @@ from graphql_server.render_graphiql import ( GraphiQLConfig, GraphiQLData, + GraphiQLOptions, render_graphiql_async, ) @@ -41,6 +42,9 @@ class GraphQLView(HTTPMethodView): enable_async = False subscriptions = None headers = None + default_query = None + header_editor_enabled = None + should_persist_headers = None methods = ["GET", "POST", "PUT", "DELETE"] @@ -127,8 +131,15 @@ async def dispatch_request(self, request, *args, **kwargs): graphiql_html_title=self.graphiql_html_title, jinja_env=self.jinja_env, ) + graphiql_options = GraphiQLOptions( + default_query=self.default_query, + header_editor_enabled=self.header_editor_enabled, + should_persist_headers=self.should_persist_headers, + ) source = await render_graphiql_async( - data=graphiql_data, config=graphiql_config + data=graphiql_data, + config=graphiql_config, + options=graphiql_options, ) return html(source) diff --git a/graphql_server/webob/graphqlview.py b/graphql_server/webob/graphqlview.py index 3801fee..4eff242 100644 --- a/graphql_server/webob/graphqlview.py +++ b/graphql_server/webob/graphqlview.py @@ -19,6 +19,7 @@ from graphql_server.render_graphiql import ( GraphiQLConfig, GraphiQLData, + GraphiQLOptions, render_graphiql_sync, ) @@ -38,6 +39,9 @@ class GraphQLView: enable_async = False subscriptions = None headers = None + default_query = None + header_editor_enabled = None + should_persist_headers = None charset = "UTF-8" format_error = staticmethod(format_error_default) @@ -117,8 +121,17 @@ def dispatch_request(self, request): graphiql_html_title=self.graphiql_html_title, jinja_env=None, ) + graphiql_options = GraphiQLOptions( + default_query=self.default_query, + header_editor_enabled=self.header_editor_enabled, + should_persist_headers=self.should_persist_headers, + ) return Response( - render_graphiql_sync(data=graphiql_data, config=graphiql_config), + render_graphiql_sync( + data=graphiql_data, + config=graphiql_config, + options=graphiql_options, + ), charset=self.charset, content_type="text/html", ) diff --git a/tests/aiohttp/schema.py b/tests/aiohttp/schema.py index 9198b12..6e5495a 100644 --- a/tests/aiohttp/schema.py +++ b/tests/aiohttp/schema.py @@ -24,8 +24,17 @@ def resolve_raises(*_): resolve=lambda obj, info, *args: info.context["request"].query.get("q"), ), "context": GraphQLField( - GraphQLNonNull(GraphQLString), - resolve=lambda obj, info, *args: info.context["request"], + GraphQLObjectType( + name="context", + fields={ + "session": GraphQLField(GraphQLString), + "request": GraphQLField( + GraphQLNonNull(GraphQLString), + resolve=lambda obj, info: info.context["request"], + ), + }, + ), + resolve=lambda obj, info: info.context, ), "test": GraphQLField( type_=GraphQLString, diff --git a/tests/aiohttp/test_graphqlview.py b/tests/aiohttp/test_graphqlview.py index 0f6becb..0a940f9 100644 --- a/tests/aiohttp/test_graphqlview.py +++ b/tests/aiohttp/test_graphqlview.py @@ -521,8 +521,8 @@ async def test_handles_unsupported_http_methods(client): } -@pytest.mark.parametrize("app", [create_app()]) @pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app()]) async def test_passes_request_into_request_context(app, client): response = await client.get(url_string(query="{request}", q="testing")) @@ -532,27 +532,42 @@ async def test_passes_request_into_request_context(app, client): } -class TestCustomContext: - @pytest.mark.parametrize( - "app", [create_app(context="CUSTOM CONTEXT")], - ) - @pytest.mark.asyncio - async def test_context_remapped(self, app, client): - response = await client.get(url_string(query="{context}")) +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(context={"session": "CUSTOM CONTEXT"})]) +async def test_passes_custom_context_into_context(app, client): + response = await client.get(url_string(query="{context { session request }}")) + + _json = await response.json() + assert response.status == 200 + assert "data" in _json + assert "session" in _json["data"]["context"] + assert "request" in _json["data"]["context"] + assert "CUSTOM CONTEXT" in _json["data"]["context"]["session"] + assert "Request" in _json["data"]["context"]["request"] - _json = await response.json() - assert response.status == 200 - assert "Request" in _json["data"]["context"] - assert "CUSTOM CONTEXT" not in _json["data"]["context"] - @pytest.mark.parametrize("app", [create_app(context={"request": "test"})]) - @pytest.mark.asyncio - async def test_request_not_replaced(self, app, client): - response = await client.get(url_string(query="{context}")) +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(context="CUSTOM CONTEXT")]) +async def test_context_remapped_if_not_mapping(app, client): + response = await client.get(url_string(query="{context { session request }}")) - _json = await response.json() - assert response.status == 200 - assert _json["data"]["context"] == "test" + _json = await response.json() + assert response.status == 200 + assert "data" in _json + assert "session" in _json["data"]["context"] + assert "request" in _json["data"]["context"] + assert "CUSTOM CONTEXT" not in _json["data"]["context"]["request"] + assert "Request" in _json["data"]["context"]["request"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(context={"request": "test"})]) +async def test_request_not_replaced(app, client): + response = await client.get(url_string(query="{context { request }}")) + + _json = await response.json() + assert response.status == 200 + assert _json["data"]["context"]["request"] == "test" @pytest.mark.asyncio @@ -583,69 +598,68 @@ async def test_post_multipart_data(client): assert await response.json() == {"data": {u"writeTest": {u"test": u"Hello World"}}} -class TestBatchExecutor: - @pytest.mark.asyncio - @pytest.mark.parametrize("app", [create_app(batch=True)]) - async def test_batch_allows_post_with_json_encoding(self, app, client): - response = await client.post( - "/graphql", - data=json.dumps([dict(id=1, query="{test}")]), - headers={"content-type": "application/json"}, - ) +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(batch=True)]) +async def test_batch_allows_post_with_json_encoding(app, client): + response = await client.post( + "/graphql", + data=json.dumps([dict(id=1, query="{test}")]), + headers={"content-type": "application/json"}, + ) - assert response.status == 200 - assert await response.json() == [{"data": {"test": "Hello World"}}] - - @pytest.mark.asyncio - @pytest.mark.parametrize("app", [create_app(batch=True)]) - async def test_batch_supports_post_json_query_with_json_variables( - self, app, client - ): - response = await client.post( - "/graphql", - data=json.dumps( - [ - dict( - id=1, - query="query helloWho($who: String){ test(who: $who) }", - variables={"who": "Dolly"}, - ) - ] - ), - headers={"content-type": "application/json"}, - ) + assert response.status == 200 + assert await response.json() == [{"data": {"test": "Hello World"}}] - assert response.status == 200 - assert await response.json() == [{"data": {"test": "Hello Dolly"}}] - - @pytest.mark.asyncio - @pytest.mark.parametrize("app", [create_app(batch=True)]) - async def test_batch_allows_post_with_operation_name(self, app, client): - response = await client.post( - "/graphql", - data=json.dumps( - [ - dict( - id=1, - query=""" - query helloYou { test(who: "You"), ...shared } - query helloWorld { test(who: "World"), ...shared } - query helloDolly { test(who: "Dolly"), ...shared } - fragment shared on QueryRoot { - shared: test(who: "Everyone") - } - """, - operationName="helloWorld", - ) - ] - ), - headers={"content-type": "application/json"}, - ) - assert response.status == 200 - assert await response.json() == [ - {"data": {"test": "Hello World", "shared": "Hello Everyone"}} - ] +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(batch=True)]) +async def test_batch_supports_post_json_query_with_json_variables(app, client): + response = await client.post( + "/graphql", + data=json.dumps( + [ + dict( + id=1, + query="query helloWho($who: String){ test(who: $who) }", + variables={"who": "Dolly"}, + ) + ] + ), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == [{"data": {"test": "Hello Dolly"}}] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(batch=True)]) +async def test_batch_allows_post_with_operation_name(app, client): + response = await client.post( + "/graphql", + data=json.dumps( + [ + dict( + id=1, + query=""" + query helloYou { test(who: "You"), ...shared } + query helloWorld { test(who: "World"), ...shared } + query helloDolly { test(who: "Dolly"), ...shared } + fragment shared on QueryRoot { + shared: test(who: "Everyone") + } + """, + operationName="helloWorld", + ) + ] + ), + headers={"content-type": "application/json"}, + ) + + assert response.status == 200 + assert await response.json() == [ + {"data": {"test": "Hello World", "shared": "Hello Everyone"}} + ] @pytest.mark.asyncio diff --git a/tests/flask/schema.py b/tests/flask/schema.py index 5d4c52c..eb51e26 100644 --- a/tests/flask/schema.py +++ b/tests/flask/schema.py @@ -18,10 +18,20 @@ def resolve_raises(*_): "thrower": GraphQLField(GraphQLNonNull(GraphQLString), resolve=resolve_raises), "request": GraphQLField( GraphQLNonNull(GraphQLString), - resolve=lambda obj, info: info.context.args.get("q"), + resolve=lambda obj, info: info.context["request"].args.get("q"), ), "context": GraphQLField( - GraphQLNonNull(GraphQLString), resolve=lambda obj, info: info.context + GraphQLObjectType( + name="context", + fields={ + "session": GraphQLField(GraphQLString), + "request": GraphQLField( + GraphQLNonNull(GraphQLString), + resolve=lambda obj, info: info.context["request"], + ), + }, + ), + resolve=lambda obj, info: info.context, ), "test": GraphQLField( type_=GraphQLString, diff --git a/tests/flask/test_graphqlview.py b/tests/flask/test_graphqlview.py index d2f478d..961a8e0 100644 --- a/tests/flask/test_graphqlview.py +++ b/tests/flask/test_graphqlview.py @@ -489,14 +489,30 @@ def test_passes_request_into_request_context(app, client): assert response_json(response) == {"data": {"request": "testing"}} -@pytest.mark.parametrize( - "app", [create_app(get_context_value=lambda: "CUSTOM CONTEXT")] -) +@pytest.mark.parametrize("app", [create_app(context={"session": "CUSTOM CONTEXT"})]) def test_passes_custom_context_into_context(app, client): - response = client.get(url_string(app, query="{context}")) + response = client.get(url_string(app, query="{context { session request }}")) assert response.status_code == 200 - assert response_json(response) == {"data": {"context": "CUSTOM CONTEXT"}} + res = response_json(response) + assert "data" in res + assert "session" in res["data"]["context"] + assert "request" in res["data"]["context"] + assert "CUSTOM CONTEXT" in res["data"]["context"]["session"] + assert "Request" in res["data"]["context"]["request"] + + +@pytest.mark.parametrize("app", [create_app(context="CUSTOM CONTEXT")]) +def test_context_remapped_if_not_mapping(app, client): + response = client.get(url_string(app, query="{context { session request }}")) + + assert response.status_code == 200 + res = response_json(response) + assert "data" in res + assert "session" in res["data"]["context"] + assert "request" in res["data"]["context"] + assert "CUSTOM CONTEXT" not in res["data"]["context"]["request"] + assert "Request" in res["data"]["context"]["request"] def test_post_multipart_data(app, client): diff --git a/tests/sanic/schema.py b/tests/sanic/schema.py index a129d92..f827c2b 100644 --- a/tests/sanic/schema.py +++ b/tests/sanic/schema.py @@ -24,8 +24,17 @@ def resolve_raises(*_): resolve=lambda obj, info: info.context["request"].args.get("q"), ), "context": GraphQLField( - GraphQLNonNull(GraphQLString), - resolve=lambda obj, info: info.context["request"], + GraphQLObjectType( + name="context", + fields={ + "session": GraphQLField(GraphQLString), + "request": GraphQLField( + GraphQLNonNull(GraphQLString), + resolve=lambda obj, info: info.context["request"], + ), + }, + ), + resolve=lambda obj, info: info.context, ), "test": GraphQLField( type_=GraphQLString, diff --git a/tests/sanic/test_graphqlview.py b/tests/sanic/test_graphqlview.py index 7325e6d..740697c 100644 --- a/tests/sanic/test_graphqlview.py +++ b/tests/sanic/test_graphqlview.py @@ -491,13 +491,30 @@ def test_passes_request_into_request_context(app): assert response_json(response) == {"data": {"request": "testing"}} -@pytest.mark.parametrize("app", [create_app(context="CUSTOM CONTEXT")]) -def test_supports_pretty_printing_on_custom_context_response(app): - _, response = app.client.get(uri=url_string(query="{context}")) +@pytest.mark.parametrize("app", [create_app(context={"session": "CUSTOM CONTEXT"})]) +def test_passes_custom_context_into_context(app): + _, response = app.client.get(uri=url_string(query="{context { session request }}")) - assert response.status == 200 - assert "data" in response_json(response) - assert response_json(response)["data"]["context"] == "" + assert response.status_code == 200 + res = response_json(response) + assert "data" in res + assert "session" in res["data"]["context"] + assert "request" in res["data"]["context"] + assert "CUSTOM CONTEXT" in res["data"]["context"]["session"] + assert "Request" in res["data"]["context"]["request"] + + +@pytest.mark.parametrize("app", [create_app(context="CUSTOM CONTEXT")]) +def test_context_remapped_if_not_mapping(app): + _, response = app.client.get(uri=url_string(query="{context { session request }}")) + + assert response.status_code == 200 + res = response_json(response) + assert "data" in res + assert "session" in res["data"]["context"] + assert "request" in res["data"]["context"] + assert "CUSTOM CONTEXT" not in res["data"]["context"]["request"] + assert "Request" in res["data"]["context"]["request"] @pytest.mark.parametrize("app", [create_app()]) diff --git a/tests/webob/schema.py b/tests/webob/schema.py index f00f14f..e6aa93f 100644 --- a/tests/webob/schema.py +++ b/tests/webob/schema.py @@ -22,8 +22,17 @@ def resolve_raises(*_): resolve=lambda obj, info: info.context["request"].params.get("q"), ), "context": GraphQLField( - GraphQLNonNull(GraphQLString), - resolve=lambda obj, info: info.context["request"], + GraphQLObjectType( + name="context", + fields={ + "session": GraphQLField(GraphQLString), + "request": GraphQLField( + GraphQLNonNull(GraphQLString), + resolve=lambda obj, info: info.context["request"], + ), + }, + ), + resolve=lambda obj, info: info.context, ), "test": GraphQLField( type_=GraphQLString, diff --git a/tests/webob/test_graphqlview.py b/tests/webob/test_graphqlview.py index 6b5f37c..456b5f1 100644 --- a/tests/webob/test_graphqlview.py +++ b/tests/webob/test_graphqlview.py @@ -462,16 +462,30 @@ def test_passes_request_into_request_context(client): assert response_json(response) == {"data": {"request": "testing"}} +@pytest.mark.parametrize("settings", [dict(context={"session": "CUSTOM CONTEXT"})]) +def test_passes_custom_context_into_context(client, settings): + response = client.get(url_string(query="{context { session request }}")) + + assert response.status_code == 200 + res = response_json(response) + assert "data" in res + assert "session" in res["data"]["context"] + assert "request" in res["data"]["context"] + assert "CUSTOM CONTEXT" in res["data"]["context"]["session"] + assert "request" in res["data"]["context"]["request"] + + @pytest.mark.parametrize("settings", [dict(context="CUSTOM CONTEXT")]) -def test_supports_custom_context(client, settings): - response = client.get(url_string(query="{context}")) +def test_context_remapped_if_not_mapping(client, settings): + response = client.get(url_string(query="{context { session request }}")) assert response.status_code == 200 - assert "data" in response_json(response) - assert ( - response_json(response)["data"]["context"] - == "GET /graphql?query=%7Bcontext%7D HTTP/1.0\r\nHost: localhost:80" - ) + res = response_json(response) + assert "data" in res + assert "session" in res["data"]["context"] + assert "request" in res["data"]["context"] + assert "CUSTOM CONTEXT" not in res["data"]["context"]["request"] + assert "request" in res["data"]["context"]["request"] def test_post_multipart_data(client):