spring-projects / spring-statemachine

Spring Statemachine is a framework for application developers to use state machine concepts with Spring.
1.57k stars 616 forks source link

Issue restoring persisted state machine with >=4 levels of submachine hierarchy #1166

Open patrickc410 opened 1 month ago

patrickc410 commented 1 month ago

I am having an issue restoring a persisted state machine with >=4 levels of submachine hierarchy to its correct state.

I believe the issue is with the AbstractStateMachine.resetStateMachineReactively method, but I am not quite sure.

I first found this issue working on a project which was using v3.3 of the spring-statemachine project. So I upgraded to v4.0, but the issue is still present

I have created 3 simple JUnit test cases to demonstrate the problem.

Test Case 1) 4 Levels of Hierarchy, with persist and restore between each event (FAILS)

The first test case is with 4 levels of submachine hierarchy (machines M1, M2, M3, and M4). When we arrive at submachine M4, I noticed that persisting and restoring always brings us back to the first state of M4, not the correct state of M4

hierarchical-sm-4-levels

package com.example.demo;

import org.junit.jupiter.api.Test;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.config.StateMachineBuilder;
import org.springframework.statemachine.listener.StateMachineListener;
import org.springframework.statemachine.persist.DefaultStateMachinePersister;
import org.springframework.statemachine.persist.StateMachinePersister;
import reactor.core.publisher.Mono;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;

import static com.example.demo.HierarchicalStateMachineTestHelper.getCurrentStateChain;

/**
 * Test class for a hierarchical state machine with 4 levels of hierarchy
 */
public class HierarchicalSM4LevelsPersistRestoreTest {

    private enum States {
        M1_START,
        M1_MIDDLE,
        M1_M2,
        M1_END,
        M2_START,
        M2_MIDDLE,
        M2_M3,
        M2_END,
        M3_START,
        M3_MIDDLE,
        M3_M4,
        M3_END,
        M4_START,
        M4_MIDDLE,
        M4_END
    }

    private enum Events {
        M1_START_TO_MIDDLE,
        M1_MIDDLE_TO_M2,
        M1_M2_TO_END,
        M2_START_TO_MIDDLE,
        M2_MIDDLE_TO_M3,
        M2_M3_TO_END,
        M3_START_TO_MIDDLE,
        M3_MIDDLE_TO_M4,
        M3_M4_TO_END,
        M4_START_TO_MIDDLE,
        M4_MIDDLE_TO_END
    }

