Skip to content

Commit 03b0400

Browse files
author
Guido
committed
Add update mutations
1 parent 400750e commit 03b0400

File tree

6 files changed

+187
-39
lines changed

6 files changed

+187
-39
lines changed

src/graphql_sqlalchemy/args.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
from graphql import (
42
GraphQLArgument,
53
GraphQLNonNull,
@@ -10,7 +8,15 @@
108
from sqlalchemy.ext.declarative import DeclarativeMeta
119

1210
from .scalars import get_graphql_type_from_column
13-
from .inputs import get_where_type, get_order_type, get_insert_type, get_conflict_type
11+
from .inputs import (
12+
get_where_input_type,
13+
get_order_input_type,
14+
get_insert_input_type,
15+
get_conflict_input_type,
16+
get_inc_input_type,
17+
get_set_input_type,
18+
get_pk_columns_input,
19+
)
1420
from .types import Inputs
1521
from .helpers import get_table
1622

@@ -20,8 +26,8 @@
2026

2127
def make_args(model: DeclarativeMeta, inputs: Inputs) -> GraphQLArgumentMap:
2228
args = {
23-
"order": GraphQLArgument(GraphQLList(GraphQLNonNull(get_order_type(model, inputs)))),
24-
"where": GraphQLArgument(get_where_type(model, inputs)),
29+
"order": GraphQLArgument(GraphQLList(GraphQLNonNull(get_order_input_type(model, inputs)))),
30+
"where": GraphQLArgument(get_where_input_type(model, inputs)),
2531
}
2632

2733
for name, field in PAGINATION_ARGS.items():
@@ -43,17 +49,33 @@ def make_pk_args(model: DeclarativeMeta) -> GraphQLArgumentMap:
4349

4450
def make_insert_args(model: DeclarativeMeta, inputs: Inputs) -> GraphQLArgumentMap:
4551
return {
46-
"objects": GraphQLArgument(GraphQLNonNull(GraphQLList(GraphQLNonNull(get_insert_type(model, inputs))))),
47-
"on_conflict": GraphQLArgument(get_conflict_type(model, inputs)),
52+
"objects": GraphQLArgument(GraphQLNonNull(GraphQLList(GraphQLNonNull(get_insert_input_type(model, inputs))))),
53+
"on_conflict": GraphQLArgument(get_conflict_input_type(model, inputs)),
4854
}
4955

5056

5157
def make_insert_one_args(model: DeclarativeMeta, inputs: Inputs) -> GraphQLArgumentMap:
5258
return {
53-
"object": GraphQLArgument(get_insert_type(model, inputs)),
54-
"on_conflict": GraphQLArgument(get_conflict_type(model, inputs)),
59+
"object": GraphQLArgument(get_insert_input_type(model, inputs)),
60+
"on_conflict": GraphQLArgument(get_conflict_input_type(model, inputs)),
5561
}
5662

5763

5864
def make_delete_args(model: DeclarativeMeta, inputs: Inputs) -> GraphQLArgumentMap:
59-
return {"where": GraphQLArgument(get_where_type(model, inputs))}
65+
return {"where": GraphQLArgument(get_where_input_type(model, inputs))}
66+
67+
68+
def make_update_args(model: DeclarativeMeta, inputs: Inputs) -> GraphQLArgumentMap:
69+
return {
70+
"_inc": GraphQLArgument(get_inc_input_type(model, inputs)),
71+
"_set": GraphQLArgument(get_set_input_type(model, inputs)),
72+
"where": GraphQLArgument(get_where_input_type(model, inputs)),
73+
}
74+
75+
76+
def make_update_by_pk_args(model: DeclarativeMeta, inputs: Inputs) -> GraphQLArgumentMap:
77+
return {
78+
"_inc": GraphQLArgument(get_inc_input_type(model, inputs)),
79+
"_set": GraphQLArgument(get_set_input_type(model, inputs)),
80+
"pk_columns": GraphQLArgument(GraphQLNonNull(get_pk_columns_input(model))),
81+
}

src/graphql_sqlalchemy/dialects/pg/inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from ...types import Inputs
1717
from ...helpers import get_table
18-
from ...inputs import get_where_type
18+
from ...inputs import get_where_input_type
1919

2020

2121
def get_constraint_enum(model: DeclarativeMeta) -> GraphQLEnumType:
@@ -49,7 +49,7 @@ def get_conflict_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObj
4949
"update_columns": GraphQLInputField(
5050
GraphQLNonNull(GraphQLList(GraphQLNonNull(get_update_column_enums(model))))
5151
),
52-
"where": GraphQLInputField(get_where_type(model, inputs)),
52+
"where": GraphQLInputField(get_where_input_type(model, inputs)),
5353
}
5454

