Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,35 @@ default AsyncRunnable thenRunRetryingWhile(
});
}

/**
* This method is equivalent to a while loop, where the condition is checked before each iteration.
* If the condition returns {@code false} on the first check, the body is never executed.
*
* @param whileCheck a condition to check before each iteration; the loop continues as long as this condition returns true
* @param loopBodyRunnable the asynchronous task to be executed in each iteration of the loop
* @return the composition of this and the looping branch
* @see AsyncCallbackLoop
*/
default AsyncRunnable thenRunWhileLoop(final BooleanSupplier whileCheck, final AsyncRunnable loopBodyRunnable) {
return thenRun(finalCallback -> {
LoopState loopState = new LoopState();
new AsyncCallbackLoop(loopState, iterationCallback -> {

if (loopState.breakAndCompleteIf(() -> !whileCheck.getAsBoolean(), iterationCallback)) {
return;
}
loopBodyRunnable.finish((result, t) -> {
if (t != null) {
iterationCallback.completeExceptionally(t);
return;
}
iterationCallback.complete(iterationCallback);
});

}).run(finalCallback);
});
}

/**
* This method is equivalent to a do-while loop, where the loop body is executed first and
* then the condition is checked to determine whether the loop should continue.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mongodb.internal.async;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.lang.Nullable;

/**
* A trampoline that converts recursive callback invocations into an iterative loop,
* preventing stack overflow in async loops.
*
* <p>When async loop iterations complete synchronously on the same thread, callback
* recursion occurs: each iteration's {@code callback.onResult()} immediately triggers
* the next iteration, causing unbounded stack growth. For example, a 1000-iteration
* loop would create > 1000 stack frames and cause {@code StackOverflowError}.</p>
*
* <p>The trampoline intercepts this recursion: instead of executing the next iteration
* immediately (which would deepen the stack), it enqueues the continuation and returns, allowing
* the stack to unwind. A flat loop at the top then processes enqueued continuation iteratively,
* maintaining constant stack depth regardless of iteration count.</p>
*
* <p>Since async chains are sequential, at most one task is pending at any time.
* The trampoline uses a single slot rather than a queue.</p>
*
* The first call on a thread becomes the "trampoline owner" and runs the drain loop.
* Subsequent (re-entrant) calls on the same thread enqueue their continuation and return immediately.</p>
*
* <p>This class is not part of the public API and may be removed or changed at any time</p>
*/
@NotThreadSafe
public final class AsyncTrampoline {

private static final ThreadLocal<ContinuationHolder> TRAMPOLINE = new ThreadLocal<>();

private AsyncTrampoline() {}

/**
* Execute continuation through the trampoline. If no trampoline is active, become the owner
* and drain all enqueued continuations. If a trampoline is already active, enqueue and return.
*/
public static void run(final Runnable continuation) {
ContinuationHolder continuationHolder = TRAMPOLINE.get();
if (continuationHolder != null) {
continuationHolder.enqueue(continuation);
} else {
continuationHolder = new ContinuationHolder();
TRAMPOLINE.set(continuationHolder);
try {
continuation.run();
while (continuationHolder.continuation != null) {
Runnable continuationToRun = continuationHolder.continuation;
continuationHolder.continuation = null;
continuationToRun.run();
}
} finally {
TRAMPOLINE.remove();
}
}
}

/**
* A single-slot container for continuation.
* At most one continuation is pending at any time in a sequential async chain.
*/
@NotThreadSafe
private static final class ContinuationHolder {
@Nullable
private Runnable continuation;

void enqueue(final Runnable continuation) {
if (this.continuation != null) {
throw new AssertionError("Trampoline slot already occupied");
}
this.continuation = continuation;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.mongodb.internal.async.function;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.internal.async.AsyncTrampoline;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.lang.Nullable;

Expand Down Expand Up @@ -62,9 +63,11 @@ public void run(final SingleResultCallback<Void> callback) {
@NotThreadSafe
private class LoopingCallback implements SingleResultCallback<Void> {
private final SingleResultCallback<Void> wrapped;
private final Runnable nextIteration;

LoopingCallback(final SingleResultCallback<Void> callback) {
wrapped = callback;
nextIteration = () -> AsyncCallbackLoop.this.body.run(this);
}

@Override
Expand All @@ -80,7 +83,7 @@ public void onResult(@Nullable final Void result, @Nullable final Throwable t) {
return;
}
if (continueLooping) {
body.run(this);
AsyncTrampoline.run(nextIteration);
} else {
wrapped.onResult(result, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
import static org.junit.jupiter.api.Assertions.assertEquals;

abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0));
Expand Down Expand Up @@ -723,6 +724,120 @@ void testTryCatchTestAndRethrow() {
});
}

@Test
void testWhile() {
// last iteration: 3 < 3 = 1
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
assertBehavesSameVariations(10,
() -> {
int counter = 0;
while (counter < 3 && plainTest(counter)) {
counter++;
sync(counter);
}
},
(callback) -> {
MutableValue<Integer> counter = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
counter.set(counter.get() + 1);
async(counter.get(), c2);
}).finish(callback);
});
}

@Test
void testWhileWithThenRun() {
// while: last iteration: 3 < 3 = 1
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
// trailing sync: 1(exception) + 1(success) = 2
// 6(while exception) + 4(while success) * 2(trailing sync) = 14
assertBehavesSameVariations(14,
() -> {
int counter = 0;
while (counter < 3 && plainTest(counter)) {
counter++;
sync(counter);
}
sync(counter + 1);
},
(callback) -> {
MutableValue<Integer> counter = new MutableValue<>(0);
beginAsync().thenRun(c -> {
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
counter.set(counter.get() + 1);
async(counter.get(), c2);
}).finish(c);
}).thenRun(c -> {
async(counter.get() + 1, c);
}).finish(callback);
});
}