    @Test
    public void whenStateMachineHas4LevelsOfHierarchy_shouldPersistAndRestoreToSameState() throws Exception {
        // Constants
        final String machineName = "my-machine-name";

        List<Events> events = Arrays.asList(
                Events.M1_START_TO_MIDDLE,  // M1: start        -> M1: middle
                Events.M1_MIDDLE_TO_M2,     // M1: start        -> M1: M2 (M2: start)
                Events.M2_START_TO_MIDDLE,  // M2: start        -> M2: middle
                Events.M2_MIDDLE_TO_M3,     // M2: middle       -> M2: M3 (M3: start)
                Events.M3_START_TO_MIDDLE,  // M3: start        -> M3: middle
                Events.M3_MIDDLE_TO_M4,     // M3: middle       -> M3: M4 (M4: start)
                Events.M4_START_TO_MIDDLE,  // M4: start        -> M4: middle
                Events.M4_MIDDLE_TO_END,    // M4: middle       -> M4: end
                Events.M3_M4_TO_END,        // M3: M4 (M4: end) -> M3: end
                Events.M2_M3_TO_END,        // M2: M3 (M3: end) -> M2: end
                Events.M1_M2_TO_END         // M1: M2 (M2: end) -> M1: end
        );

        List<List<States>> expectedStatesEntered = List.of(
                List.of(States.M1_START),
                List.of(States.M1_MIDDLE),
                List.of(States.M1_M2),
                List.of(States.M1_M2, States.M2_START),
                List.of(States.M1_M2, States.M2_MIDDLE),
                List.of(States.M1_M2, States.M2_M3),
                List.of(States.M1_M2, States.M2_M3, States.M3_START),
                List.of(States.M1_M2, States.M2_M3, States.M3_MIDDLE),
                List.of(States.M1_M2, States.M2_M3, States.M3_M4),
                List.of(States.M1_M2, States.M2_M3, States.M3_M4, States.M4_START),
                List.of(States.M1_M2, States.M2_M3, States.M3_M4, States.M4_MIDDLE),
                List.of(States.M1_M2, States.M2_M3, States.M3_M4, States.M4_END),
                List.of(States.M1_M2, States.M2_M3, States.M3_END),
                List.of(States.M1_M2, States.M2_END),
                List.of(States.M1_END)
        );

        // Build listeners
        var listenerCollector = new HierarchicalStateMachineTestHelper.StateMachineStateEntryCollectorListener<States, Events>();

        // Build persistence
        var persist = new HierarchicalStateMachineTestHelper.InMemoryStateMachinePersist<States, Events, String>();
        StateMachinePersister<States, Events, String> persister = new DefaultStateMachinePersister<>(persist);

        // Build state machine
        StateMachine<States, Events> stateMachine = buildStateMachine(List.of(listenerCollector));
        stateMachine.startReactively().block();

        for (Events event: events) {
            // Send event
            stateMachine.sendEvent(Mono.just(MessageBuilder.withPayload(event).build())).blockLast();

            // Persist
            List<States> currentStateChainBeforePersist = getCurrentStateChain(stateMachine);
            persister.persist(stateMachine, machineName);

            // Restore
            StateMachine<States, Events> newStateMachine = buildStateMachine(List.of(listenerCollector));
            stateMachine = persister.restore(newStateMachine, machineName);
            List<States> currentStateChainAfterRestore = getCurrentStateChain(stateMachine);

            // Compare
            assert Objects.equals(currentStateChainBeforePersist, currentStateChainAfterRestore)
                    : "State machine state chain before and after persist/restore should be the same, but was not. \n\tBefore persist: %s\n\tAfter restore:  %s".formatted(currentStateChainBeforePersist, currentStateChainAfterRestore);
        }

        // States entered
        var statesEntered = listenerCollector.getStatesEnteredStateChains();
        assert Objects.equals(statesEntered, expectedStatesEntered)
                : "States entered was not as expected. \n\tExpected: %s\n\tActual:    %s".formatted(expectedStatesEntered, statesEntered);
    }

    /**
     * Build a hierarchical state machine with 4 levels of hierarchy
     * @param listeners List of state machine listeners
     * @return New state machine
     */
    private static StateMachine<States, Events> buildStateMachine(List<StateMachineListener<States, Events>> listeners) throws Exception {
        StateMachineBuilder.Builder<States, Events> builder = StateMachineBuilder.builder();

        // Listeners
        if (listeners != null && listeners.size() > 0) {
            for (var listener : listeners) {
                builder.configureConfiguration().withConfiguration().listener(listener);
            }
        }

        builder.configureStates()
            .withStates()
            .initial(States.M1_START)
            .state(States.M1_START)
            .state(States.M1_MIDDLE)
            .state(States.M1_M2)
            .state(States.M1_END)
            .and()
            .withStates()
                .parent(States.M1_M2)
                .initial(States.M2_START)
                .state(States.M2_START)
                .state(States.M2_MIDDLE)
                .state(States.M2_M3)
                .state(States.M2_END)
                .and()
                .withStates()
                    .parent(States.M2_M3)
                    .initial(States.M3_START)
                    .state(States.M3_START)
                    .state(States.M3_MIDDLE)
                    .state(States.M3_M4)
                    .state(States.M3_END)
                    .and()
                    .withStates()
                        .parent(States.M3_M4)
                        .initial(States.M4_START)
                        .state(States.M4_START)
                        .state(States.M4_MIDDLE)
                        .state(States.M4_END);

        builder.configureTransitions()
            .withExternal()
                .source(States.M1_START)
                .target(States.M1_MIDDLE)
                .event(Events.M1_START_TO_MIDDLE)
            .and()
            .withExternal()
                .source(States.M1_MIDDLE)
                .target(States.M1_M2)
                .event(Events.M1_MIDDLE_TO_M2)
            .and()
            .withExternal()
                .source(States.M1_M2)
                .target(States.M1_END)
                .event(Events.M1_M2_TO_END);

        builder.configureTransitions()
            .withExternal()
                .state(States.M1_M2) // parent
                .source(States.M2_START)
                .target(States.M2_MIDDLE)
                .event(Events.M2_START_TO_MIDDLE)
            .and()
            .withExternal()
                .state(States.M1_M2) // parent
                .source(States.M2_MIDDLE)
                .target(States.M2_M3)
                .event(Events.M2_MIDDLE_TO_M3)
            .and()
            .withExternal()
                .state(States.M1_M2) // parent
                .source(States.M2_M3)
                .target(States.M2_END)
                .event(Events.M2_M3_TO_END);

        builder.configureTransitions()
            .withExternal()
                .state(States.M2_M3) // parent
                .source(States.M3_START)
                .target(States.M3_MIDDLE)
                .event(Events.M3_START_TO_MIDDLE)
            .and()
            .withExternal()
                .state(States.M2_M3) // parent
                .source(States.M3_MIDDLE)
                .target(States.M3_M4)
                .event(Events.M3_MIDDLE_TO_M4)
            .and()
            .withExternal()
                .state(States.M2_M3) // parent
                .source(States.M3_M4)
                .target(States.M3_END)
                .event(Events.M3_M4_TO_END);

        builder.configureTransitions()
            .withExternal()
                .state(States.M3_M4) // parent
                .source(States.M4_START)
                .target(States.M4_MIDDLE)
                .event(Events.M4_START_TO_MIDDLE)
            .and()
            .withExternal()
                .state(States.M3_M4) // parent
                .source(States.M4_MIDDLE)
                .target(States.M4_END)
                .event(Events.M4_MIDDLE_TO_END);

        return builder.build();
    }
}

