forked from ConvLab/ConvLab
-
Notifications
You must be signed in to change notification settings - Fork 0
/
base.py
176 lines (155 loc) · 6.11 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Modified by Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod
from convlab.agent.net import net_util
from convlab.lib import logger, util
from convlab.lib.decorator import lab_api
import numpy as np
import pydash as ps
logger = logger.get_logger(__name__)
class Algorithm(ABC):
'''
Abstract class ancestor to all Algorithms,
specifies the necessary design blueprint for agent to work in Lab.
Mostly, implement just the abstract methods and properties.
'''
def __init__(self, agent, global_nets=None):
'''
@param {*} agent is the container for algorithm and related components, and interfaces with env.
'''
self.agent = agent
self.algorithm_spec = agent.agent_spec['algorithm']
self.name = self.algorithm_spec['name']
self.memory_spec = agent.agent_spec['memory']
self.net_spec = agent.agent_spec.get('net', None)
self.body = self.agent.body
self.init_algorithm_params()
self.init_nets(global_nets)
logger.info(util.self_desc(self))
@abstractmethod
@lab_api
def init_algorithm_params(self):
'''Initialize other algorithm parameters'''
raise NotImplementedError
@abstractmethod
@lab_api
def init_nets(self, global_nets=None):
'''Initialize the neural network from the spec'''
raise NotImplementedError
@lab_api
def post_init_nets(self):
'''
Method to conditionally load models.
Call at the end of init_nets() after setting self.net_names
'''
assert hasattr(self, 'net_names')
if util.in_eval_lab_modes():
logger.info(f'Loaded algorithm models for lab_mode: {util.get_lab_mode()}')
self.load()
else:
logger.info(f'Initialized algorithm models for lab_mode: {util.get_lab_mode()}')
@lab_api
def calc_pdparam(self, x, evaluate=True, net=None):
'''
To get the pdparam for action policy sampling, do a forward pass of the appropriate net, and pick the correct outputs.
The pdparam will be the logits for discrete prob. dist., or the mean and std for continuous prob. dist.
'''
raise NotImplementedError
def nanflat_to_data_a(self, data_name, nanflat_data_a):
'''Reshape nanflat_data_a, e.g. action_a, from a single pass back into the API-conforming data_a'''
data_names = (data_name,)
data_a, = self.agent.agent_space.aeb_space.init_data_s(data_names, a=self.agent.a)
for body, data in zip(self.agent.nanflat_body_a, nanflat_data_a):
e, b = body.e, body.b
data_a[(e, b)] = data
return data_a
@lab_api
def act(self, state):
'''Standard act method.'''
raise NotImplementedError
return action
@abstractmethod
@lab_api
def sample(self):
'''Samples a batch from memory'''
raise NotImplementedError
return batch
@abstractmethod
@lab_api
def train(self):
'''Implement algorithm train, or throw NotImplementedError'''
if util.in_eval_lab_modes():
return np.nan
raise NotImplementedError
@abstractmethod
@lab_api
def update(self):
'''Implement algorithm update, or throw NotImplementedError'''
raise NotImplementedError
@lab_api
def save(self, ckpt=None):
'''Save net models for algorithm given the required property self.net_names'''
if not hasattr(self, 'net_names'):
logger.info('No net declared in self.net_names in init_nets(); no models to save.')
else:
net_util.save_algorithm(self, ckpt=ckpt)
@lab_api
def load(self):
'''Load net models for algorithm given the required property self.net_names'''
if not hasattr(self, 'net_names'):
logger.info('No net declared in self.net_names in init_nets(); no models to load.')
else:
net_util.load_algorithm(self)
# set decayable variables to final values
for k, v in vars(self).items():
if k.endswith('_scheduler'):
var_name = k.replace('_scheduler', '')
if hasattr(v, 'end_val'):
setattr(self.body, var_name, v.end_val)
# NOTE optional extension for multi-agent-env
@lab_api
def space_act(self, state_a):
'''Interface-level agent act method for all its bodies. Resolves state to state; get action and compose into action.'''
data_names = ('action',)
action_a, = self.agent.agent_space.aeb_space.init_data_s(data_names, a=self.agent.a)
for eb, body in util.ndenumerate_nonan(self.agent.body_a):
state = state_a[eb]
self.body = body
action_a[eb] = self.act(state)
# set body reference back to default
self.body = self.agent.nanflat_body_a[0]
return action_a
@lab_api
def space_sample(self):
'''Samples a batch from memory'''
batches = []
for body in self.agent.nanflat_body_a:
self.body = body
batches.append(self.sample())
# set body reference back to default
self.body = self.agent.nanflat_body_a[0]
batch = util.concat_batches(batches)
batch = util.to_torch_batch(batch, self.net.device, self.body.memory.is_episodic)
return batch
@lab_api
def space_train(self):
if util.in_eval_lab_modes():
return np.nan
losses = []
for body in self.agent.nanflat_body_a:
self.body = body
losses.append(self.train())
# set body reference back to default
self.body = self.agent.nanflat_body_a[0]
loss_a = self.nanflat_to_data_a('loss', losses)
return loss_a
@lab_api
def space_update(self):
explore_vars = []
for body in self.agent.nanflat_body_a:
self.body = body
explore_vars.append(self.update())
# set body reference back to default
self.body = self.agent.nanflat_body_a[0]
explore_var_a = self.nanflat_to_data_a('explore_var', explore_vars)
return explore_var_a