Skip to content

Commit

Permalink
fix M3 chip torch bug
Browse files Browse the repository at this point in the history
  • Loading branch information
gilbertocamara committed May 22, 2024
1 parent 7e98e88 commit abb39b5
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 1 deletion.
6 changes: 6 additions & 0 deletions R/sits_lighttae.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,11 @@ sits_lighttae <- function(samples = NULL,
return(out)
}
)
# torch 12.0 not working with Apple MPS
if (torch::backends_mps_is_available())
cpu_train <- TRUE
else
cpu_train <- FALSE
# Train the model using luz
torch_model <-
luz::setup(
Expand Down Expand Up @@ -300,6 +305,7 @@ sits_lighttae <- function(samples = NULL,
gamma = lr_decay_rate
)
),
accelerator = luz::accelerator(cpu = cpu_train),
dataloader_options = list(batch_size = batch_size),
verbose = verbose
)
Expand Down
7 changes: 7 additions & 0 deletions R/sits_mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ sits_mlp <- function(samples = NULL,
self$model(x)
}
)
# torch 12.0 not working with Apple MPS
if (torch::backends_mps_is_available())
cpu_train <- TRUE
else
cpu_train <- FALSE
# Train the model using luz
torch_model <-
luz::setup(
Expand All @@ -262,6 +267,8 @@ sits_mlp <- function(samples = NULL,
patience = patience,
min_delta = min_delta
)),
dataloader_options = list(batch_size = batch_size),
accelerator = luz::accelerator(cpu = cpu_train),
verbose = verbose
)
# Serialize model
Expand Down
6 changes: 6 additions & 0 deletions R/sits_resnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ sits_resnet <- function(samples = NULL,
self$softmax()
}
)
# torch 12.0 not working with Apple MPS
if (torch::backends_mps_is_available())
cpu_train <- TRUE
else
cpu_train <- FALSE
# train the model using luz
torch_model <-
luz::setup(
Expand Down Expand Up @@ -354,6 +359,7 @@ sits_resnet <- function(samples = NULL,
gamma = lr_decay_rate
)
),
accelerator = luz::accelerator(cpu = cpu_train),
dataloader_options = list(batch_size = batch_size),
verbose = verbose
)
Expand Down
6 changes: 6 additions & 0 deletions R/sits_tae.R
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ sits_tae <- function(samples = NULL,
return(x)
}
)
# torch 12.0 not working with Apple MPS
if (torch::backends_mps_is_available())
cpu_train <- TRUE
else
cpu_train <- FALSE
# train the model using luz
torch_model <-
luz::setup(
Expand Down Expand Up @@ -273,6 +278,7 @@ sits_tae <- function(samples = NULL,
gamma = lr_decay_rate
)
),
accelerator = luz::accelerator(cpu = cpu_train),
dataloader_options = list(batch_size = batch_size),
verbose = verbose
)
Expand Down
9 changes: 8 additions & 1 deletion R/sits_tempcnn.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
#' @examples
#' if (sits_run_examples()) {
#' # create a TempCNN model
#' torch_model <- sits_train(samples_modis_ndvi, sits_tempcnn())
#' torch_model <- sits_train(samples_modis_ndvi,
#' sits_tempcnn(verbose = TRUE))
#' # plot the model
#' plot(torch_model)
#' # create a data cube from local files
Expand Down Expand Up @@ -285,6 +286,11 @@ sits_tempcnn <- function(samples = NULL,
self$softmax()
}
)
# torch 12.0 not working with Apple MPS
if (torch::backends_mps_is_available())
cpu_train <- TRUE
else
cpu_train <- FALSE
# Train the model using luz
torch_model <-
luz::setup(
Expand Down Expand Up @@ -323,6 +329,7 @@ sits_tempcnn <- function(samples = NULL,
gamma = lr_decay_rate
)
),
accelerator = luz::accelerator(cpu = cpu_train),
dataloader_options = list(batch_size = batch_size),
verbose = verbose
)
Expand Down
Binary file added mnist-cnn.pt
Binary file not shown.
Binary file added mnist/mnist/processed/test.rds
Binary file not shown.
Binary file added mnist/mnist/processed/training.rds
Binary file not shown.
Binary file added mnist/mnist/raw/t10k-images-idx3-ubyte.gz
Binary file not shown.
Binary file added mnist/mnist/raw/t10k-labels-idx1-ubyte.gz
Binary file not shown.
Binary file added mnist/mnist/raw/train-images-idx3-ubyte.gz
Binary file not shown.
Binary file added mnist/mnist/raw/train-labels-idx1-ubyte.gz
Binary file not shown.

0 comments on commit abb39b5

Please sign in to comment.