Below is the failure:

java.lang.AssertionError: State machine state chain before and after persist/restore should be the same, but was not. 
    Before persist: [M1_M2, M2_M3, M3_M4, M4_MIDDLE]
    After restore:  [M1_M2, M2_M3, M3_M4, M4_START]
    at com.example.demo.HierarchicalSM4LevelsPersistRestoreTest.whenStateMachineHas4LevelsOfHierarchy_shouldPersistAndRestoreToSameState(HierarchicalSM4LevelsPersistRestoreTest.java:118)
    at java.base/java.lang.reflect.Method.invoke(Method.java:578)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)

Test Case 2) 4 Levels of Hierarchy, with no persist and restore between each event (PASSES)

This test case is to rule out the idea that the state machine is mis-configured from the start. When we do not persist and restore the machine between each event, it moves to all the states that we expect.

package com.example.demo;

import org.junit.jupiter.api.Test;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.config.StateMachineBuilder;
import org.springframework.statemachine.listener.StateMachineListener;
import org.springframework.statemachine.persist.DefaultStateMachinePersister;
import org.springframework.statemachine.persist.StateMachinePersister;
import reactor.core.publisher.Mono;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;

import static com.example.demo.HierarchicalStateMachineTestHelper.getCurrentStateChain;

/**
 * Test class for a hierarchical state machine with 4 levels of hierarchy
 */
public class HierarchicalSM4LevelsNoPersistRestoreTest {

    private enum States {
        M1_START,
        M1_MIDDLE,
        M1_M2,
        M1_END,
        M2_START,
        M2_MIDDLE,
        M2_M3,
        M2_END,
        M3_START,
        M3_MIDDLE,
        M3_M4,
        M3_END,
        M4_START,
        M4_MIDDLE,
        M4_END
    }

    private enum Events {
        M1_START_TO_MIDDLE,
        M1_MIDDLE_TO_M2,
        M1_M2_TO_END,
        M2_START_TO_MIDDLE,
        M2_MIDDLE_TO_M3,
        M2_M3_TO_END,
        M3_START_TO_MIDDLE,
        M3_MIDDLE_TO_M4,
        M3_M4_TO_END,
        M4_START_TO_MIDDLE,
        M4_MIDDLE_TO_END
    }

