Skip to content

Commit

Permalink
add missile penalty
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackory committed Apr 18, 2022
1 parent 15b628a commit e8d4559
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 12 deletions.
1 change: 1 addition & 0 deletions envs/JSBSim/reward_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .relative_altitude_reward import RelativeAltitudeReward
from .heading_reward import HeadingReward
from .missile_posture_reward import MissilePostureReward
from .shoot_penalty_reward import ShootPenaltyReward
2 changes: 1 addition & 1 deletion envs/JSBSim/reward_functions/event_driven_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ def get_reward(self, task, env, agent_id):
reward -= 200
for missile in env.agents[agent_id].launch_missiles:
if missile.is_success:
reward += 80
reward += 100
return self._process(reward, agent_id)
32 changes: 32 additions & 0 deletions envs/JSBSim/reward_functions/shoot_penalty_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from .reward_function_base import BaseRewardFunction


class ShootPenaltyReward(BaseRewardFunction):
"""
ShootPenaltyReward
when launching a missile, give -10 reward for penalty,
to avoid launching all missiles at once
"""
def __init__(self, config):
super().__init__(config)

def reset(self, task, env):
self.pre_remaining_missiles = {agent_id: agent.num_missiles for agent_id, agent in env.agents.items()}
return super().reset(task, env)

def get_reward(self, task, env, agent_id):
"""
Reward is the sum of all the events.
Args:
task: task instance
env: environment instance
Returns:
(float): reward
"""
reward = 0
if task.remaining_missiles[agent_id] == self.pre_remaining_missiles[agent_id] - 1:
reward -= 10
self.pre_remaining_missiles[agent_id] = task.remaining_missiles[agent_id]
return self._process(reward, agent_id)
18 changes: 10 additions & 8 deletions envs/JSBSim/tasks/singlecombat_with_missle_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import deque

from .singlecombat_task import SingleCombatTask, HierarchicalSingleCombatTask
from ..reward_functions import AltitudeReward, PostureReward, MissilePostureReward, EventDrivenReward
from ..reward_functions import AltitudeReward, PostureReward, MissilePostureReward, EventDrivenReward, ShootPenaltyReward
from ..core.simulatior import MissileSimulator
from ..utils.utils import LLA2NEU, get_AO_TA_R

Expand Down Expand Up @@ -165,6 +165,7 @@ def __init__(self, config):
PostureReward(self.config),
AltitudeReward(self.config),
EventDrivenReward(self.config),
ShootPenaltyReward(self.config)
]

def load_observation_space(self):
Expand All @@ -179,22 +180,22 @@ def get_obs(self, env, agent_id):

def normalize_action(self, env, agent_id, action):
self._shoot_action[agent_id] = action[-1]
return super().normalize_action(env, agent_id, action[:-1])
return super().normalize_action(env, agent_id, action[:-1].astype(np.int32))

def reset(self, env):
self._shoot_action = {agent_id: 0 for agent_id in env.agents.keys()}
self._remaining_missiles = {agent_id: agent.num_missiles for agent_id, agent in env.agents.items()}
self.remaining_missiles = {agent_id: agent.num_missiles for agent_id, agent in env.agents.items()}
super().reset(env)

def step(self, env):
for agent_id, agent in env.agents.items():
# [RL-based missile launch with limited condition]
shoot_flag = agent.is_alive and self._shoot_action[agent_id] and self._remaining_missiles[agent_id] > 0
shoot_flag = agent.is_alive and self._shoot_action[agent_id] and self.remaining_missiles[agent_id] > 0
if shoot_flag:
new_missile_uid = agent_id + str(self._remaining_missiles[agent_id])
new_missile_uid = agent_id + str(self.remaining_missiles[agent_id])
env.add_temp_simulator(
MissileSimulator.create(parent=agent, target=agent.enemies[0], uid=new_missile_uid))
self._remaining_missiles[agent_id] -= 1
self.remaining_missiles[agent_id] -= 1


class HierarchicalSingleCombatShootTask(HierarchicalSingleCombatTask, SingleCombatShootMissileTask):
Expand All @@ -203,7 +204,8 @@ def __init__(self, config: str):
self.reward_functions = [
PostureReward(self.config),
AltitudeReward(self.config),
EventDrivenReward(self.config)
EventDrivenReward(self.config),
ShootPenaltyReward(self.config)
]

def load_observation_space(self):
Expand All @@ -220,7 +222,7 @@ def normalize_action(self, env, agent_id, action):
"""Convert high-level action into low-level action.
"""
self._shoot_action[agent_id] = action[-1]
return HierarchicalSingleCombatTask.normalize_action(self, env, agent_id, action[:-1])
return HierarchicalSingleCombatTask.normalize_action(self, env, agent_id, action[:-1].astype(np.int32))

def reset(self, env):
self._inner_rnn_states = {agent_id: np.zeros((1, 1, 128)) for agent_id in env.agents.keys()}
Expand Down
2 changes: 1 addition & 1 deletion scripts/train_selfplay.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
env="SingleCombat"
scenario="1v1/ShootMissile/HierarchySelfplay"
algo="ppo"
exp="beta_shoot_wp"
exp="penalty_shoot"
seed=1

echo "env is ${env}, scenario is ${scenario}, algo is ${algo}, exp is ${exp}, seed is ${seed}"
Expand Down
6 changes: 4 additions & 2 deletions test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
def convert(sample):
return np.concatenate((sample[0], np.expand_dims(sample[1], axis=0)))

episode_reward = 0
while True:
actions = np.array([[convert(envs.action_space.sample()) for _ in range(envs.num_agents)] for _ in range(parallel_num)])
obss, rewards, dones, infos = envs.step(actions)

episode_reward += rewards[:,0,:]
envs.render(mode='txt', filepath='JSBSimRecording.txt.acmi')
# terminate if any of the parallel envs has been done
if np.any(dones):
if np.all(dones):
print(episode_reward)
break
envs.close()

0 comments on commit e8d4559

Please sign in to comment.