@Test
void testNestedWhileLoops() {
// inner while: 4 success + 6 exception = 10
// last inner iteration: 3 < 3 = 1
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 1(transition to next iteration) = 12
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 12(transition to next iteration) = 56
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 56(transition to next iteration) = 232
assertBehavesSameVariations(232,
() -> {
int outer = 0;
while (outer < 3 && plainTest(outer)) {
int inner = 0;
while (inner < 3 && plainTest(inner)) {
sync(outer + inner);
inner++;
}
outer++;
}
},
(callback) -> {
MutableValue<Integer> outer = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> outer.get() < 3 && plainTest(outer.get()), c -> {
MutableValue<Integer> inner = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(
() -> inner.get() < 3 && plainTest(inner.get()),
c2 -> {
beginAsync().thenRun(c3 -> {
async(outer.get() + inner.get(), c3);
}).thenRun(c3 -> {
inner.set(inner.get() + 1);
c3.complete(c3);
}).finish(c2);
}
).thenRun(c2 -> {
outer.set(outer.get() + 1);
c2.complete(c2);
}).finish(c);
}).finish(callback);
});
}

@Test
void testWhileLoopStackConstant() {
int depthWith100 = maxStackDepthForIterations(100);
int depthWith10000 = maxStackDepthForIterations(10_000);
assertEquals(depthWith100, depthWith10000, "Stack depth should be constant regardless of iteration count (trampoline)");
}
Comment thread
vbabanin marked this conversation as resolved.

private int maxStackDepthForIterations(final int iterations) {
Comment thread
strogiyotec marked this conversation as resolved.
MutableValue<Integer> counter = new MutableValue<>(0);
MutableValue<Integer> maxDepth = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> counter.get() < iterations, c -> {
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
counter.set(counter.get() + 1);
c.complete(c);
}).finish((v, t) -> {});

assertEquals(iterations, counter.get());
return maxDepth.get();
}

@Test
void testRetryLoop() {
assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1,
Expand Down Expand Up @@ -768,6 +883,65 @@ void testDoWhileLoop() {
});
}

@Test
void testNestedDoWhileLoops() {
// inner do-while: 3 success + 5 exception = 8
// last outer iteration: 3 < 3 = 1
// 5(inner exception) + 3(inner success) * 1(transition to next iteration) = 8
// 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 8(transition to next iteration)) = 35
// 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 35(transition to next iteration)) = 116
assertBehavesSameVariations(116,
() -> {
int outer = 0;
do {
int inner = 0;
do {
sync(outer + inner);
inner++;
} while (inner < 3 && plainTest(inner));
outer++;
} while (outer < 3 && plainTest(outer));
},
(callback) -> {
MutableValue<Integer> outer = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c -> {
MutableValue<Integer> inner = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c2 -> {
beginAsync().thenRun(c3 -> {
async(outer.get() + inner.get(), c3);
}).thenRun(c3 -> {
inner.set(inner.get() + 1);
c3.complete(c3);
}).finish(c2);
}, () -> inner.get() < 3 && plainTest(inner.get())
).thenRun(c2 -> {
outer.set(outer.get() + 1);
c2.complete(c2);
}).finish(c);
}, () -> outer.get() < 3 && plainTest(outer.get())).finish(callback);
});
}

@Test
void testDoWhileLoopStackConstant() {
int depthWith100 = maxDoWhileStackDepthForIterations(100);
int depthWith10000 = maxDoWhileStackDepthForIterations(10_000);
assertEquals(depthWith100, depthWith10000,
"Stack depth should be constant regardless of iteration count");
Comment thread
vbabanin marked this conversation as resolved.
}

private int maxDoWhileStackDepthForIterations(final int iterations) {
MutableValue<Integer> counter = new MutableValue<>(0);
MutableValue<Integer> maxDepth = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c -> {
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
counter.set(counter.get() + 1);
c.complete(c);
}, () -> counter.get() < iterations).finish((v, t) -> {});
assertEquals(iterations, counter.get());
return maxDepth.get();
}

@Test
void testFinallyWithPlainInsideTry() {
// (in try: normal flow + exception + exception) * (in finally: normal + exception) = 6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.function.Consumer;
import java.util.function.Supplier;

import static java.lang.String.format;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -272,14 +273,16 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee
}

assertTrue(wasCalledFuture.isDone(), "callback should have been called");
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());
assertEquals(expectedException == null, actualException.get() == null,
"both or neither should have produced an exception");
format("both or neither should have produced an exception. Expected exception: %s, actual exception: %s",
expectedException,
actualException.get()));
if (expectedException != null) {
assertEquals(expectedException.getMessage(), actualException.get().getMessage());
assertEquals(expectedException.getClass(), actualException.get().getClass());
}
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());

listener.clear();
}
Expand Down