    @Test
    public void whenStateMachineHas4LevelsOfHierarchy_shouldEnterExpectedStates() throws Exception {
        List<Events> events = Arrays.asList(
                Events.M1_START_TO_MIDDLE,  // M1: start        -> M1: middle
                Events.M1_MIDDLE_TO_M2,     // M1: start        -> M1: M2 (M2: start)
                Events.M2_START_TO_MIDDLE,  // M2: start        -> M2: middle
                Events.M2_MIDDLE_TO_M3,     // M2: middle       -> M2: M3 (M3: start)
                Events.M3_START_TO_MIDDLE,  // M3: start        -> M3: middle
                Events.M3_MIDDLE_TO_M4,     // M3: middle       -> M3: M4 (M4: start)
                Events.M4_START_TO_MIDDLE,  // M4: start        -> M4: middle
                Events.M4_MIDDLE_TO_END,    // M4: middle       -> M4: end
                Events.M3_M4_TO_END,        // M3: M4 (M4: end) -> M3: end
                Events.M2_M3_TO_END,        // M2: M3 (M3: end) -> M2: end
                Events.M1_M2_TO_END         // M1: M2 (M2: end) -> M1: end
        );

        List<List<States>> expectedStatesEntered = List.of(
                List.of(States.M1_START),
                List.of(States.M1_MIDDLE),
                List.of(States.M1_M2),
                List.of(States.M1_M2, States.M2_START),
                List.of(States.M1_M2, States.M2_MIDDLE),
                List.of(States.M1_M2, States.M2_M3),
                List.of(States.M1_M2, States.M2_M3, States.M3_START),
                List.of(States.M1_M2, States.M2_M3, States.M3_MIDDLE),
                List.of(States.M1_M2, States.M2_M3, States.M3_M4),
                List.of(States.M1_M2, States.M2_M3, States.M3_M4, States.M4_START),
                List.of(States.M1_M2, States.M2_M3, States.M3_M4, States.M4_MIDDLE),
                List.of(States.M1_M2, States.M2_M3, States.M3_M4, States.M4_END),
                List.of(States.M1_M2, States.M2_M3, States.M3_END),
                List.of(States.M1_M2, States.M2_END),
                List.of(States.M1_END)
        );

        // Build listeners
        var listenerCollector = new HierarchicalStateMachineTestHelper.StateMachineStateEntryCollectorListener<States, Events>();

        // Build state machine
        StateMachine<States, Events> stateMachine = buildStateMachine(List.of(listenerCollector));
        stateMachine.startReactively().block();

        for (Events event: events) {
            // Send event
            stateMachine.sendEvent(Mono.just(MessageBuilder.withPayload(event).build())).blockLast();
        }

        // States entered
        var statesEntered = listenerCollector.getStatesEnteredStateChains();
        assert Objects.equals(statesEntered, expectedStatesEntered)
                : "States entered was not as expected. \n\tExpected: %s\n\tActual:    %s".formatted(expectedStatesEntered, statesEntered);
    }

