Skip to content

Commit

Permalink
feat: add cross-correlation pipe op (#109)
Browse files Browse the repository at this point in the history
* feat: draft for correlation pipe op

* docs: remove some old args

* chore: add grid as param

* docs: update description for cor

* feat: bring tfr to common domain

* perf: only reshape function when outside domain

* feat: transform to common domain via min max scaling

* feat: more changes for cor

* tests: simpliy tests for cross-correlation

* tests: restructure

* refactor: move helpers back since only used in extract

* refactor: reuse ncol

* fix: missing document

* docs: remove stats tag and rebuild readme

* tests: test for non matching domain

* refactor: revert var rename

* docs: update news

* tests: more tests

* chore: update .Rbuildignore

* tests: use expect_identical

* docs(cor): missing indent

* tests: adding withr for random seed usage

* tests: add error message for test

* tests: add message for warning

* feat: add mlr3pipelines to depends

* chore: remove unneeded mlr3 and mlr3pieplines import in tests

* fix: still import

* docs: some more description

* docs: update news
  • Loading branch information
m-muecke committed Jun 27, 2024
1 parent b0e60ff commit 7c74f8e
Show file tree
Hide file tree
Showing 26 changed files with 270 additions and 83 deletions.
8 changes: 4 additions & 4 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
^.*\.Rproj$
^.Renviron$
^\.pre-commit-config\.yaml$
^CITATION.cff$
^CITATION\.cff$
^CRAN-RELEASE$
^CRAN-SUBMISSION$
^LICENSE$
^Meta$
^README.html$
^README\.Rmd$
^README\.html$
^\.Renviron$
^\.Rproj\.user$
^\.ccache$
^\.editorconfig$
Expand All @@ -16,6 +15,7 @@
^\.httr-oauth$
^\.ignore$
^\.lintr$
^\.pre-commit-config\.yaml$
^\.vscode$
^_pkgdown\.yml$
^attic$
Expand Down
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,28 @@ URL: https://mlr3fda.mlr-org.com, https://github.com/mlr-org/mlr3fda
BugReports: https://github.com/mlr-org/mlr3fda/issues
Depends:
mlr3 (>= 0.14.0),
mlr3pipelines (>= 0.5.2),
R (>= 3.1.0)
Imports:
checkmate,
data.table,
lgr,
mlr3misc (>= 0.14.0),
mlr3pipelines (>= 0.5.2),
paradox,
R6,
tf (>= 0.3.4)
Suggests:
rpart,
testthat (>= 3.0.0)
testthat (>= 3.0.0),
withr
Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.1
Collate:
'zzz.R'
'PipeOpFDACor.R'
'PipeOpFDAExtract.R'
'PipeOpFDAFlatten.R'
'PipeOpFDAInterpol.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method(hash_input,tfd_irreg)
S3method(hash_input,tfd_reg)
export(PipeOpFDACor)
export(PipeOpFDAExtract)
export(PipeOpFDAFlatten)
export(PipeOpFDAInterpol)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# dev

* New PipeOp `PipeOpFDACor`
* mlr3fda now depends on mlr3pipelines instead of importing it

# mlr3fda 0.1.1

* Fix CRAN issues
Expand Down
76 changes: 76 additions & 0 deletions R/PipeOpFDACor.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#' @title Cross-Correlation of Functional Data
#' @name mlr_pipeops_fda.cor
#'
#' @description
#' Calculates the cross-correlation between two functional vectors using [tf::tf_crosscor()].
#' Note that it only operates on regular data and that the cross-correlation assumes that each column
#' has the same domain.
#'
#' @section Parameters:
#' The parameters are the parameters inherited from [`PipeOpTaskPreprocSimple`], as well as the following
#' parameters:
#' * `arg` :: `numeric()` \cr
#' Grid to use for the cross-correlation.
#'
#' @export
#' @examples
#' set.seed(1234L)
#' dt = data.table(y = 1:100, x1 = tf::tf_rgp(100L), x2 = tf::tf_rgp(100L))
#' task = as_task_regr(dt, target = "y")
#' po_cor = po("fda.cor")
#' task_cor = po_cor$train(list(task))[[1L]]
#' task_cor
PipeOpFDACor = R6Class("PipeOpFDACor",
inherit = mlr3pipelines::PipeOpTaskPreprocSimple,
public = list(
#' @description Initializes a new instance of this Class.
#' @param id (`character(1)`)\cr
#' Identifier of resulting object, default `"fda.cor"`.
#' @param param_vals (named `list`)\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would
#' otherwise be set during construction. Default `list()`.
initialize = function(id = "fda.cor", param_vals = list()) {
param_set = ps(
arg = p_uty(tags = c("train", "predict"), custom_check = check_numeric)
)

super$initialize(
id = id,
param_set = param_set,
param_vals = param_vals,
packages = c("mlr3fda", "mlr3pipelines", "tf"),
feature_types = "tfd_reg",
tags = "fda"
)
}
),
private = list(
.transform_dt = function(dt, levels) {
pars = self$param_set$get_values()

k = ncol(dt)
if (k < 2L) {
warningf("task has less than 2 columns")
return(dt)
}

nms = names(dt)
res = list()
for (i in 2:k) {
for (j in 1:(i - 1L)) {
x = dt[[i]]
y = dt[[j]]
if (!all(tf::tf_domain(x) == tf::tf_domain(y))) {
stopf("Domain of %s and %s do not match", nms[[j]], nms[[i]])
}
nm = sprintf("%s_%s_cor", nms[[j]], nms[[i]])
res[[nm]] = invoke(tf::tf_crosscor, x = x, y = y, .args = pars)
}
}
setDT(res)
}
)
)

#' @include zzz.R
register_po("fda.cor", PipeOpFDACor)
2 changes: 0 additions & 2 deletions R/PipeOpFDAExtract.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
#'
#' @export
#' @examples
#' library(mlr3pipelines)
#'
#' task = tsk("fuel")
#' po_fmean = po("fda.extract", features = "mean")
#' task_fmean = po_fmean$train(list(task))[[1L]]
Expand Down
2 changes: 0 additions & 2 deletions R/PipeOpFDAFlatten.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
#'
#' @export
#' @examples
#' library(mlr3pipelines)
#'
#' task = tsk("fuel")
#' pop = po("fda.flatten")
#' task_flat = pop$train(list(task))
Expand Down
2 changes: 0 additions & 2 deletions R/PipeOpFDAInterpol.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@
#'
#' @export
#' @examples
#' library(mlr3pipelines)
#'
#' task = tsk("fuel")
#' pop = po("fda.interpol")
#' task_interpol = pop$train(list(task))[[1L]]
Expand Down
2 changes: 0 additions & 2 deletions R/PipeOpFDASmooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
#'
#' @export
#' @examples
#' library(mlr3pipelines)
#'
#' task = tsk("fuel")
#' po_smooth = po("fda.smooth", method = "rollmean", args = list(k = 5))
#' task_smooth = po_smooth$train(list(task))[[1L]]
Expand Down
2 changes: 0 additions & 2 deletions R/PipeOpFPCA.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
#'
#' @export
#' @examples
#' library(mlr3pipelines)
#'
#' task = tsk("fuel")
#' po_fpca = po("fda.fpca", n_components = 3L)
#' task_fpca = po_fpca$train(list(task))[[1L]]
Expand Down
2 changes: 0 additions & 2 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ To train a model on this task we first need to extract scalar features from the
We illustrate this below by extracting the mean value.

```{r fda.extract, fig.width = 5, fig.height = 3}
library(mlr3pipelines)
po_fmean = po("fda.extract", features = "mean")
task_fmean = po_fmean$train(list(task))[[1L]]
Expand Down
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ from the functions. We illustrate this below by extracting the mean
value.

``` r
library(mlr3pipelines)

po_fmean = po("fda.extract", features = "mean")

task_fmean = po_fmean$train(list(task))[[1L]]
Expand Down Expand Up @@ -122,13 +120,14 @@ glrn$predict(task, row_ids = ids$test)

## Implemented PipeOps

| Key | Label | Packages | Tags |
|:-------------------------------------------------------------------------------|:-------------------------------------------------|:---------------------------------------------------|:--------------------|
| [fda.extract](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.extract) | Extracts Simple Features from Functional Columns | [tf](https://cran.r-project.org/package=tf) | fda, data transform |
| [fda.flatten](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.flatten) | Flattens Functional Columns | [tf](https://cran.r-project.org/package=tf) | fda, data transform |
| [fda.fpca](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.fpca) | Functional Principal Component Analysis | [tf](https://cran.r-project.org/package=tf) | fda, data transform |
| [fda.interpol](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.interpol) | Interpolate Functional Columns | [tf](https://cran.r-project.org/package=tf) | fda, data transform |
| [fda.smooth](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.smooth) | Smoothing Functional Columns | [tf](https://cran.r-project.org/package=tf), stats | fda, data transform |
| Key | Label | Packages | Tags |
|:---|:---|:---|:---|
| [fda.cor](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.cor) | Cross-Correlation of Functional Data | [tf](https://cran.r-project.org/package=tf) | fda, data transform |
| [fda.extract](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.extract) | Extracts Simple Features from Functional Columns | [tf](https://cran.r-project.org/package=tf) | fda, data transform |
| [fda.flatten](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.flatten) | Flattens Functional Columns | [tf](https://cran.r-project.org/package=tf) | fda, data transform |
| [fda.fpca](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.fpca) | Functional Principal Component Analysis | [tf](https://cran.r-project.org/package=tf) | fda, data transform |
| [fda.interpol](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.interpol) | Interpolate Functional Columns | [tf](https://cran.r-project.org/package=tf) | fda, data transform |
| [fda.smooth](https://mlr3fda.mlr-org.com/reference/mlr_pipeops_fda.smooth) | Smoothing Functional Columns | [tf](https://cran.r-project.org/package=tf), stats | fda, data transform |

## Bugs, Questions, Feedback

Expand Down
88 changes: 88 additions & 0 deletions man/mlr_pipeops_fda.cor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions man/mlr_pipeops_fda.extract.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions man/mlr_pipeops_fda.flatten.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions man/mlr_pipeops_fda.fpca.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions man/mlr_pipeops_fda.interpol.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions man/mlr_pipeops_fda.smooth.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 7c74f8e

Please sign in to comment.