diff --git a/R/missRanger.R b/R/missRanger.R index abfc9df..f1e3c13 100644 --- a/R/missRanger.R +++ b/R/missRanger.R @@ -375,8 +375,14 @@ convert <- function(X, check = FALSE) { stopifnot(is.data.frame(X)) if (!ncol(X)) { - return(list(X = X, bad = character(0), vars = character(0), - types = character(0), classes = character(0))) + out <- list( + X = X, + bad = character(0), + vars = character(0), + types = character(0), + classes = character(0) + ) + return(out) } types <- vapply(X, typeof2, FUN.VALUE = "") diff --git a/R/pmm.R b/R/pmm.R index 3734772..70ec638 100644 --- a/R/pmm.R +++ b/R/pmm.R @@ -22,12 +22,14 @@ #' pmm(xtrain = c("A", "A", "B"), xtest = "A", ytrain = c(2, 2, 4), k = 2) # 2 #' pmm(xtrain = factor(c("A", "B")), xtest = factor("C"), ytrain = 1:2) # 2 pmm <- function(xtrain, xtest, ytrain, k = 1L, seed = NULL) { - stopifnot(length(xtrain) == length(ytrain), - sum(ok <- !is.na(xtrain) & !is.na(ytrain)) >= 1L, - (nt <- length(xtest)) >= 1L, !anyNA(xtest), - mode(xtrain) %in% c("logical", "numeric", "character"), - mode(xtrain) == mode(xtest), - k >= 1L) + stopifnot( + length(xtrain) == length(ytrain), + sum(ok <- !is.na(xtrain) & !is.na(ytrain)) >= 1L, + (nt <- length(xtest)) >= 1L, !anyNA(xtest), + mode(xtrain) %in% c("logical", "numeric", "character"), + mode(xtrain) == mode(xtest), + k >= 1L + ) xtrain <- xtrain[ok] ytrain <- ytrain[ok] diff --git a/tests/testthat/test-missRanger.R b/tests/testthat/test-missRanger.R index 06e9bf6..ecb07ba 100644 --- a/tests/testthat/test-missRanger.R +++ b/tests/testthat/test-missRanger.R @@ -208,3 +208,26 @@ test_that("Extremely wide datasets are handled", { data[5L, 5L] <- NA expect_no_error(missRanger(as.data.frame(data), num.trees = 3L, verbose = 0L)) }) + +test_that("Convert and revert do what they should", { + n <- 20L + X <- data.frame( + x1 = seq_len(n), + x2 = rep(LETTERS[1:4], n %/% 4), + x3 = factor(rep(LETTERS[1:2], n %/% 2)), + x4 = seq_len(n) > n %/% 3, + x5 = seq(as.Date("2008-11-01"), as.Date("2008-11-21"), length.out = n) + ) + X <- generateNA(X, seed = 1L) + conv <- convert(X) + reve <- revert(conv) + + expect_equal(X, reve) + expect_equal( + unname(sapply(conv$X, class)), + c("integer", "factor", "factor", "factor", "numeric") + ) + + Ximp <- missRanger(X, seed = 1L, verbose = FALSE, pmm.k = 3L) + expect_equal(sapply(Ximp, class), sapply(X, class)) +})