    /**
     * Build a hierarchical state machine with 4 levels of hierarchy
     * @param listeners List of state machine listeners
     * @return New state machine
     */
    private static StateMachine<States, Events> buildStateMachine(List<StateMachineListener<States, Events>> listeners) throws Exception {
        StateMachineBuilder.Builder<States, Events> builder = StateMachineBuilder.builder();

        // Listeners
        if (listeners != null && listeners.size() > 0) {
            for (var listener : listeners) {
                builder.configureConfiguration().withConfiguration().listener(listener);
            }
        }

        builder.configureStates()
            .withStates()
            .initial(States.M1_START)
            .state(States.M1_START)
            .state(States.M1_MIDDLE)
            .state(States.M1_M2)
            .state(States.M1_END)
            .and()
            .withStates()
                .parent(States.M1_M2)
                .initial(States.M2_START)
                .state(States.M2_START)
                .state(States.M2_MIDDLE)
                .state(States.M2_M3)
                .state(States.M2_END)
                .and()
                .withStates()
                    .parent(States.M2_M3)
                    .initial(States.M3_START)
                    .state(States.M3_START)
                    .state(States.M3_MIDDLE)
                    .state(States.M3_M4)
                    .state(States.M3_END)
                    .and()
                    .withStates()
                        .parent(States.M3_M4)
                        .initial(States.M4_START)
                        .state(States.M4_START)
                        .state(States.M4_MIDDLE)
                        .state(States.M4_END);

        builder.configureTransitions()
            .withExternal()
                .source(States.M1_START)
                .target(States.M1_MIDDLE)
                .event(Events.M1_START_TO_MIDDLE)
            .and()
            .withExternal()
                .source(States.M1_MIDDLE)
                .target(States.M1_M2)
                .event(Events.M1_MIDDLE_TO_M2)
            .and()
            .withExternal()
                .source(States.M1_M2)
                .target(States.M1_END)
                .event(Events.M1_M2_TO_END);

        builder.configureTransitions()
            .withExternal()
                .state(States.M1_M2) // parent
                .source(States.M2_START)
                .target(States.M2_MIDDLE)
                .event(Events.M2_START_TO_MIDDLE)
            .and()
            .withExternal()
                .state(States.M1_M2) // parent
                .source(States.M2_MIDDLE)
                .target(States.M2_M3)
                .event(Events.M2_MIDDLE_TO_M3)
            .and()
            .withExternal()
                .state(States.M1_M2) // parent
                .source(States.M2_M3)
                .target(States.M2_END)
                .event(Events.M2_M3_TO_END);

        builder.configureTransitions()
            .withExternal()
                .state(States.M2_M3) // parent
                .source(States.M3_START)
                .target(States.M3_MIDDLE)
                .event(Events.M3_START_TO_MIDDLE)
            .and()
            .withExternal()
                .state(States.M2_M3) // parent
                .source(States.M3_MIDDLE)
                .target(States.M3_M4)
                .event(Events.M3_MIDDLE_TO_M4)
            .and()
            .withExternal()
                .state(States.M2_M3) // parent
                .source(States.M3_M4)
                .target(States.M3_END)
                .event(Events.M3_M4_TO_END);

        builder.configureTransitions()
            .withExternal()
                .state(States.M3_M4) // parent
                .source(States.M4_START)
                .target(States.M4_MIDDLE)
                .event(Events.M4_START_TO_MIDDLE)
            .and()
            .withExternal()
                .state(States.M3_M4) // parent
                .source(States.M4_MIDDLE)
                .target(States.M4_END)
                .event(Events.M4_MIDDLE_TO_END);

        return builder.build();
    }
}

Test Case 3) 3 Levels of Hierarchy, with persist and restore between each event (PASSES)

When there are only three levels of submachine hierarchy, the issue with restoring to the correct state does not appear.

hierarchical-sm-3-levels

package com.example.demo;

import org.junit.jupiter.api.Test;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.config.StateMachineBuilder;
import org.springframework.statemachine.listener.StateMachineListener;
import org.springframework.statemachine.persist.DefaultStateMachinePersister;
import org.springframework.statemachine.persist.StateMachinePersister;
import reactor.core.publisher.Mono;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;

import static com.example.demo.HierarchicalStateMachineTestHelper.getCurrentStateChain;
import static com.example.demo.HierarchicalStateMachineTestHelper.paddedPrintln;

/**
 * Test class for hierarchical state machines with 3 levels of hierarchy
 */
public class HierarchicalSM3LevelsPersistRestoreTest {

    private enum States {
        // region M1
        M1_START,
        M1_MIDDLE,
        M1_M2,
        M1_END,
        // endregion

        // region M2
        M2_START,
        M2_MIDDLE,
        M2_M3,
        M2_END,
        // endregion

        // region M3
        M3_START,
        M3_MIDDLE,
        M3_END
        // endregion
    }

    private enum Events {
        // region M1
        M1_START_TO_MIDDLE,
        M1_MIDDLE_TO_M2,
        M1_M2_TO_END,
        // endregion

        // region M2
        M2_START_TO_MIDDLE,
        M2_MIDDLE_TO_M3,
        M2_M3_TO_END,
        // endregion

        // region M3
        M3_START_TO_MIDDLE,
        M3_MIDDLE_TO_END
        // endregion
    }

