Skip to content

Commit

Permalink
add gas and passthrough
Browse files Browse the repository at this point in the history
  • Loading branch information
devkral committed Mar 15, 2024
1 parent 7a66d35 commit 58c2673
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 53 deletions.
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ A Limits object has following attributes:
- depth: max depth (default: 20, None disables feature)
- selections: max selections (default: None, None disables feature)
- complexity: max (depth subtree \* selections subtree) (default: 100, None disables feature)
- gas: accumulated gas costs (default: None, None disables feature)
- passthrough: field names specified here will be passed through regardless if specified (default: empty frozen set)

they overwrite django settings if specified.

Expand All @@ -180,10 +182,22 @@ they overwrite django settings if specified.
Sometimes single fields should have different limits:

```python
person1 = Limits(depth=10)(graphene.Field(Person))
from graphene_protector import Limits
person1 = Limits(depth=10)(graphene.Field(Person))
```

Limits are passthroughs for missing parameters

There is also a novel technique named gas: you can assign a field a static value or dynamically calculate it for the field

The decorator is called gas_usage

```python
from graphene_protector import gas_usage
person1 = gas_usage(10)(graphene.Field(Person))
```

Limits are inherited for unspecified parameters
see tests for more examples

## one-time disable limit checks

Expand Down Expand Up @@ -243,8 +257,6 @@ If you want some new or better algorithms integrated just make a PR

# TODO

- add tests and documentation for passthrough
- add tests and documentation for gas
- stop when an open path regex is used. May append an invalid char and check if it is still ignoring
- keep an eye on the performance impact of the new path regex checking
- add tests for auto_snakecase and camelcase_path
Expand Down
78 changes: 49 additions & 29 deletions graphene_protector/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
__all__ = [
"follow_of_type",
"to_camel_case",
"to_snake_case",
"merge_limits",
"gas_for_field",
"limits_for_field",
"check_resource_usage",
"gas_usage",
"LimitsValidationRule",
"decorate_limits",
"decorate_limits_async",
"SchemaMixin",
]

