Skip to content

Commit

Permalink
add reinforce
Browse files Browse the repository at this point in the history
  • Loading branch information
sungjinl committed Jun 1, 2019
1 parent ca49490 commit a2ead35
Showing 1 changed file with 193 additions and 13 deletions.
206 changes: 193 additions & 13 deletions convlab/spec/demo.json
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
"name": "DQN",
"action_pdtype": "Argmax",
"action_policy": "rule_guide_epsilon_greedy",
"rule_guide_max_epi": 300,
"rule_guide_max_epi": 1,
"rule_guide_frequency": 3,
"explore_var_spec": {
"name": "linear_decay",
Expand Down Expand Up @@ -232,6 +232,104 @@
}
}
},
"onenet_rule_reinforce_template": {
"agent": [{
"name": "DialogAgent",
"nlu": {
"name": "OneNetLU",
"model_file": "https://convlab.blob.core.windows.net/models/onenet.tar.gz"
},
"dst": {
"name": "RuleDST"
},
"nlg": {
"name": "MultiwozTemplateNLG",
"is_user": false
},
"state_encoder": {
"name": "MultiWozStateEncoder"
},
"action_decoder": {
"name": "MultiWozVocabActionDecoder"
},
"algorithm": {
"name": "Reinforce",
"action_pdtype": "default",
"action_policy": "rule_guide_default",
"rule_guide_max_epi": 1,
"rule_guide_frequency": 3,
"explore_var_spec": null,
"gamma": 0.98,
"entropy_coef_spec": {
"name": "linear_decay",
"start_val": 0.01,
"end_val": 0.001,
"start_step": 1000,
"end_step": 50000
},
"training_frequency": 1
},
"memory": {
"name": "OnPolicyReplay"
},
"net": {
"type": "MLPNet",
"hid_layers": [100],
"hid_layers_activation": "relu",
"clip_grad_val": null,
"loss_spec": {
"name": "MSELoss"
},
"optim_spec": {
"name": "Adam",
"lr": 0.001
},
"lr_scheduler_spec": {
"name": "StepLR",
"step_size": 1000,
"gamma": 0.999,
},
"gpu": false
}
}],
"env": [{
"name": "multiwoz",
"action_dim": 300,
"observation_dim": 393,
"max_t": 40,
"max_frame": 50000,
"nlu": {
"name": "OneNetLU",
"model_file": "https://convlab.blob.core.windows.net/models/onenet.tar.gz"
},
"user_policy": {
"name": "UserPolicyAgendaMultiWoz"
},
"sys_policy": {
"name": "RuleBasedMultiwozBot"
},
"nlg": {
"name": "MultiwozTemplateNLG",
"is_user": true
}
}],
"body": {
"product": "outer",
"num": 1
},
"meta": {
"distributed": false,
"num_eval": 100,
"eval_frequency": 1000,
"max_tick_unit": "total_t",
"max_trial": 1,
"max_session": 1,
"resources": {
"num_cpus": 1,
"num_gpus": 0
}
}
},
"rule_rule": {
"agent": [{
"name": "DialogAgent",
Expand Down Expand Up @@ -297,25 +395,25 @@
"name": "DQN",
"action_pdtype": "Argmax",
"action_policy": "rule_guide_epsilon_greedy",
"rule_guide_max_epi": 100,
"rule_guide_max_epi": 1,
"rule_guide_frequency": 3,
"explore_var_spec": {
"name": "linear_decay",
"start_val": 0.0,
"end_val": 0.0,
"start_val": 0.1,
"end_val": 0.05,
"start_step": 0,
"end_step": 800,
"end_step": 50000,
},
"gamma": 0.9,
"training_batch_iter": 100,
"training_iter": 3,
"training_frequency": 50,
"training_batch_iter": 1000,
"training_iter": 1,
"training_frequency": 100,
"training_start_step": 32
},
"memory": {
"name": "Replay",
"batch_size": 16,
"max_size": 10000,
"max_size": 50000,
"use_cer": false
},
"net": {
Expand All @@ -336,7 +434,7 @@
"gamma": 0.999,
},
"update_type": "replace",
"update_frequency": 50,
"update_frequency": 300,
"polyak_coef": 0,
"gpu": false
}
Expand All @@ -346,7 +444,7 @@
"action_dim": 300,
"observation_dim": 393,
"max_t": 40,
"max_frame": 500,
"max_frame": 50000,
"user_policy": {
"name": "UserPolicyAgendaMultiWoz"
},
Expand All @@ -360,8 +458,90 @@
},
"meta": {
"distributed": false,
"num_eval": 10,
"eval_frequency": 100,
"num_eval": 100,
"eval_frequency": 1000,
"max_tick_unit": "total_t",
"max_trial": 1,
"max_session": 1,
"resources": {
"num_cpus": 1,
"num_gpus": 0
}
}
},
"rule_reinforce": {
"agent": [{
"name": "DialogAgent",
"dst": {
"name": "RuleDST"
},
"state_encoder": {
"name": "MultiWozStateEncoder"
},
"action_decoder": {
"name": "MultiWozVocabActionDecoder"
},
"algorithm": {
"name": "Reinforce",
"action_pdtype": "default",
"action_policy": "rule_guide_default",
"rule_guide_max_epi": 1,
"rule_guide_frequency": 3,
"explore_var_spec": null,
"gamma": 0.98,
"entropy_coef_spec": {
"name": "linear_decay",
"start_val": 0.01,
"end_val": 0.001,
"start_step": 1000,
"end_step": 50000
},
"training_frequency": 1
},
"memory": {
"name": "OnPolicyReplay"
},
"net": {
"type": "MLPNet",
"hid_layers": [100],
"hid_layers_activation": "relu",
"clip_grad_val": null,
"loss_spec": {
"name": "MSELoss"
},
"optim_spec": {
"name": "Adam",
"lr": 0.001
},
"lr_scheduler_spec": {
"name": "StepLR",
"step_size": 1000,
"gamma": 0.999,
},
"gpu": false
}
}],
"env": [{
"name": "multiwoz",
"action_dim": 300,
"observation_dim": 393,
"max_t": 40,
"max_frame": 50000,
"user_policy": {
"name": "UserPolicyAgendaMultiWoz"
},
"sys_policy": {
"name": "RuleBasedMultiwozBot"
},
}],
"body": {
"product": "outer",
"num": 1
},
"meta": {
"distributed": false,
"num_eval": 100,
"eval_frequency": 1000,
"max_tick_unit": "total_t",
"max_trial": 1,
"max_session": 1,
Expand Down

0 comments on commit a2ead35

Please sign in to comment.