    @Test
    public void whenStateMachineHas3LevelsOfHierarchy_shouldPersistAndRestoreToSameState() throws Exception {
        // Constants
        final String machineName = "my-machine-name";

        List<Events> events = Arrays.asList(
            Events.M1_START_TO_MIDDLE,  // M1: start        -> M1: middle
            Events.M1_MIDDLE_TO_M2,     // M1: start        -> M1: M2 (M2: start)
            Events.M2_START_TO_MIDDLE,  // M2: start        -> M2: middle
            Events.M2_MIDDLE_TO_M3,     // M2: middle       -> M2: M3 (M3: start)
            Events.M3_START_TO_MIDDLE,  // M3: start        -> M3: middle
            Events.M3_MIDDLE_TO_END,    // M3: middle       -> M3: end
            Events.M2_M3_TO_END,        // M2: M3 (M3: end) -> M2: end
            Events.M1_M2_TO_END         // M1: M2 (M2: end) -> M1: end
        );

        List<List<States>> expectedStatesEntered = List.of(
            List.of(States.M1_START),
            List.of(States.M1_MIDDLE),
            List.of(States.M1_M2),
            List.of(States.M1_M2, States.M2_START),
            List.of(States.M1_M2, States.M2_MIDDLE),
            List.of(States.M1_M2, States.M2_M3),
            List.of(States.M1_M2, States.M2_M3, States.M3_START),
            List.of(States.M1_M2, States.M2_M3, States.M3_MIDDLE),
            List.of(States.M1_M2, States.M2_M3, States.M3_END),
            List.of(States.M1_M2, States.M2_END),
            List.of(States.M1_END)
        );

        // Build listeners
        var listenerCollector = new HierarchicalStateMachineTestHelper.StateMachineStateEntryCollectorListener<States, Events>();

        // Build persistence
        var persist = new HierarchicalStateMachineTestHelper.InMemoryStateMachinePersist<States, Events, String>();
        StateMachinePersister<States, Events, String> persister = new DefaultStateMachinePersister<>(persist);

        // Build state machine
        StateMachine<States, Events> stateMachine = buildStateMachine(List.of(listenerCollector));
        stateMachine.startReactively().block();

        for (Events event: events) {
            // Send event
            stateMachine.sendEvent(Mono.just(MessageBuilder.withPayload(event).build())).blockLast();

            // Persist
            List<States> currentStateChainBeforePersist = getCurrentStateChain(stateMachine);
            persister.persist(stateMachine, machineName);

            // Restore
            StateMachine<States, Events> newStateMachine = buildStateMachine(List.of(listenerCollector));
            stateMachine = persister.restore(newStateMachine, machineName);
            List<States> currentStateChainAfterRestore = getCurrentStateChain(stateMachine);

            // Compare
            assert Objects.equals(currentStateChainBeforePersist, currentStateChainAfterRestore)
                : "State machine state chain before and after persist/restore should be the same, but was not. \n\tBefore persist: %s\n\tAfter restore:  %s".formatted(currentStateChainBeforePersist, currentStateChainAfterRestore);
        }

        // States entered
        var statesEntered = listenerCollector.getStatesEnteredStateChains();
        assert Objects.equals(statesEntered, expectedStatesEntered)
            : "States entered was not as expected. \n\tExpected: %s\n\tActual:    %s".formatted(expectedStatesEntered, statesEntered);
    }

