-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
parameters.jl
125 lines (113 loc) · 3.48 KB
/
parameters.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
struct DeviceParams
device_id::Int
device::Function # expected Flux.cpu or Flux.gpu
function DeviceParams(device_id)
if device_id >= 0 && CUDA.functional()
CUDA.device!(device_id)
device = gpu
@info "Set device: GPU with device_id=$(device_id)"
else
if device_id > 0
@warn "You've set device_id = $device_id, but CUDA.functional() is $(CUDA.functional())"
end
@info "Set device: CPU"
device_id = -1
device = cpu
end
new(device_id, device)
end
end
@with_kw struct PhysicalParams
L::Int
Nd::Int
M2::Float64
lam::Float64
end
function Base.getproperty(pp::PhysicalParams, s::Symbol)
s == :lattice_shape && return ntuple(_ -> pp.L, pp.Nd)
s == :m² && return getfield(pp, :M2)
s == :λ && return getfield(pp, :lam)
return getfield(pp, s)
end
@with_kw struct ModelParams
seed::Int = 2021
n_layers::Int = 16
hidden_sizes::Vector{Int} = [8, 8]
kernel_size::Int = 3
inC::Int = 1
outC::Int = 2
use_final_tanh::Bool = true
use_bn::Bool = false
end
@with_kw struct TrainingParams
seed::Int = 12345
batchsize::Int = 64
epochs::Int = 40
iterations::Int = 100
base_lr::Float64 = 0.001
opt::String = "Adam"
prior::String = "Normal{Float32}(0.f0, 1.f0)"
lr_scheduler::String = ""
pretrained::String = ""
end
struct HyperParams
configversion::VersionNumber
dp::DeviceParams
tp::TrainingParams
pp::PhysicalParams
mp::ModelParams
result_dir::String
end
function load_hyperparams(
config::Dict,
output_dirname::String;
device_id::Union{Nothing,Int}=nothing,
pretrained::Union{Nothing,String}=nothing,
result::AbstractString="result",
)::HyperParams
configversion = VersionNumber(string(config["config"]["version"]))
if !isnothing(device_id)
@info "override device id $(device_id)"
else
device_id = config["device"]["device_id"]
end
if !isnothing(pretrained)
@info "restore model from $(pretrained)"
config["training"]["pretrained"] = pretrained
end
dp = DeviceParams(device_id)
tp = ToStruct.tostruct(TrainingParams, config["training"])
pp = ToStruct.tostruct(PhysicalParams, config["physical"])
if !("use_bn" in keys(config["model"]))
config["model"]["use_bn"] = false
end
mp = ToStruct.tostruct(ModelParams, config["model"])
result_dir = abspath(joinpath(result, output_dirname))
return HyperParams(configversion, dp, tp, pp, mp, result_dir)
end
function _d(configpath::AbstractString)
foldername = splitext(basename(configpath))[begin]
return foldername
end
function load_hyperparams(
configpath::AbstractString,
output_dirname::String=_d(configpath),
args...;
kwargs...,
)
config = TOML.parsefile(configpath)
load_hyperparams(config, output_dirname, args...; kwargs...)
end
function hp2toml(hp::HyperParams, fname::AbstractString)
data = OrderedDict{String,Any}()
data["config"] = OrderedDict{String,Any}("version" => string(hp.configversion))
data["device"] = OrderedDict{String,Any}("device_id" => hp.dp.device_id)
for (sym, itemname) in [(:mp, "model"), (:pp, "physical"), (:tp, "training")]
obj = getfield(hp, sym)
v = OrderedDict(key => getfield(obj, key) for key in fieldnames(obj |> typeof))
data[itemname] = v
end
open(fname, "w") do io
TOML.print(io, data)
end
end