Skip to content

Commit

Permalink
Update GPU syntax to allow for backend choices
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 18, 2023
1 parent 79f83a6 commit ae425b5
Show file tree
Hide file tree
Showing 3 changed files with 735 additions and 696 deletions.
23 changes: 22 additions & 1 deletion R/diffeqr.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,32 @@ jitoptimize_sde <- function (de,prob){
#' }
#'
#' @export
diffeqgpu_setup <- function (){
diffeqgpu_setup <- function (backend){
JuliaCall::julia_install_package_if_needed("DiffEqGPU")
JuliaCall::julia_library("DiffEqGPU")
functions <- JuliaCall::julia_eval("filter(isascii, replace.(string.(propertynames(DiffEqGPU)),\"!\"=>\"_bang\"))")
degpu <- julia_pkg_import("DiffEqGPU",functions)

if (backend == "CUDA") {
JuliaCall::julia_install_package_if_needed("CUDA")
JuliaCall::julia_library("CUDA")
backend <- julia_pkg_import("CUDA",c("CUDABackend"))
} else if (backend == "AMDGPU") {
JuliaCall::julia_install_package_if_needed("AMDGPU")
JuliaCall::julia_library("AMDGPU")
backend <- julia_pkg_import("AMDGPU",c("AMDGPUBackend"))
} else if (backend == "Metal") {
JuliaCall::julia_install_package_if_needed("Metal")
JuliaCall::julia_library("Metal")
backend <- julia_pkg_import("Metal",c("MetalBackend"))
} else if (backend == "oneAPI") {
JuliaCall::julia_install_package_if_needed("oneAPI")
JuliaCall::julia_library("oneAPI")
backend <- julia_pkg_import("oneAPI",c("oneAPIBackend"))
} else {
stop(paste("Illegal backend choice found. Allowed choices: CUDA, AMDGPU, Metal, and oneAPI. Chosen backend: ", backend)
}
list(degpu, backend)
}

julia_function <- function(func_name, pkg_name = "Main",
Expand Down
Loading

0 comments on commit ae425b5

Please sign in to comment.