    /**
     * Build a hierarchical state machine with 3 levels of hierarchy
     * @param listeners List of state machine listeners
     * @return New state machine
     */
    private static StateMachine<States, Events> buildStateMachine(List<StateMachineListener<States, Events>> listeners) throws Exception {
        StateMachineBuilder.Builder<States, Events> builder = StateMachineBuilder.builder();

        // Listeners
        if (listeners != null && listeners.size() > 0) {
            for (var listener : listeners) {
                builder.configureConfiguration().withConfiguration().listener(listener);
            }
        }

        builder.configureStates()
            .withStates()
            .initial(States.M1_START)
            .state(States.M1_START)
            .state(States.M1_MIDDLE)
            .state(States.M1_M2)
            .state(States.M1_END)
            .and()
            .withStates()
                .parent(States.M1_M2)
                .initial(States.M2_START)
                .state(States.M2_START)
                .state(States.M2_MIDDLE)
                .state(States.M2_M3)
                .state(States.M2_END)
                .and()
                .withStates()
                    .parent(States.M2_M3)
                    .initial(States.M3_START)
                    .state(States.M3_START)
                    .state(States.M3_MIDDLE)
                    .state(States.M3_END);

        builder.configureTransitions()
            .withExternal()
                .source(States.M1_START)
                .target(States.M1_MIDDLE)
                .event(Events.M1_START_TO_MIDDLE)
            .and()
            .withExternal()
                .source(States.M1_MIDDLE)
                .target(States.M1_M2)
                .event(Events.M1_MIDDLE_TO_M2)
                .and()
            .withExternal()
            .source(States.M1_M2)
                .target(States.M1_END)
                .event(Events.M1_M2_TO_END);

        builder.configureTransitions()
            .withExternal()
                .state(States.M1_M2) // parent
                .source(States.M2_START)
                .target(States.M2_MIDDLE)
                .event(Events.M2_START_TO_MIDDLE)
            .and()
            .withExternal()
                .state(States.M1_M2) // parent
                .source(States.M2_MIDDLE)
                .target(States.M2_M3)
                .event(Events.M2_MIDDLE_TO_M3)
            .and()
            .withExternal()
                .state(States.M1_M2) // parent
                .source(States.M2_M3)
                .target(States.M2_END)
                .event(Events.M2_M3_TO_END);

        builder.configureTransitions()
            .withExternal()
                .state(States.M2_M3) // parent
                .source(States.M3_START)
                .target(States.M3_MIDDLE)
                .event(Events.M3_START_TO_MIDDLE)
            .and()
            .withExternal()
                .state(States.M2_M3) // parent
                .source(States.M3_MIDDLE)
                .target(States.M3_END)
                .event(Events.M3_MIDDLE_TO_END);

        return builder.build();
    }
}

Helper Class used throughout test cases

The helper class has a few inner classes used throughout the test cases:

package com.example.demo;

import org.springframework.statemachine.StateContext;
import org.springframework.statemachine.StateMachine;
import org.springframework.statemachine.StateMachineContext;
import org.springframework.statemachine.StateMachinePersist;
import org.springframework.statemachine.listener.StateMachineListenerAdapter;

import java.util.*;

public class HierarchicalStateMachineTestHelper {
    /**
     * Helper method to get the current state chain of a state machine;
     * The "state chain" is the list of state IDs, one for each machine and submachine,
     * from the top-level state machine to the bottom-level submachine whose state we are in
     * @param stateMachine the state machine
     * @return the current state chain
     * @param <S> the state type
     */
    public static <S> List<S> getCurrentStateChain(StateMachine<S, ?> stateMachine) {
        var currentState = stateMachine.getState();

        if (currentState == null) {
            return null;
        }

        return currentState.getIds() == null ? null : currentState.getIds().stream().toList();
    }

    /**
     * In-memory implementation of state machine persistence
     * @param <S> the state type
     * @param <E> the event type
     * @param <T> the context object type
     */
    public static class InMemoryStateMachinePersist<S, E, T> implements StateMachinePersist<S, E, T> {

        /**
         * The map of context objects to state machine contexts
         */
        private final Map<T, StateMachineContext<S, E>> contexts = new HashMap<>();

        @Override
        public void write(StateMachineContext<S, E> context, T contextObj) {
            contexts.put(contextObj, context);
        }

        @Override
        public StateMachineContext<S, E> read(T contextObj) {
            return contexts.get(contextObj);
        }
    }

    /**
     * State machine listener that collects the states entered
     * @param <S> the state type
     * @param <E> the event type
     */
    public static class StateMachineStateEntryCollectorListener<S, E> extends StateMachineListenerAdapter<S, E> {
        private final List<List<S>> statesEnteredStateChains = new ArrayList<>();

        /**
         * Get the state chains of states entered
         * @return the state chains of states entered
         */
        public List<List<S>> getStatesEnteredStateChains() {
            return statesEnteredStateChains;
        }

        @Override
        public void stateContext(StateContext<S, E> stateContext) {
            if (Objects.equals(stateContext.getStage(), StateContext.Stage.STATE_ENTRY)) {
                statesEnteredStateChains.add(stateContext.getStateMachine().getState().getIds().stream().toList());
            }
        }
    }
}

Please let me know if you notice anything about the way I configured the state machines and submachines or the way I am persisting and restoring that is not correct