Skip to content

Commit

Permalink
fix input_pixels argument for compatibility with old models
Browse files Browse the repository at this point in the history
  • Loading branch information
M3nin0 committed May 20, 2024
1 parent e3205de commit c67fadb
Show file tree
Hide file tree
Showing 15 changed files with 42 additions and 42 deletions.
6 changes: 3 additions & 3 deletions R/api_check.R
Original file line number Diff line number Diff line change
Expand Up @@ -1364,14 +1364,14 @@
#' @title Does the result have the same number of pixels as the input values?
#' @name .check_processed_values
#' @param values a matrix of processed values
#' @param n_input_pixels number of pixels in input matrix
#' @param input_pixels number of pixels in input matrix
#' @return Called for side effects.
#' @keywords internal
#' @noRd
.check_processed_values <- function(values, n_input_pixels) {
.check_processed_values <- function(values, input_pixels) {
.check_set_caller(".check_processed_values")
.check_that(
!(is.null(nrow(values))) && nrow(values) == n_input_pixels
!(is.null(nrow(values))) && nrow(values) == input_pixels
)
return(invisible(values))
}
Expand Down
4 changes: 2 additions & 2 deletions R/api_classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
# Fill with zeros remaining NA pixels
values <- C_fill_na(values, 0)
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Log here
.debug_log(
event = "start_block_data_classification",
Expand All @@ -127,7 +127,7 @@
# Are the results consistent with the data input?
.check_processed_values(
values = values,
n_input_pixels = n_input_pixels
input_pixels = input_pixels
)
# Log
.debug_log(
Expand Down
8 changes: 4 additions & 4 deletions R/api_combine_predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,13 @@
# Average probability calculation
comb_fn <- function(values, uncert_values = NULL) {
# Check values length
n_input_pixels <- nrow(values[[1]])
input_pixels <- nrow(values[[1]])
# Combine by average
values <- weighted_probs(values, weights)
# get the number of labels
n_labels <- length(sits_labels(cubes[[1]]))
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
.check_processed_labels(values, n_labels)
# Return values
values
Expand All @@ -244,13 +244,13 @@
# Average probability calculation
comb_fn <- function(values, uncert_values) {
# Check values length
n_input_pixels <- nrow(values[[1]])
input_pixels <- nrow(values[[1]])
# Combine by average
values <- weighted_uncert_probs(values, uncert_values)
# get the number of labels
n_labels <- length(sits_labels(cubes[[1]]))
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
.check_processed_labels(values, n_labels)
# Return values
values
Expand Down
4 changes: 2 additions & 2 deletions R/api_label_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@
.label_fn_majority <- function() {
label_fn <- function(values) {
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
values <- C_label_max_prob(values)
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Return values
values
}
Expand Down
4 changes: 2 additions & 2 deletions R/api_mixture_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@
em_mtx <- .endmembers_as_matrix(em)
mixture_fn <- function(values) {
# Check values length
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Process NNLS solver and return
values <- C_nnls_solver_batch(
x = as.matrix(values),
em = em_mtx,
rmse = rmse
)
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Return values
values
}
Expand Down
6 changes: 3 additions & 3 deletions R/api_reclassify.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
stop(.conf("messages", ".reclassify_fn_cube_mask"))
}
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Convert to character vector
values <- as.character(values)
mask_values <- as.character(mask_values)
Expand All @@ -185,12 +185,12 @@
# Get values as numeric
values <- matrix(
data = labels_code[match(values, labels)],
nrow = n_input_pixels
nrow = input_pixels
)
# Mask NA values
values[is.na(env[["mask"]])] <- NA
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Return values
values
}
Expand Down
4 changes: 2 additions & 2 deletions R/api_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
# Define smooth function
smooth_fn <- function(values, block) {
# Check values length
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Compute logit
values <- log(values / (rowSums(values) - values))
# Process Bayesian
Expand All @@ -187,7 +187,7 @@
# Compute inverse logit
values <- exp(values) / (exp(values) + 1)
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Return values
values
}
Expand Down
12 changes: 6 additions & 6 deletions R/api_uncertainty.R
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,12 @@
# Define uncertainty function
uncert_fn <- function(values) {
# Used in check (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Process least confidence
# return a matrix[rows(values),1]
values <- C_least_probs(values)
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Return data
values
}
Expand All @@ -260,11 +260,11 @@
# Define uncertainty function
uncert_fn <- function(values) {
# Used in check (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Process least confidence
values <- C_entropy_probs(values) # return a matrix[rows(values),1]
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Return data
values
}
Expand All @@ -280,11 +280,11 @@
# Define uncertainty function
uncert_fn <- function(values) {
# Used in check (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Process margin
values <- C_margin_probs(values) # return a matrix[rows(data),1]
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Return data
values
}
Expand Down
4 changes: 2 additions & 2 deletions R/api_variance.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
# Define smooth function
smooth_fn <- function(values, block) {
# Check values length
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Compute logit
values <- log(values / (rowSums(values) - values))
# Process variance
Expand All @@ -188,7 +188,7 @@
neigh_fraction = neigh_fraction
)
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Return values
values
}
Expand Down
4 changes: 2 additions & 2 deletions R/sits_lighttae.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ sits_lighttae <- function(samples = NULL,
# Unserialize model
torch_model[["model"]] <- .torch_unserialize_model(serialized_model)
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Transform input into a 3D tensor
# Reshape the 2D matrix into a 3D array
n_samples <- nrow(values)
Expand Down Expand Up @@ -356,7 +356,7 @@ sits_lighttae <- function(samples = NULL,
)
# Are the results consistent with the data input?
.check_processed_values(
values = values, n_input_pixels = n_input_pixels
values = values, input_pixels = input_pixels
)
# Update the columns names to labels
colnames(values) <- labels
Expand Down
12 changes: 6 additions & 6 deletions R/sits_machine_learning.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ sits_rfor <- function(samples = NULL, num_trees = 100, mtry = NULL, ...) {
# Verifies if randomForest package is installed
.check_require_packages("randomForest")
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Do classification
values <- stats::predict(
object = model, newdata = values, type = "prob"
)
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Reorder matrix columns if needed
if (any(labels != colnames(values))) {
values <- values[, labels]
Expand Down Expand Up @@ -193,7 +193,7 @@ sits_svm <- function(samples = NULL, formula = sits_formula_linear(),
# Verifies if e1071 package is installed
.check_require_packages("e1071")
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Performs data normalization
values <- .pred_normalize(pred = values, stats = ml_stats)
# Do classification
Expand All @@ -203,7 +203,7 @@ sits_svm <- function(samples = NULL, formula = sits_formula_linear(),
# Get the predicted probabilities
values <- attr(values, "probabilities")
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Reorder matrix columns if needed
if (any(labels != colnames(values))) {
values <- values[, labels]
Expand Down Expand Up @@ -337,14 +337,14 @@ sits_xgboost <- function(samples = NULL, learning_rate = 0.15,
# Verifies if xgboost package is installed
.check_require_packages("xgboost")
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Do classification
values <- stats::predict(
object = model, as.matrix(values), ntreelimit = ntreelimit,
reshape = TRUE
)
# Are the results consistent with the data input?
.check_processed_values(values, n_input_pixels)
.check_processed_values(values, input_pixels)
# Update the columns names to labels
colnames(values) <- labels
return(values)
Expand Down
4 changes: 2 additions & 2 deletions R/sits_mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ sits_mlp <- function(samples = NULL,
# Unserialize model
torch_model[["model"]] <- .torch_unserialize_model(serialized_model)
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Performs data normalization
values <- .pred_normalize(pred = values, stats = ml_stats)
# Transform input into matrix
Expand Down Expand Up @@ -305,7 +305,7 @@ sits_mlp <- function(samples = NULL,
)
# Are the results consistent with the data input?
.check_processed_values(
values = values, n_input_pixels = n_input_pixels
values = values, input_pixels = input_pixels
)
# Update the columns names to labels
colnames(values) <- labels
Expand Down
4 changes: 2 additions & 2 deletions R/sits_resnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ sits_resnet <- function(samples = NULL,
# Unserialize model
torch_model[["model"]] <- .torch_unserialize_model(serialized_model)
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Transform input into a 3D tensor
# Reshape the 2D matrix into a 3D array
n_samples <- nrow(values)
Expand Down Expand Up @@ -403,7 +403,7 @@ sits_resnet <- function(samples = NULL,
x = torch::torch_tensor(values, device = "cpu")
)
.check_processed_values(
values = values, n_input_pixels = n_input_pixels
values = values, input_pixels = input_pixels
)
# Update the columns names to labels
colnames(values) <- labels
Expand Down
4 changes: 2 additions & 2 deletions R/sits_tae.R
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ sits_tae <- function(samples = NULL,
# Unserialize model
torch_model[["model"]] <- .torch_unserialize_model(serialized_model)
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Transform input into a 3D tensor
# Reshape the 2D matrix into a 3D array
n_samples <- nrow(values)
Expand Down Expand Up @@ -323,7 +323,7 @@ sits_tae <- function(samples = NULL,
)
# Are the results consistent with the data input?
.check_processed_values(
values = values, n_input_pixels = n_input_pixels
values = values, input_pixels = input_pixels
)
# Update the columns names to labels
colnames(values) <- labels
Expand Down
4 changes: 2 additions & 2 deletions R/sits_tempcnn.R
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ sits_tempcnn <- function(samples = NULL,
# Unserialize model
torch_model[["model"]] <- .torch_unserialize_model(serialized_model)
# Used to check values (below)
n_input_pixels <- nrow(values)
input_pixels <- nrow(values)
# Transform input into a 3D tensor
# Reshape the 2D matrix into a 3D array
n_samples <- nrow(values)
Expand Down Expand Up @@ -374,7 +374,7 @@ sits_tempcnn <- function(samples = NULL,
)
# Are the results consistent with the data input?
.check_processed_values(
values = values, n_input_pixels = n_input_pixels
values = values, input_pixels = input_pixels
)
# Update the columns names to labels
colnames(values) <- labels
Expand Down

0 comments on commit c67fadb

Please sign in to comment.