import re
from collections.abc import Callable
from dataclasses import fields, replace
Expand Down Expand Up @@ -64,6 +79,7 @@ def merge_limits(old_limits: Limits, new_limits: Limits):
_limits = {}
for field in fields(new_limits):
value = getattr(new_limits, field.name)
# passthrough is always set so there is no issue
if value is not MISSING:
_limits[field.name] = value
return replace(old_limits, **_limits)
Expand All @@ -85,7 +101,7 @@ def gas_for_field(schema_field, **kwargs) -> int:
if hasattr(schema_field, "_graphene_protector_gas"):
retval = getattr(schema_field, "_graphene_protector_gas")
if callable(retval):
retval = retval(**kwargs)
retval = retval(schema_field=schema_field, **kwargs)
return retval
if hasattr(schema_field, "__func__"):
schema_field = getattr(schema_field, "__func__")
Expand Down Expand Up @@ -146,6 +162,30 @@ def check_resource_usage(
if isinstance(field, FragmentSpreadNode):
field = validation_context.get_fragment(field.name.value)

try:
schema_field = getattr(schema, fieldname)
except AttributeError:
_name = None
if hasattr(field, "name"):
_name = field.name
if hasattr(_name, "value"):
_name = _name.value
if (
hasattr(schema, "fields")
and not isinstance(schema, GraphQLInterfaceType)
and _name
):
schema_field = schema.fields[_name]
else:
schema_field = schema

# add gas for field
retval.gas_used += get_gas_for_field(
schema_field,
parent=schema,
fieldname=fieldname,
)

if isinstance(field, (GraphQLUnionType, GraphQLInterfaceType)):
merged_limits = limits
local_union_selections = 0
Expand Down Expand Up @@ -207,35 +247,13 @@ def check_resource_usage(
del local_union_selections
del local_gas
elif field.selection_set:
try:
schema_field = getattr(schema, fieldname)
except AttributeError:
_name = None
if hasattr(field, "name"):
_name = field.name
if hasattr(_name, "value"):
_name = _name.value
if (
hasattr(schema, "fields")
and not isinstance(schema, GraphQLInterfaceType)
and _name
):
schema_field = schema.fields[_name]
else:
schema_field = schema
merged_limits, sub_limits = get_limits_for_field(
schema_field,
limits,
parent=schema,
fieldname=fieldname,
)
# add gas for selection field
retval.gas_used += get_gas_for_field(
schema_field,
parent=schema,
fieldname=fieldname,
)
allow_reset = True
allow_restart_counters = True
field_contributes_to_score = True
_npath = "{}/{}".format(
_path,
Expand All @@ -245,10 +263,12 @@ def check_resource_usage(
field_contributes_to_score = False
# must be seperate from condition above
if sub_limits is not MISSING:
if id(sub_limits) in _seen_limits:
allow_reset = False
id_sub_limits = id(sub_limits)
# loop detected, cannot reset via sub_limits
if id_sub_limits in _seen_limits:
allow_restart_counters = False
else:
_seen_limits.add(id(sub_limits))
_seen_limits.add(id_sub_limits)
if isinstance(
schema_field,
(GraphQLUnionType, GraphQLInterfaceType, GraphQLObjectType),
Expand All @@ -269,10 +289,10 @@ def check_resource_usage(
get_gas_for_field=get_gas_for_field,
# field_contributes_to_score will be casted to 1 for True
level_depth=level_depth + field_contributes_to_score
if sub_limits.depth is MISSING or not allow_reset
if sub_limits.depth is MISSING or not allow_restart_counters
else 1,
level_complexity=level_complexity + field_contributes_to_score
if sub_limits.complexity is MISSING or not allow_reset
if sub_limits.complexity is MISSING or not allow_restart_counters
else 1,
_seen_limits=_seen_limits,
_path=_npath,
Expand Down
29 changes: 24 additions & 5 deletions graphene_protector/misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
__all__ = [
"MISSING",
"Limits",
"UsagesResult",
"DEFAULT_LIMITS",
"MISSING_LIMITS",
"EarlyStop",
"ResourceLimitReached",
"DepthLimitReached",
"SelectionsLimitReached",
"ComplexityLimitReached",
"GasLimitReached",
"default_path_ignore_pattern",
]

import copy
import sys
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Set, Union

from graphql.error import GraphQLError

_empty_set = frozenset()


class MISSING:
"""
Expand All @@ -29,10 +47,11 @@ class Limits:
gas: Union[int, None, MISSING] = MISSING
# only for sublimits not for main Limit instance
# passthrough for not missing limits
passthrough: Set[str] = field(default_factory=set)
passthrough: Set[str] = _empty_set

def __call__(self, field):
setattr(field, "_graphene_protector_limits", self)
# ensure every decoration has an own id
setattr(field, "_graphene_protector_limits", copy.copy(self))
return field


Expand Down Expand Up @@ -64,11 +83,11 @@ class SelectionsLimitReached(ResourceLimitReached):
pass


class GasLimitReached(ResourceLimitReached):
class ComplexityLimitReached(ResourceLimitReached):
pass


class ComplexityLimitReached(ResourceLimitReached):
class GasLimitReached(ResourceLimitReached):
pass


Expand Down
5 changes: 3 additions & 2 deletions tests/graphene/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import graphene
from graphene import relay
from graphql_relay import from_global_id

from graphene_protector import gas_usage


class SomeNode(graphene.ObjectType):
Expand Down Expand Up @@ -28,7 +29,7 @@ class Edge:
class Query(graphene.ObjectType):
node = relay.Node.Field()

hello = graphene.String()
hello = gas_usage(lambda **kwargs: 1)(graphene.String())

def resolve_hello(root, info):
return "World"
Expand Down
4 changes: 3 additions & 1 deletion tests/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
GraphQLString,
)

from graphene_protector import gas_usage

try:
field = GraphQLField(GraphQLString, resolve=lambda *_: "World")
except TypeError:
field = GraphQLField(GraphQLString, resolver=lambda *_: "World")

Query = GraphQLObjectType(
"Query",
lambda: {"hello": field},
lambda: {"hello": gas_usage(1)(field)},
)
8 changes: 4 additions & 4 deletions tests/strawberry/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Iterable, List, Optional, Union

import strawberry
from strawberry.relay import from_base64
from strawberry.types import Info

from graphene_protector import gas_usage


@strawberry.input
class PersonFilter:
Expand Down Expand Up @@ -43,9 +44,7 @@ class Query:
# node: strawberry.relay.Node = strawberry.relay.node()
@strawberry.field()
@staticmethod
def node(
info, id: strawberry.relay.GlobalID
) -> Optional[strawberry.relay.Node]:
def node(info, id: strawberry.relay.GlobalID) -> Optional[strawberry.relay.Node]:
return id.resolve_node(info=info, required=False)

@strawberry.field
Expand All @@ -57,6 +56,7 @@ def persons(
Person2(name="Zoe", child=Person1(name="Hubert")),
]

@gas_usage(lambda **_kwargs: 4)
@strawberry.field
def in_out(self, into: List[str]) -> List[str]:
return into
Expand Down
10 changes: 5 additions & 5 deletions tests/test_graphql_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from graphql import parse, validate
from graphql.type import GraphQLSchema

from graphene_protector import Limits, SchemaMixin, ValidationRule
from graphene_protector import Limits, LimitsValidationRule, SchemaMixin

from .graphql.schema import Query


class Schema(GraphQLSchema, SchemaMixin):
protector_default_limits = Limits(
depth=2, selections=None, complexity=None, gas=None
)
protector_default_limits = Limits(depth=2, selections=None, complexity=None, gas=1)
auto_camelcase = False


Expand All @@ -23,4 +21,6 @@ def test_simple(self):
query=Query,
)
query_ast = parse("{ hello }")
self.assertFalse(validate(schema, query_ast, [ValidationRule]))
self.assertFalse(validate(schema, query_ast, [LimitsValidationRule]))
query_ast = parse("{ hello, hello1: hello }")
self.assertTrue(validate(schema, query_ast, [LimitsValidationRule]))
15 changes: 14 additions & 1 deletion tests/testgraphene_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,26 @@ class TestGraphene(unittest.TestCase):
def test_simple(self):
schema = ProtectorSchema(
query=Query,
limits=Limits(depth=2, selections=None, complexity=None, gas=None),
limits=Limits(depth=2, selections=None, complexity=None, gas=1),
types=[SomeNode],
)
self.assertIsInstance(schema, GrapheneSchema)
result = schema.execute("{ hello }")
self.assertFalse(result.errors)
self.assertDictEqual(result.data, {"hello": "World"})

def test_gas(self):
schema = ProtectorSchema(
query=Query,
limits=Limits(depth=None, selections=None, complexity=None, gas=1),
types=[SomeNode],
)
self.assertIsInstance(schema, GrapheneSchema)
result = schema.execute("{ hello }")
self.assertFalse(result.errors)
self.assertDictEqual(result.data, {"hello": "World"})
result = schema.execute("{ hello, hello1: hello }")
self.assertTrue(result.errors)

def test_node(self):
schema = ProtectorSchema(
Expand Down
21 changes: 21 additions & 0 deletions tests/testgraphene_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,20 @@ class Person4(graphene.ObjectType):
age = graphene.Int()
depth = graphene.Int()
child = Limits(depth=2)(graphene.Field(Person))
child2 = Limits(depth=2, passthrough={"depth"})(graphene.Field(Person))

def resolve_child(self, info):
if self.depth == 0:
return None

return Person(id=self.id + 1, depth=self.depth - 1, age=34)

def resolve_child2(self, info):
if self.depth == 0:
return None

return Person(id=self.id + 1, depth=self.depth - 1, age=34)


class Query(graphene.ObjectType):
setDirectly = Limits(depth=2)(graphene.Field(Person))
Expand Down Expand Up @@ -161,6 +168,20 @@ def test_set_hierachy(self):
}
}
}
"""
result = schema.execute(query)
self.assertTrue(result.errors)
# test passthrough
query = """
query something{
setHierachy{
child {
child2 {
age
}
}
}
}
"""
result = schema.execute(query)
self.assertTrue(result.errors)
Loading

0 comments on commit 58c2673

Please sign in to comment.