Skip to content

Commit

Permalink
fix #1689: properly add/remove federation directives
Browse files Browse the repository at this point in the history
  • Loading branch information
t1 committed Jan 8, 2023
1 parent b1181bd commit a7af375
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ public class SchemaBuilder {
private final DirectiveTypeCreator directiveTypeCreator;
private final UnionCreator unionCreator;

private final DotName FEDERATION_ANNOTATIONS_PACKAGE = DotName.createSimple("io.smallrye.graphql.api.federation");

/**
* This builds the Schema from Jandex
*
Expand Down Expand Up @@ -162,13 +160,7 @@ private void addDirectiveTypes(Schema schema) {
// custom directives from annotations
for (AnnotationInstance annotationInstance : ScanningContext.getIndex().getAnnotations(DIRECTIVE)) {
ClassInfo classInfo = annotationInstance.target().asClass();
boolean federationEnabled = Boolean.getBoolean("smallrye.graphql.federation.enabled");
// only add federation-related directive types to the schema if federation is enabled
DotName packageName = classInfo.name().packagePrefixName();
if (packageName == null || !packageName.equals(FEDERATION_ANNOTATIONS_PACKAGE) || federationEnabled) {
schema.addDirectiveType(directiveTypeCreator.create(classInfo));
}

schema.addDirectiveType(directiveTypeCreator.create(classInfo));
}
// bean validation directives
schema.addDirectiveType(BeanValidationDirectivesHelper.CONSTRAINT_DIRECTIVE_TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ public void setRepeatable(boolean repeatable) {
this.repeatable = repeatable;
}

public boolean isFederation() {
return className != null && className.startsWith("io.smallrye.graphql.api.federation");
}

/**
* Helper 'getter' methods, but DON'T add 'get' into their names, otherwise it breaks Quarkus bytecode recording,
* because they would be detected as actual property getters while they are actually not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,17 @@ private TypeResolver fetchEntityType() {
private void createGraphQLDirectiveTypes() {
if (schema.hasDirectiveTypes()) {
for (DirectiveType directiveType : schema.getDirectiveTypes()) {
createGraphQLDirectiveType(directiveType);
if (enabled(directiveType)) {
createGraphQLDirectiveType(directiveType);
}
}
}
}

private static boolean enabled(DirectiveType directiveType) {
return Config.get().isFederationEnabled() || !directiveType.isFederation();
}

private void createGraphQLDirectiveType(DirectiveType directiveType) {
GraphQLDirective.Builder directiveBuilder = GraphQLDirective.newDirective()
.name(directiveType.getName())
Expand Down Expand Up @@ -372,7 +378,8 @@ private void createGraphQLEnumType(EnumType enumType) {
.description(enumType.getDescription());
// Directives
if (enumType.hasDirectiveInstances()) {
enumBuilder = enumBuilder.withDirectives(createGraphQLDirectives(enumType.getDirectiveInstances()));
enumBuilder = enumBuilder
.withDirectives(createGraphQLDirectives(enumType.getDirectiveInstances()));
}
// Values
for (EnumValue value : enumType.getValues()) {
Expand All @@ -381,7 +388,8 @@ private void createGraphQLEnumType(EnumType enumType) {
.value(value.getValue())
.description(value.getDescription());
if (value.hasDirectiveInstances()) {
definitionBuilder = definitionBuilder.withDirectives(createGraphQLDirectives(value.getDirectiveInstances()));
definitionBuilder = definitionBuilder
.withDirectives(createGraphQLDirectives(value.getDirectiveInstances()));
}
enumBuilder = enumBuilder.value(definitionBuilder.build());
}
Expand Down Expand Up @@ -411,9 +419,8 @@ private void createGraphQLInterfaceType(Type interfaceType) {

// Directives
if (interfaceType.hasDirectiveInstances()) {
for (DirectiveInstance directiveInstance : interfaceType.getDirectiveInstances()) {
interfaceTypeBuilder.withDirective(createGraphQLDirectiveFrom(directiveInstance));
}
interfaceTypeBuilder = interfaceTypeBuilder
.withDirectives(createGraphQLDirectives(interfaceType.getDirectiveInstances()));
}

// Interfaces
Expand Down Expand Up @@ -502,7 +509,7 @@ private GraphQLInputObjectType createGraphQLInputObjectType(InputType inputType)
// Directives
if (inputType.hasDirectiveInstances()) {
inputObjectTypeBuilder = inputObjectTypeBuilder
.withDirectives(createGraphQLDirectives(inputType.getDirectiveInstances()));
.withDirectives(createGraphQLDirectives(inputType.getDirectiveInstances()));
}

// Fields
Expand Down Expand Up @@ -535,9 +542,8 @@ private void createGraphQLObjectType(Type type) {

// Directives
if (type.hasDirectiveInstances()) {
for (DirectiveInstance directiveInstance : type.getDirectiveInstances()) {
objectTypeBuilder.withDirective(createGraphQLDirectiveFrom(directiveInstance));
}
objectTypeBuilder = objectTypeBuilder
.withDirectives(createGraphQLDirectives(type.getDirectiveInstances()));
}

// Fields
Expand Down Expand Up @@ -657,7 +663,8 @@ private GraphQLFieldDefinition createGraphQLFieldDefinitionFromOperation(String

// Directives
if (operation.hasDirectiveInstances()) {
fieldBuilder = fieldBuilder.withDirectives(createGraphQLDirectives(operation.getDirectiveInstances()));
fieldBuilder = fieldBuilder
.withDirectives(createGraphQLDirectives(operation.getDirectiveInstances()));
}

GraphQLFieldDefinition graphQLFieldDefinition = fieldBuilder.build();
Expand All @@ -673,6 +680,7 @@ private GraphQLFieldDefinition createGraphQLFieldDefinitionFromOperation(String

private GraphQLDirective[] createGraphQLDirectives(Collection<DirectiveInstance> directiveInstances) {
return directiveInstances.stream()
.filter(directiveInstance -> enabled(directiveInstance.getType()))
.map(this::createGraphQLDirectiveFrom)
.toArray(GraphQLDirective[]::new);
}
Expand All @@ -695,9 +703,8 @@ private GraphQLFieldDefinition createGraphQLFieldDefinitionFromField(Reference o

// Directives
if (field.hasDirectiveInstances()) {
for (DirectiveInstance directiveInstance : field.getDirectiveInstances()) {
fieldBuilder.withDirective(createGraphQLDirectiveFrom(directiveInstance));
}
fieldBuilder = fieldBuilder
.withDirectives(createGraphQLDirectives(field.getDirectiveInstances()));
}

// Auto Map argument
Expand Down Expand Up @@ -770,10 +777,10 @@ private GraphQLInputObjectField createGraphQLInputObjectFieldFromField(Field fie
// Type
inputFieldBuilder = inputFieldBuilder.type(createGraphQLInputType(field));

// Directives
if (field.hasDirectiveInstances()) {
for (DirectiveInstance directiveInstance : field.getDirectiveInstances()) {
inputFieldBuilder.withDirective(createGraphQLDirectiveFrom(directiveInstance));
}
inputFieldBuilder = inputFieldBuilder
.withDirectives(createGraphQLDirectives(field.getDirectiveInstances()));
}

// Default value (on method)
Expand Down Expand Up @@ -942,12 +949,13 @@ private GraphQLArgument createGraphQLArgument(Argument argument) {
graphQLInputType = GraphQLNonNull.nonNull(graphQLInputType);
}

// Type
argumentBuilder = argumentBuilder.type(graphQLInputType);

// Directives
if (argument.hasDirectiveInstances()) {
for (DirectiveInstance directiveInstance : argument.getDirectiveInstances()) {
argumentBuilder.withDirective(createGraphQLDirectiveFrom(directiveInstance));
}
argumentBuilder = argumentBuilder
.withDirectives(createGraphQLDirectives(argument.getDirectiveInstances()));
}

return argumentBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static graphql.introspection.Introspection.DirectiveLocation.INTERFACE;
import static graphql.introspection.Introspection.DirectiveLocation.OBJECT;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toSet;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
Expand All @@ -17,6 +18,7 @@
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.util.EnumSet;
import java.util.Set;
import java.util.stream.Stream;

import org.jboss.jandex.IndexView;
Expand Down Expand Up @@ -121,14 +123,15 @@ void schemaWithInputDirectives() {
@Test
void testSchemaWithFederationDisabled() {
config.federationEnabled = false;
// need to set it as system property because the SchemaBuilder doesn't have access to the Config object
System.setProperty("smallrye.graphql.federation.enabled", "false");

GraphQLSchema graphQLSchema = createGraphQLSchema(Directive.class, Key.class, Keys.class,
TestTypeWithFederation.class, FederationTestApi.class);

assertNull(graphQLSchema.getDirective("key"));
assertEquals(Set.of("include", "specifiedBy", "deprecated", "skip", "constraint"),
graphQLSchema.getDirectives().stream().map(GraphQLDirective::getName).collect(toSet()));
assertNull(graphQLSchema.getDirective("key")); // esp. NOT this one
assertNull(graphQLSchema.getType("_Entity"));
assertNull(graphQLSchema.getType("_Service"));

GraphQLObjectType queryRoot = graphQLSchema.getQueryType();
assertEquals(1, queryRoot.getFields().size());
Expand Down Expand Up @@ -156,66 +159,60 @@ void testSchemaWithFederationDisabled() {
@Test
void testSchemaWithFederationEnabled() {
config.federationEnabled = true;
// need to set it as system property because the SchemaBuilder doesn't have access to the Config object
System.setProperty("smallrye.graphql.federation.enabled", "true");
try {
GraphQLSchema graphQLSchema = createGraphQLSchema(Repeatable.class, Directive.class, Key.class, Keys.class,
TestTypeWithFederation.class, FederationTestApi.class);

GraphQLDirective keyDirective = graphQLSchema.getDirective("key");
assertEquals("key", keyDirective.getName());
assertTrue(keyDirective.isRepeatable());
assertEquals(
"Designates an object type as an entity and specifies its key fields " +
"(a set of fields that the subgraph can use to uniquely identify any instance " +
"of the entity). You can apply multiple @key directives to a single entity " +
"(to specify multiple valid sets of key fields).",
keyDirective.getDescription());
assertEquals(EnumSet.of(OBJECT, INTERFACE), keyDirective.validLocations());
assertEquals(1, keyDirective.getArguments().size());
assertEquals("String", ((GraphQLScalarType) keyDirective.getArgument("fields").getType()).getName());

GraphQLUnionType entityType = (GraphQLUnionType) graphQLSchema.getType("_Entity");
assertNotNull(entityType);
assertEquals(1, entityType.getTypes().size());
assertEquals(TestTypeWithFederation.class.getSimpleName(), entityType.getTypes().get(0).getName());

GraphQLObjectType queryRoot = graphQLSchema.getQueryType();
assertEquals(3, queryRoot.getFields().size());

GraphQLFieldDefinition entities = queryRoot.getField("_entities");
assertEquals(1, entities.getArguments().size());
assertEquals("[_Any!]!", entities.getArgument("representations").getType().toString());
assertEquals("[_Entity]!", entities.getType().toString());

GraphQLFieldDefinition service = queryRoot.getField("_service");
assertEquals(0, service.getArguments().size());
assertEquals("_Service!", service.getType().toString());

GraphQLFieldDefinition query = queryRoot.getField("testTypeWithFederation");
assertEquals(1, query.getArguments().size());
assertEquals(GraphQLString, query.getArgument("arg").getType());
assertEquals("TestTypeWithFederation", ((GraphQLObjectType) query.getType()).getName());

GraphQLObjectType type = graphQLSchema.getObjectType("TestTypeWithFederation");
assertEquals(2, type.getDirectives().size());
assertKeyDirective(type.getDirectives().get(0), "id");
assertKeyDirective(type.getDirectives().get(1), "type id");
assertEquals(3, type.getFields().size());
assertEquals("id", type.getFields().get(0).getName());
assertEquals(GraphQLString, type.getFields().get(0).getType());
assertEquals("type", type.getFields().get(1).getName());
assertEquals(GraphQLString, type.getFields().get(1).getType());
assertEquals("value", type.getFields().get(2).getName());
assertEquals(GraphQLString, type.getFields().get(2).getType());

GraphQLObjectType serviceType = graphQLSchema.getObjectType("_Service");
assertEquals(1, serviceType.getFields().size());
assertEquals("sdl", serviceType.getFields().get(0).getName());
assertEquals("String!", serviceType.getFields().get(0).getType().toString());
} finally {
System.clearProperty("smallrye.graphql.federation.enabled");
}
GraphQLSchema graphQLSchema = createGraphQLSchema(Repeatable.class, Directive.class, Key.class, Keys.class,
TestTypeWithFederation.class, FederationTestApi.class);

GraphQLDirective keyDirective = graphQLSchema.getDirective("key");
assertEquals("key", keyDirective.getName());
assertTrue(keyDirective.isRepeatable());
assertEquals(
"Designates an object type as an entity and specifies its key fields " +
"(a set of fields that the subgraph can use to uniquely identify any instance " +
"of the entity). You can apply multiple @key directives to a single entity " +
"(to specify multiple valid sets of key fields).",
keyDirective.getDescription());
assertEquals(EnumSet.of(OBJECT, INTERFACE), keyDirective.validLocations());
assertEquals(1, keyDirective.getArguments().size());
assertEquals("String", ((GraphQLScalarType) keyDirective.getArgument("fields").getType()).getName());

GraphQLUnionType entityType = (GraphQLUnionType) graphQLSchema.getType("_Entity");
assertNotNull(entityType);
assertEquals(1, entityType.getTypes().size());
assertEquals(TestTypeWithFederation.class.getSimpleName(), entityType.getTypes().get(0).getName());

GraphQLObjectType queryRoot = graphQLSchema.getQueryType();
assertEquals(3, queryRoot.getFields().size());

GraphQLFieldDefinition entities = queryRoot.getField("_entities");
assertEquals(1, entities.getArguments().size());
assertEquals("[_Any!]!", entities.getArgument("representations").getType().toString());
assertEquals("[_Entity]!", entities.getType().toString());

GraphQLFieldDefinition service = queryRoot.getField("_service");
assertEquals(0, service.getArguments().size());
assertEquals("_Service!", service.getType().toString());

GraphQLFieldDefinition query = queryRoot.getField("testTypeWithFederation");
assertEquals(1, query.getArguments().size());
assertEquals(GraphQLString, query.getArgument("arg").getType());
assertEquals("TestTypeWithFederation", ((GraphQLObjectType) query.getType()).getName());

GraphQLObjectType type = graphQLSchema.getObjectType("TestTypeWithFederation");
assertEquals(2, type.getDirectives().size());
assertKeyDirective(type.getDirectives().get(0), "id");
assertKeyDirective(type.getDirectives().get(1), "type id");
assertEquals(3, type.getFields().size());
assertEquals("id", type.getFields().get(0).getName());
assertEquals(GraphQLString, type.getFields().get(0).getType());
assertEquals("type", type.getFields().get(1).getName());
assertEquals(GraphQLString, type.getFields().get(1).getType());
assertEquals("value", type.getFields().get(2).getName());
assertEquals(GraphQLString, type.getFields().get(2).getType());

GraphQLObjectType serviceType = graphQLSchema.getObjectType("_Service");
assertEquals(1, serviceType.getFields().size());
assertEquals("sdl", serviceType.getFields().get(0).getName());
assertEquals("String!", serviceType.getFields().get(0).getType().toString());
}

private static void assertKeyDirective(GraphQLDirective graphQLDirective, String value) {
Expand Down

0 comments on commit a7af375

Please sign in to comment.