diff --git a/graphql.go b/graphql.go index 2b1f6a29..def7c86c 100644 --- a/graphql.go +++ b/graphql.go @@ -28,6 +28,12 @@ type Params struct { // one operation. OperationName string + // ValidationRules is for overriding rules of document validation. Default + // SpecifiedRules are ignored when specified this option other than nil. + // So it would be better that combining your rules with SpecifiedRules to + // fill this. + ValidationRules []ValidationRuleFn + // Context may be provided to pass application-specific per-request // information to resolve functions. Context context.Context @@ -84,7 +90,7 @@ func Do(p Params) *Result { } // validate document - validationResult := ValidateDocument(&p.Schema, AST, nil) + validationResult := ValidateDocument(&p.Schema, AST, p.ValidationRules) if !validationResult.IsValid { // run validation finish functions for extensions diff --git a/graphql_test.go b/graphql_test.go index 8b06a7b1..9cc178d9 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -6,6 +6,9 @@ import ( "testing" "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/kinds" + "github.com/graphql-go/graphql/language/visitor" "github.com/graphql-go/graphql/testutil" ) @@ -268,3 +271,122 @@ func TestEmptyStringIsNotNull(t *testing.T) { t.Errorf("wrong result, query: %v, graphql result diff: %v", query, testutil.Diff(expected, result)) } } + +func TestQueryWithCustomRule(t *testing.T) { + // Test graphql.Do() with custom rule, it extracts query name from each + // Tests. + ruleN := len(graphql.SpecifiedRules) + rules := make([]graphql.ValidationRuleFn, ruleN+1) + copy(rules[:ruleN], graphql.SpecifiedRules) + + var ( + queryFound bool + queryName string + ) + rules[ruleN] = func(context *graphql.ValidationContext) *graphql.ValidationRuleInstance { + return &graphql.ValidationRuleInstance{ + VisitorOpts: &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.OperationDefinition: { + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + od, ok := p.Node.(*ast.OperationDefinition) + if ok && od.Operation == "query" { + queryFound = true + if od.Name != nil { + queryName = od.Name.Value + } + } + return visitor.ActionNoChange, nil + }, + }, + }, + }, + } + } + + expectedNames := []string{ + "HeroNameQuery", + "HeroNameAndFriendsQuery", + "HumanByIdQuery", + } + + for i, test := range Tests { + queryFound, queryName = false, "" + params := graphql.Params{ + Schema: test.Schema, + RequestString: test.Query, + VariableValues: test.Variables, + ValidationRules: rules, + } + testGraphql(test, params, t) + if !queryFound { + t.Fatal("can't detect \"query\" operation by validation rule") + } + if queryName != expectedNames[i] { + t.Fatalf("unexpected query name: want=%s got=%s", queryName, expectedNames) + } + } +} + +// TestCustomRuleWithArgs tests graphql.GetArgumentValues() be able to access +// field's argument values from custom validation rule. +func TestCustomRuleWithArgs(t *testing.T) { + fieldDef, ok := testutil.StarWarsSchema.QueryType().Fields()["human"] + if !ok { + t.Fatal("can't retrieve \"human\" field definition") + } + + // a custom validation rule to extract argument values of "human" field. + var actual map[string]interface{} + enter := func(p visitor.VisitFuncParams) (string, interface{}) { + // only interested in "human" field. + fieldNode, ok := p.Node.(*ast.Field) + if !ok || fieldNode.Name == nil || fieldNode.Name.Value != "human" { + return visitor.ActionNoChange, nil + } + // extract argument values by graphql.GetArgumentValues(). + actual = graphql.GetArgumentValues(fieldDef.Args, fieldNode.Arguments, nil) + return visitor.ActionNoChange, nil + } + checkHumanArgs := func(context *graphql.ValidationContext) *graphql.ValidationRuleInstance { + return &graphql.ValidationRuleInstance{ + VisitorOpts: &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.Field: {Enter: enter}, + }, + }, + } + } + + for _, tc := range []struct { + query string + expected map[string]interface{} + }{ + { + `query { human(id: "1000") { name } }`, + map[string]interface{}{"id": "1000"}, + }, + { + `query { human(id: "1002") { name } }`, + map[string]interface{}{"id": "1002"}, + }, + { + `query { human(id: "9999") { name } }`, + map[string]interface{}{"id": "9999"}, + }, + } { + actual = nil + params := graphql.Params{ + Schema: testutil.StarWarsSchema, + RequestString: tc.query, + ValidationRules: append(graphql.SpecifiedRules, checkHumanArgs), + } + result := graphql.Do(params) + if len(result.Errors) > 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + if !reflect.DeepEqual(actual, tc.expected) { + t.Fatalf("unexpected result: want=%+v got=%+v", tc.expected, actual) + } + } +} diff --git a/values.go b/values.go index 06c08af6..8dc5210f 100644 --- a/values.go +++ b/values.go @@ -67,6 +67,17 @@ func getArgumentValues( return results } +// GetArgumentValues prepares an object map of argument values given a list of +// argument definitions and list of argument AST nodes. +// +// This is an exported version of getArgumentValues(), to ease writing custom +// validation rules. +func GetArgumentValues( + argDefs []*Argument, argASTs []*ast.Argument, + variableValues map[string]interface{}) map[string]interface{} { + return getArgumentValues(argDefs, argASTs, variableValues) +} + // Given a variable definition, and any value of input, return a value which // adheres to the variable definition, or throw an error. func getVariableValue(schema Schema, definitionAST *ast.VariableDefinition, input interface{}) (interface{}, error) {