Skip to content

Commit

Permalink
improve kfold validate
Browse files Browse the repository at this point in the history
  • Loading branch information
gilbertocamara committed Jun 17, 2024
1 parent 87a94c3 commit 4f8a72c
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions R/sits_validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,31 +89,29 @@ sits_kfold_validate <- function(samples,
.check_that(!("NoClass" %in% labels),
msg = .conf("messages", "sits_kfold_validate_samples")
)
# start parallel process
multicores <- min(multicores, folds)
.parallel_start(workers = multicores)
on.exit(.parallel_stop())
# Create partitions different splits of the input data
samples <- .samples_create_folds(samples, folds = folds)
# Do parallel process
conf_lst <- .parallel_map(seq_len(folds), function(k) {
conf_lst <- purrr::map(seq_len(folds), function(k) {
# Split data into training and test data sets
data_train <- samples[samples[["folds"]] != k, ]
data_test <- samples[samples[["folds"]] == k, ]
# Create a machine learning model
ml_model <- sits_train(samples = data_train, ml_method = ml_method)
ml_model <- sits_train(
samples = data_train,
ml_method = ml_method
)
# classify test values
values <- sits_classify(
data = data_test,
ml_model = ml_model,
multicores = multicores
)
pred <- tidyr::unnest(values, "predicted")[["class"]]
# Convert samples time series in predictors and preprocess data
pred_test <- .predictors(samples = data_test, ml_model = ml_model)
# Get predictors features to classify
values <- .pred_features(pred_test)
# Classify the test data
values <- ml_model(values)
# Extract classified labels (majority probability)
values <- labels[C_label_max_prob(as.matrix(values))]
# Removes 'ml_model' variable
remove(ml_model)
return(list(pred = values, ref = .pred_references(pred_test)))
}, n_retries = 0, progress = FALSE)
ref <- values[["label"]]
return(list(pred = pred, ref = ref))
})
# create predicted and reference vectors
pred <- unlist(lapply(conf_lst, function(x) x[["pred"]]))
ref <- unlist(lapply(conf_lst, function(x) x[["ref"]]))
Expand Down

0 comments on commit 4f8a72c

Please sign in to comment.