Skip to content

Commit cc344c0

Browse files
author
Guido
committed
Use merge for conflicts for now
In the future leaving open the possibility of using dialects for different DB flavors
1 parent 593370f commit cc344c0

File tree

5 files changed

+93
-45
lines changed

5 files changed

+93
-45
lines changed

src/graphql_sqlalchemy/dialects/__init__.py

Whitespace-only changes.

src/graphql_sqlalchemy/dialects/pg/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from sqlalchemy.ext.declarative import DeclarativeMeta
2+
from graphql import (
3+
GraphQLInputObjectType,
4+
GraphQLList,
5+
GraphQLEnumType,
6+
GraphQLInputField,
7+
GraphQLNonNull,
8+
)
9+
10+
from ...names import (
11+
get_model_conflict_input_name,
12+
get_model_constraint_enum_name,
13+
get_model_constraint_key_name,
14+
get_model_column_update_enum_name,
15+
)
16+
from ...types import Inputs
17+
from ...helpers import get_table
18+
from ...inputs import get_where_type
19+
20+
21+
def get_constraint_enum(model: DeclarativeMeta) -> GraphQLEnumType:
22+
type_name = get_model_constraint_enum_name(model)
23+
24+
fields = {}
25+
for column in get_table(model).primary_key:
26+
key_name = get_model_constraint_key_name(model, column, is_primary_key=True)
27+
fields[key_name] = key_name
28+
29+
return GraphQLEnumType(type_name, fields)
30+
31+
32+
def get_update_column_enums(model: DeclarativeMeta) -> GraphQLEnumType:
33+
type_name = get_model_column_update_enum_name(model)
34+
35+
fields = {}
36+
for column in get_table(model).columns:
37+
fields[column.name] = column.name
38+
39+
return GraphQLEnumType(type_name, fields)
40+
41+
42+
def get_conflict_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
43+
type_name = get_model_conflict_input_name(model)
44+
if type_name in inputs:
45+
return inputs[type_name]
46+
47+
fields = {
48+
"constraint": GraphQLInputField(GraphQLNonNull(get_constraint_enum(model))),
49+
"update_columns": GraphQLInputField(
50+
GraphQLNonNull(GraphQLList(GraphQLNonNull(get_update_column_enums(model))))
51+
),
52+
"where": GraphQLInputField(get_where_type(model, inputs)),
53+
}
54+
55+
input_type = GraphQLInputObjectType(type_name, fields)
56+
inputs[type_name] = input_type
57+
return input_type

src/graphql_sqlalchemy/inputs.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
GraphQLString,
99
GraphQLInputFieldMap,
1010
GraphQLNonNull,
11+
GraphQLBoolean,
1112
)
1213

1314
from .names import (
@@ -16,9 +17,6 @@
1617
get_model_insert_input_name,
1718
get_scalar_comparison_name,
1819
get_model_conflict_input_name,
19-
get_model_constraint_enum_name,
20-
get_model_constraint_key_name,
21-
get_model_column_update_enum_name,
2220
)
2321
from .scalars import get_graphql_type_from_column, get_base_comparison_fields, get_string_comparison_fields
2422
from .types import Inputs
@@ -103,38 +101,13 @@ def get_insert_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjec
103101
return input_type
104102

105103

106-
def get_constraint_enum(model: DeclarativeMeta) -> GraphQLEnumType:
107-
type_name = get_model_constraint_enum_name(model)
108-
109-
fields = {}
110-
for column in get_table(model).primary_key:
111-
key_name = get_model_constraint_key_name(model, column, is_primary_key=True)
112-
fields[key_name] = key_name
113-
114-
return GraphQLEnumType(type_name, fields)
115-
116-
117-
def get_update_column_enums(model: DeclarativeMeta) -> GraphQLEnumType:
118-
type_name = get_model_column_update_enum_name(model)
119-
120-
fields = {}
121-
for column in get_table(model).columns:
122-
fields[column.name] = column.name
123-
124-
return GraphQLEnumType(type_name, fields)
125-
126-
127104
def get_conflict_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
128105
type_name = get_model_conflict_input_name(model)
129106
if type_name in inputs:
130107
return inputs[type_name]
131108

132109
fields = {
133-
"constraint": GraphQLInputField(GraphQLNonNull(get_constraint_enum(model))),
134-
"update_columns": GraphQLInputField(
135-
GraphQLNonNull(GraphQLList(GraphQLNonNull(get_update_column_enums(model))))
136-
),
137-
"where": GraphQLInputField(get_where_type(model, inputs)),
110+
"merge": GraphQLInputField(GraphQLNonNull(GraphQLBoolean)),
138111
}
139112

140113
input_type = GraphQLInputObjectType(type_name, fields)

src/graphql_sqlalchemy/resolvers.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from sqlalchemy import Column, or_, and_, not_, true
66
from sqlalchemy.sql import ClauseElement
7-
from sqlalchemy.orm import Query
7+
from sqlalchemy.orm import Query, Session
88
from sqlalchemy.ext.declarative import DeclarativeMeta
99

1010

@@ -131,36 +131,54 @@ def resolver(_root: DeclarativeMeta, info: Any, **kwargs: Dict[str, Any]) -> Dec
131131
return resolver
132132

133133

134+
def session_add_object(
135+
obj: Dict[str, Any], model: DeclarativeMeta, session: Session, on_conflict: Optional[Dict[str, Any]] = None
136+
) -> DeclarativeMeta:
137+
instance = model()
138+
for key, value in obj.items():
139+
setattr(instance, key, value)
140+
141+
if on_conflict and on_conflict["merge"]:
142+
session.merge(instance)
143+
else:
144+
session.add(instance)
145+
return instance
146+
147+
148+
def session_commit(session: Session) -> None:
149+
try:
150+
session.commit()
151+
except Exception:
152+
session.rollback()
153+
raise
154+
155+
134156
def make_insert_resolver(model: DeclarativeMeta) -> Callable:
135157
def resolver(
136-
_root: DeclarativeMeta, info: Any, objects: List[Dict[str, Any]]
158+
_root: DeclarativeMeta, info: Any, objects: List[Dict[str, Any]], on_conflict: Optional[Dict[str, Any]] = None
137159
) -> Dict[str, Union[int, List[DeclarativeMeta]]]:
138160
session = info.context["session"]
139161
models = []
140-
for obj in objects:
141-
instance = model()
142-
models.append(instance)
143-
for key, value in obj.items():
144-
setattr(instance, key, value)
145162

146-
session.add(instance)
163+
with session.no_autoflush:
164+
for obj in objects:
165+
instance = session_add_object(obj, model, session, on_conflict)
166+
models.append(instance)
147167

148-
session.commit()
168+
session_commit(session)
149169
return {"affected_rows": len(models), "returning": models}
150170

151171
return resolver
152172

153173

154174
def make_insert_one_resolver(model: DeclarativeMeta) -> Callable:
155-
def resolver(_root: DeclarativeMeta, info: Any, object: Dict[str, Any]) -> DeclarativeMeta:
175+
def resolver(
176+
_root: DeclarativeMeta, info: Any, object: Dict[str, Any], on_conflict: Optional[Dict[str, Any]] = None
177+
) -> DeclarativeMeta:
156178
session = info.context["session"]
157179

158-
instance = model()
159-
for key, value in object.items():
160-
setattr(instance, key, value)
161-
162-
session.add(instance)
163-
session.commit()
180+
instance = session_add_object(object, model, session, on_conflict)
181+
session_commit(session)
164182
return instance
165183

166184
return resolver

0 commit comments

Comments
 (0)