Skip to content

Commit

Permalink
Merge pull request #9807 from FlorentinD/gds/progress-tracker-cleanup
Browse files Browse the repository at this point in the history
Cleanup Graph Aggregation tasks in case of failure
  • Loading branch information
FlorentinD authored Nov 6, 2024
2 parents 1485e48 + 1d40e77 commit 60c54bf
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.neo4j.gds.TestTaskStore;
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.core.loading.Capabilities.WriteMode;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.core.utils.progress.EmptyTaskStore;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.metrics.projections.ProjectionMetricsService;
import org.neo4j.values.AnyValue;
import org.neo4j.values.storable.NoValue;
import org.neo4j.values.storable.Values;
import org.neo4j.values.virtual.MapValue;
Expand All @@ -37,6 +39,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class ProductGraphAggregatorTest {

Expand Down Expand Up @@ -85,14 +88,15 @@ void shouldImportHighNodeIds() {
@MethodSource("emptyGraphNames")
void shouldFailOnEmptyGraphName(String emptyGraphName, String description) {

TestTaskStore taskStore = new TestTaskStore();
var aggregator = new ProductGraphAggregator(
DatabaseId.random(),
"neo4j",
WriteMode.LOCAL,
QueryEstimator.empty(),
ExecutingQueryProvider.empty(),
ProjectionMetricsService.DISABLED,
EmptyTaskStore.INSTANCE,
taskStore,
Log.noOpLog()
);

Expand All @@ -105,6 +109,8 @@ void shouldFailOnEmptyGraphName(String emptyGraphName, String description) {
MapValue.EMPTY,
NoValue.NO_VALUE
)).withMessageContaining("`graphName` can not be null or blank");

assertThat(taskStore.tasks()).isEmpty();
}

private static Stream<Arguments> emptyGraphNames() {
Expand All @@ -115,4 +121,33 @@ private static Stream<Arguments> emptyGraphNames() {
Arguments.of(" ", "spaces")
);
}

@Test
void shouldCleanupTaskOnFailure() {
TestTaskStore taskStore = new TestTaskStore();
var aggregator = new ProductGraphAggregator(
DatabaseId.random(),
"neo4j",
WriteMode.LOCAL,
QueryEstimator.empty(),
ExecutingQueryProvider.empty(),
ProjectionMetricsService.DISABLED,
taskStore,
Log.noOpLog()
);

assertThatThrownBy(() ->
aggregator.update(new AnyValue[] {
Values.stringValue("my-graph"),
Values.longValue(1L),
Values.stringValue("invalidID"),
MapValue.EMPTY,
MapValue.EMPTY,
NoValue.NO_VALUE }
))
.hasCauseInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("The node has to be either a NODE or an INTEGER, but got String");

assertThat(taskStore.tasks()).isEmpty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public void update(AnyValue[] input) throws ProcedureException {
NoValue.NO_VALUE
);
} catch (Exception e) {
super.onFailure();
throw new ProcedureException(
Status.Procedure.ProcedureCallFailed,
e,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.neo4j.gds.core.utils.progress.BatchingTaskProgressTracker;
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
import org.neo4j.gds.core.utils.progress.TaskStore;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory;
import org.neo4j.gds.logging.Log;
Expand Down Expand Up @@ -97,6 +98,7 @@ abstract class GraphAggregator implements UserAggregationReducer, UserAggregatio

// #result() may be called twice, we cache the result of the first call to return it again in the second invocation
private @Nullable AggregationResult result;
private ProgressTracker progressTracker;

GraphAggregator(
DatabaseId databaseId,
Expand Down Expand Up @@ -213,7 +215,7 @@ private GraphImporter createGraphImporter(
TaskRegistryFactory.local(username, taskStore),
EmptyUserLogRegistryFactory.INSTANCE
);
var progressTracker = BatchingTaskProgressTracker.create(internalProgressTracker, taskVolume, config.readConcurrency());
this.progressTracker = BatchingTaskProgressTracker.create(internalProgressTracker, taskVolume, config.readConcurrency());

return new GraphImporter(
config,
Expand Down Expand Up @@ -281,6 +283,7 @@ public AnyValue result() throws ProcedureException {
projectionMetric.start();
result = buildGraph();
} catch (Exception e) {
this.onFailure();
projectionMetric.failed(e);
throw new ProcedureException(
Status.Procedure.ProcedureCallFailed,
Expand All @@ -305,6 +308,12 @@ public AnyValue result() throws ProcedureException {
return builder.build();
}

void onFailure() {
if (progressTracker != null) {
this.progressTracker.endSubTaskWithFailure();
}
}

public @Nullable AggregationResult buildGraph() {
var importer = this.importer;
if (importer == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public void update(AnyValue[] input) throws ProcedureException {
input[5]
);
} catch (Exception e) {
super.onFailure();
throw new ProcedureException(
Status.Procedure.ProcedureCallFailed,
e,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,5 +451,7 @@ void shouldRegisterTaskAndLogProgress() {
log.assertContainsMessage(TestLog.INFO, "Graph aggregation :: Build graph store :: Relationships :: Finished");
log.assertContainsMessage(TestLog.INFO, "Graph aggregation :: Build graph store :: Finished");
log.assertContainsMessage(TestLog.INFO, "Graph aggregation :: Finished");
}

assertThat(taskStore.tasks()).isEmpty();
}
}

0 comments on commit 60c54bf

Please sign in to comment.