5555
input_type = GraphQLInputObjectType(type_name, fields)

src/graphql_sqlalchemy/inputs.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy import Column
1+
from sqlalchemy import Column, Integer, Float
22
from sqlalchemy.ext.declarative import DeclarativeMeta
33
from graphql import (
44
GraphQLInputObjectType,
@@ -17,6 +17,9 @@
1717
get_model_insert_input_name,
1818
get_scalar_comparison_name,
1919
get_model_conflict_input_name,
20+
get_model_inc_input_type_name,
21+
get_model_set_input_type_name,
22+
get_model_pk_columns_input_type_name,
2023
)
2124
from .scalars import get_graphql_type_from_column, get_base_comparison_fields, get_string_comparison_fields
2225
from .types import Inputs
@@ -38,12 +41,11 @@ def get_comparison_input_type(column: Column, inputs: Inputs) -> GraphQLInputObj
3841
if scalar == GraphQLString:
3942
fields.update(get_string_comparison_fields())
4043

41-
input_type = GraphQLInputObjectType(type_name, fields)
42-
inputs[type_name] = input_type
43-
return input_type
44+
inputs[type_name] = GraphQLInputObjectType(type_name, fields)
45+
return inputs[type_name]
4446

4547

46-
def get_where_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
48+
def get_where_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
4749
type_name = get_model_where_input_name(model)
4850
if type_name in inputs:
4951
return inputs[type_name]
@@ -63,12 +65,11 @@ def get_fields() -> GraphQLInputFieldMap:
6365

6466
return fields
6567

66-
input_type = GraphQLInputObjectType(type_name, get_fields)
67-
inputs[type_name] = input_type
68-
return input_type
68+
inputs[type_name] = GraphQLInputObjectType(type_name, get_fields)
69+
return inputs[type_name]
6970

7071

71-
def get_order_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
72+
def get_order_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
7273
type_name = get_model_order_by_input_name(model)
7374

7475
def get_fields() -> GraphQLInputFieldMap:
@@ -82,26 +83,28 @@ def get_fields() -> GraphQLInputFieldMap:
8283

8384
return fields
8485

85-
input_type = GraphQLInputObjectType(type_name, get_fields)
86-
inputs[type_name] = input_type
87-
return input_type
86+
inputs[type_name] = GraphQLInputObjectType(type_name, get_fields)
87+
return inputs[type_name]
8888

8989

90-
def get_insert_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
91-
type_name = get_model_insert_input_name(model)
92-
if type_name in inputs:
93-
return inputs[type_name]
94-
90+
def make_model_fields_input_type(model: DeclarativeMeta, type_name: str) -> GraphQLInputObjectType:
9591
fields = {}
9692
for column in get_table(model).columns:
9793
fields[column.name] = GraphQLInputField(get_graphql_type_from_column(column))
9894

99-
input_type = GraphQLInputObjectType(type_name, fields)
100-
inputs[type_name] = input_type
101-
return input_type
95+
return GraphQLInputObjectType(type_name, fields)
96+
97+
98+
def get_insert_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
99+
type_name = get_model_insert_input_name(model)
100+
if type_name in inputs:
101+
return inputs[type_name]
102102

