From 5722529553b4f9332a413232ef215dbc703b88f6 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Mon, 25 Nov 2024 17:39:03 +0100 Subject: [PATCH] Decouple Future initialization from its computation --- shared/src/main/scala/async/futures.scala | 32 ++++++++++++++-------- shared/src/test/scala/FutureBehavior.scala | 24 ++++++++++++++++ 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/shared/src/main/scala/async/futures.scala b/shared/src/main/scala/async/futures.scala index ec907f76..ff10efde 100644 --- a/shared/src/main/scala/async/futures.scala +++ b/shared/src/main/scala/async/futures.scala @@ -41,6 +41,9 @@ import scala.util.{Failure, Success, Try} trait Future[+T] extends Async.OriginalSource[Try[T]], Cancellable object Future: + trait DeferredFuture[+T] extends Future[T]: + def start(): Unit + /** A future that is completed explicitly by calling its `complete` method. There are three public implementations * * - RunnableFuture: Completion is done by running a block of code @@ -107,7 +110,7 @@ object Future: /** A future that is completed by evaluating `body` as a separate asynchronous operation in the given `scheduler` */ - private class RunnableFuture[+T](body: Async.Spawn ?=> T)(using ac: Async) extends CoreFuture[T]: + private class RunnableFuture[+T](body: Async.Spawn ?=> T)(using ac: Async) extends CoreFuture[T], DeferredFuture[T]: /** RunnableFuture maintains its own inner [[CompletionGroup]], that is separated from the provided Async * instance's. When the future is cancelled, we only cancel this CompletionGroup. This effectively means any @@ -205,16 +208,18 @@ object Future: override def cancel(): Unit = if setCancelled() then this.innerGroup.cancel() + def start(): Unit = + ac.support.scheduleBoundary: + val result = Async.withNewCompletionGroup(innerGroup)(Try({ + val r = body + checkCancellation() + r + }).recoverWith { case _: InterruptedException | _: CancellationException => + Failure(new CancellationException()) + })(using FutureAsync(CompletionGroup.Unlinked)) + complete(result) + link() - ac.support.scheduleBoundary: - val result = Async.withNewCompletionGroup(innerGroup)(Try({ - val r = body - checkCancellation() - r - }).recoverWith { case _: InterruptedException | _: CancellationException => - Failure(new CancellationException()) - })(using FutureAsync(CompletionGroup.Unlinked)) - complete(result) end RunnableFuture @@ -222,7 +227,9 @@ object Future: * future is linked to the given [[Async.Spawn]] scope by default, i.e. it is cancelled when this scope ends. */ def apply[T](body: Async.Spawn ?=> T)(using async: Async, spawnable: Async.Spawn & async.type): Future[T] = - RunnableFuture(body) + val future = RunnableFuture(body) + future.start() + future /** A future that is immediately completed with the given result. */ def now[T](result: Try[T]): Future[T] = @@ -239,6 +246,9 @@ object Future: /** A future that immediately rejects with the given exception. Similar to `Future.now(Failure(exception))`. */ inline def rejected(exception: Throwable): Future[Nothing] = now(Failure(exception)) + def deferred[T](body: Async.Spawn ?=> T)(using async: Async, spawnable: Async.Spawn & async.type): DeferredFuture[T] = + RunnableFuture(body) + extension [T](f1: Future[T]) /** Parallel composition of two futures. If both futures succeed, succeed with their values in a pair. Otherwise, * fail with the failure that was returned first. diff --git a/shared/src/test/scala/FutureBehavior.scala b/shared/src/test/scala/FutureBehavior.scala index e596d672..2c8f9cd4 100644 --- a/shared/src/test/scala/FutureBehavior.scala +++ b/shared/src/test/scala/FutureBehavior.scala @@ -447,4 +447,28 @@ class FutureBehavior extends munit.FunSuite { reader.awaitResult assertEquals(ch.read(), Right(2)) } + + test("deferred futures") { + Async.blocking: + val counter = AtomicInteger(0) + val a = new Array[Future.DeferredFuture[Int]](4) + + a(0) = Future.deferred: + counter.incrementAndGet() + a(1).await + a(2).await + a(1) = Future.deferred: + counter.incrementAndGet() + a(3).await + 4 + a(2) = Future.deferred: + counter.incrementAndGet() + a(3).await + 2 + a(3) = Future.deferred: + counter.incrementAndGet() + 1 + + a.foreach(_.start()) + + assertEquals(a(0).await, 8) + assertEquals(counter.get(), 4) + } }