diff --git a/util/gridworld.py b/util/gridworld.py index 77a7bbb..21c4c00 100644 --- a/util/gridworld.py +++ b/util/gridworld.py @@ -193,7 +193,9 @@ def simulate_mdp(mdp, policy, max_iterations=20) -> list[Step]: steps = [] state = mdp.start current_iteration = 0 - while True: + while current_iteration != max_iterations and \ + not mdp.is_terminal(state) and \ + mdp.is_reachable(state): current_iteration += 1 action = policy(state) state_probs = [(s, p) for s, p in mdp.transition(state, action).items()] @@ -203,9 +205,7 @@ def simulate_mdp(mdp, policy, max_iterations=20) -> list[Step]: reward = mdp.reward(state, action, next_state) steps.append(Step(state, action, reward)) state = next_state - if current_iteration == max_iterations or mdp.is_terminal(state): - steps.append(Step(next_state, None, 0.0)) - break + steps.append(Step(state, None, 0.0)) return steps #--------------------------