103+
inputs[type_name] = make_model_fields_input_type(model, type_name)
104+
return inputs[type_name]
103105

104-
def get_conflict_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
106+
107+
def get_conflict_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
105108
type_name = get_model_conflict_input_name(model)
106109
if type_name in inputs:
107110
return inputs[type_name]
@@ -113,3 +116,37 @@ def get_conflict_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObj
113116
input_type = GraphQLInputObjectType(type_name, fields)
114117
inputs[type_name] = input_type
115118
return input_type
119+
120+
121+
def get_inc_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
122+
type_name = get_model_inc_input_type_name(model)
123+
if type_name in inputs:
124+
return inputs[type_name]
125+
126+
fields = {}
127+
for column in get_table(model).columns:
128+
if isinstance(column.type, (Integer, Float)):
129+
fields[column.name] = GraphQLInputField(get_graphql_type_from_column(column))
130+
131+
inputs[type_name] = GraphQLInputObjectType(type_name, fields)
132+
return inputs[type_name]
133+
134+
135+
def get_set_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
136+
type_name = get_model_set_input_type_name(model)
137+
if type_name in inputs:
138+
return inputs[type_name]
139+
140+
inputs[type_name] = make_model_fields_input_type(model, type_name)
141+
return inputs[type_name]
142+
143+
144+
def get_pk_columns_input(model: DeclarativeMeta) -> GraphQLInputObjectType:
145+
type_name = get_model_pk_columns_input_type_name(model)
146+
primary_key = get_table(model).primary_key
147+
148+
fields = {}
149+
for column in primary_key.columns:
150+
fields[column.name] = GraphQLInputField(GraphQLNonNull(get_graphql_type_from_column(column)))
151+
152+
return GraphQLInputObjectType(type_name, fields)

src/graphql_sqlalchemy/names.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,23 @@ def get_model_delete_name(model: DeclarativeMeta) -> str:
6666

6767
def get_model_delete_by_pk_name(model: DeclarativeMeta) -> str:
6868
return f"delete_{get_table_name(model)}_by_pk"
69+
70+
71+
def get_model_update_name(model: DeclarativeMeta) -> str:
72+
return f"update_{get_table_name(model)}"
73+
74+
75+
def get_model_update_by_pk_name(model: DeclarativeMeta) -> str:
76+
return f"update_{get_table_name(model)}_by_pk"
77+
78+
79+
def get_model_inc_input_type_name(model: DeclarativeMeta) -> str:
80+
return f"{get_table_name(model)}_inc_input"
81+
82+
83+
def get_model_set_input_type_name(model: DeclarativeMeta) -> str:
84+
return f"{get_table_name(model)}_set_input"
85+
86+
87+
def get_model_pk_columns_input_type_name(model: DeclarativeMeta) -> str:
88+
return f"{get_table_name(model)}_pk_columns_input"

