diff --git a/graphql.go b/graphql.go index 8520956..f1dc4ea 100644 --- a/graphql.go +++ b/graphql.go @@ -34,18 +34,36 @@ func NewClient(url string, httpClient *http.Client) *Client { // with a query derived from q, populating the response into it. // q should be a pointer to struct that corresponds to the GraphQL schema. func (c *Client) Query(ctx context.Context, q interface{}, variables map[string]interface{}) error { - return c.do(ctx, queryOperation, q, variables) + return c.do(ctx, queryOperation, q, variables, nil) +} + +// QueryWithExtensions executes a single GraphQL query request, +// with a query derived from q, populating the response into it. +// q should be a pointer to struct that corresponds to the GraphQL schema. +// Additionally, this will capture the extensions from the response. +// extensions should be a pointer that corresponds to the extensions schema. +func (c *Client) QueryWithExtensions(ctx context.Context, q interface{}, variables map[string]interface{}, extensions interface{}) error { + return c.do(ctx, queryOperation, q, variables, extensions) } // Mutate executes a single GraphQL mutation request, // with a mutation derived from m, populating the response into it. // m should be a pointer to struct that corresponds to the GraphQL schema. func (c *Client) Mutate(ctx context.Context, m interface{}, variables map[string]interface{}) error { - return c.do(ctx, mutationOperation, m, variables) + return c.do(ctx, mutationOperation, m, variables, nil) +} + +// MutateWithExtensions executes a single GraphQL mutation request, +// with a mutation derived from m, populating the response into it. +// m should be a pointer to struct that corresponds to the GraphQL schema. +// Additionally, this will capture the extensions from the response. +// extensions should be a pointer that corresponds to the extensions schema. +func (c *Client) MutateWithExtensions(ctx context.Context, m interface{}, variables map[string]interface{}, extensions interface{}) error { + return c.do(ctx, mutationOperation, m, variables, extensions) } // do executes a single GraphQL operation. -func (c *Client) do(ctx context.Context, op operationType, v interface{}, variables map[string]interface{}) error { +func (c *Client) do(ctx context.Context, op operationType, v interface{}, variables map[string]interface{}, extensions interface{}) error { var query string switch op { case queryOperation: @@ -75,9 +93,9 @@ func (c *Client) do(ctx context.Context, op operationType, v interface{}, variab return fmt.Errorf("non-200 OK status code: %v body: %q", resp.Status, body) } var out struct { - Data *json.RawMessage - Errors errors - //Extensions interface{} // Unused. + Data *json.RawMessage + Errors errors + Extensions *json.RawMessage } err = json.NewDecoder(resp.Body).Decode(&out) if err != nil { @@ -91,6 +109,13 @@ func (c *Client) do(ctx context.Context, op operationType, v interface{}, variab return err } } + if extensions != nil && out.Extensions != nil { + err := jsonutil.UnmarshalGraphQL(*out.Extensions, extensions) + if err != nil { + // TODO: Consider including response body in returned error, if deemed helpful. + return err + } + } if len(out.Errors) > 0 { return out.Errors } diff --git a/graphql_test.go b/graphql_test.go index e09dcc9..a3a75ba 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -153,6 +153,40 @@ func TestClient_Query_emptyVariables(t *testing.T) { } } +func TestClient_QueryWithExtensions(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/graphql", func(w http.ResponseWriter, req *http.Request) { + body := mustRead(req.Body) + if got, want := body, `{"query":"{user{name}}"}`+"\n"; got != want { + t.Errorf("got body: %v, want %v", got, want) + } + w.Header().Set("Content-Type", "application/json") + mustWrite(w, `{"data": {"user": {"name": "Gopher"}}, "extensions": {"cost": {"actualQueryCost":50}}}`) + }) + client := graphql.NewClient("/graphql", &http.Client{Transport: localRoundTripper{handler: mux}}) + + var q struct { + User struct { + Name string + } + } + var e struct { + Cost struct { + ActualQueryCost int + } + } + err := client.QueryWithExtensions(context.Background(), &q, map[string]interface{}{}, &e) + if err != nil { + t.Fatal(err) + } + if got, want := q.User.Name, "Gopher"; got != want { + t.Errorf("got q.User.Name: %q, want: %q", got, want) + } + if got, want := e.Cost.ActualQueryCost, 50; got != want { + t.Errorf("got e.QueryCost: %q, want: %q", got, want) + } +} + // localRoundTripper is an http.RoundTripper that executes HTTP transactions // by using handler directly, instead of going over an HTTP connection. type localRoundTripper struct {