-
Notifications
You must be signed in to change notification settings - Fork 76
/
api_tuning.R
78 lines (78 loc) · 2.41 KB
/
api_tuning.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#' @title Get random hyper-parameter
#'
#' @description
#' Evaluate params by returning random numbers according
#' to params definition returned by \code{sits_tuning_hparams}
#'
#' @keywords internal
#' @noRd
#' @param trial current trial
#' @param params Hyperparameters
#' @return A list with random values for the hyperparameters
#'
.tuning_pick_random <- function(trial, params) {
# uniform distribution
uniform <- function(min = 0, max = 1) {
val <- stats::runif(n = 1, min = min, max = max)
return(val)
}
# random choice
choice <- function(..., replace = TRUE) {
options <- as.list(substitute(list(...), environment()))[-1]
val <- sample(x = options, replace = replace, size = 1)
if (length(val) == 1) val <- val[[1]]
return(unlist(val))
}
# normal distribution
normal <- function(mean = 0, sd = 1) {
val <- stats::rnorm(n = 1, mean = mean, sd = sd)
return(val)
}
# lognormal distribution
lognormal <- function(meanlog = 0, sdlog = 1) {
val <- stats::rlnorm(n = 1, meanlog = meanlog, sdlog = sdlog)
return(val)
}
# loguniform distribution
loguniform <- function(minlog = 0, maxlog = 1) {
base <- exp(1)
return(exp(stats::runif(1, log(min(c(minlog, maxlog)), base),
log(max(c(minlog, maxlog)), base))))
}
# beta distribution
beta <- function(shape1, shape2) {
val <- stats::rbeta(n = 1, shape1 = shape1, shape2 = shape2)
return(val)
}
# get
params <- purrr::map(as.list(params), eval, envir = environment())
params[["samples"]] <- NULL
return(params)
}
#' @title Convert hyper-parameters list to a tibble
#' @name .tuning_params_as_tibble
#' @keywords internal
#' @noRd
#' @description
#' Generate a tibble (one row per trial) with all model parameters
#' @param params hyperparams from sits_tuning function
#' @return A named list with provided parameters
#'
.tuning_params_as_tibble <- function(params) {
params <- lapply(params, function(x) {
if (purrr::is_atomic(x)) {
if (length(x) != 1) {
return(list(x))
}
return(x)
}
if (purrr::is_list(x)) {
return(list(.tuning_params_as_tibble(x)))
}
if (is.language(x)) {
return(deparse(x))
}
return(list(x))
})
return(tibble::tibble(!!!params))
}