src/graphql_sqlalchemy/resolvers.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def order_query(model: DeclarativeMeta, query: Query, order: Optional[List[Dict[
9999

100100
def make_object_resolver(model: DeclarativeMeta) -> Callable:
101101
def resolver(
102-
_root: DeclarativeMeta,
102+
_root: None,
103103
info: Any,
104104
where: Optional[Dict[str, Any]] = None,
105105
order: Optional[List[Dict[str, Any]]] = None,
@@ -124,7 +124,7 @@ def resolver(
124124

125125

126126
def make_pk_resolver(model: DeclarativeMeta) -> Callable:
127-
def resolver(_root: DeclarativeMeta, info: Any, **kwargs: Dict[str, Any]) -> DeclarativeMeta:
127+
def resolver(_root: None, info: Any, **kwargs: Dict[str, Any]) -> DeclarativeMeta:
128128
session = info.context["session"]
129129
return session.query(model).get(kwargs)
130130

@@ -155,7 +155,7 @@ def session_commit(session: Session) -> None:
155155

156156
def make_insert_resolver(model: DeclarativeMeta) -> Callable:
157157
def resolver(
158-
_root: DeclarativeMeta, info: Any, objects: List[Dict[str, Any]], on_conflict: Optional[Dict[str, Any]] = None
158+
_root: None, info: Any, objects: List[Dict[str, Any]], on_conflict: Optional[Dict[str, Any]] = None
159159
) -> Dict[str, Union[int, List[DeclarativeMeta]]]:
160160
session = info.context["session"]
161161
models = []
@@ -173,7 +173,7 @@ def resolver(
173173

174174
def make_insert_one_resolver(model: DeclarativeMeta) -> Callable:
175175
def resolver(
176-
_root: DeclarativeMeta, info: Any, object: Dict[str, Any], on_conflict: Optional[Dict[str, Any]] = None
176+
_root: None, info: Any, object: Dict[str, Any], on_conflict: Optional[Dict[str, Any]] = None
177177
) -> DeclarativeMeta:
178178
session = info.context["session"]
179179

@@ -186,7 +186,7 @@ def resolver(
186186

187187
def make_delete_resolver(model: DeclarativeMeta) -> Callable:
188188
def resolver(
189-
_root: DeclarativeMeta, info: Any, where: Optional[Dict[str, Any]] = None
189+
_root: None, info: Any, where: Optional[Dict[str, Any]] = None
190190
) -> Dict[str, Union[int, List[DeclarativeMeta]]]:
191191
session = info.context["session"]
192192
query = session.query(model)
@@ -202,7 +202,7 @@ def resolver(
202202

203203

204204
def make_delete_by_pk_resolver(model: DeclarativeMeta) -> Callable:
205-
def resolver(_root: DeclarativeMeta, info: Any, **kwargs: Dict[str, Any]) -> List[DeclarativeMeta]:
205+
def resolver(_root: None, info: Any, **kwargs: Dict[str, Any]) -> List[DeclarativeMeta]:
206206
session = info.context["session"]
207207

208208
row = session.query(model).get(kwargs)
@@ -212,3 +212,56 @@ def resolver(_root: DeclarativeMeta, info: Any, **kwargs: Dict[str, Any]) -> Lis
212212
return row
213213

214214
return resolver
215+
216+
217+
def update_query(
218+
query: Query, model: DeclarativeMeta, _set: Optional[Dict[str, Any]], _inc: Optional[Dict[str, Any]],
219+
) -> int:
220+
affected = 0
221+
if _inc:
222+
to_increment = {}
223+
for column_name, increment in _inc.items():
224+
to_increment[column_name] = getattr(model, column_name) + increment
225+
226+
affected += query.update(to_increment)
227+
228+
if _set:
229+
affected += query.update(_set)
230+
231+
return affected
232+
233+
234+
def make_update_resolver(model: DeclarativeMeta) -> Callable:
235+
def resolver(
236+
_root: None, info: Any, where: Dict[str, Any], _set: Optional[Dict[str, Any]], _inc: Optional[Dict[str, Any]],
237+
) -> Dict[str, Union[int, List[DeclarativeMeta]]]:
238+
session = info.context["session"]
239+
query = session.query(model)
240+
query = filter_query(model, query, where)
241+
affected = update_query(query, model, _set, _inc)
242+
243+
return {
244+
"affected_rows": affected,
245+
"returning": query.all(),
246+
}
247+
248+
return resolver
249+
250+
251+
def make_update_by_pk_resolver(model: DeclarativeMeta) -> Callable:
252+
def resolver(
253+
_root: None,
254+
info: Any,
255+
pk_columns: Dict[str, Any],
256+
_set: Optional[Dict[str, Any]],
257+
_inc: Optional[Dict[str, Any]],
258+
) -> Optional[DeclarativeMeta]:
259+
session = info.context["session"]
260+
query = session.query(model).filter_by(**pk_columns)
261+
262+
if update_query(query, model, _set, _inc):
263+
return query.one()
264+
265+
return None
266+
267+
return resolver

0 commit comments

Comments
 (0)