--- title: "Saving and Reloading Fitted Kerasnip Workflows" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Saving and Reloading Fitted Kerasnip Workflows} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ## Overview Keras models are backed by Python objects managed by TensorFlow/JAX. These objects live in the current R session and are represented as external pointers (`externalptr`) that become invalid as soon as the session ends, or even within the same session after `saveRDS()` / `readRDS()`. `kerasnip` handles this transparently so that fitted workflows can be saved, reloaded, and used for prediction without any manual restoration steps. ## Quick workflow example Before discussing the details, here is the full persistence workflow: ``` r library(kerasnip) library(tidymodels) #> ── Attaching packages ────────────────────────────────────────────────────────────────────────────── tidymodels 1.5.0 ── #> ✔ broom 1.0.12 ✔ recipes 1.3.2 #> ✔ dials 1.4.3 ✔ rsample 1.3.2 #> ✔ dplyr 1.2.1 ✔ tailor 0.1.0 #> ✔ ggplot2 4.0.3 ✔ tidyr 1.3.2 #> ✔ infer 1.1.0 ✔ tune 2.1.0 #> ✔ modeldata 1.5.1 ✔ workflows 1.3.0 #> ✔ parsnip 1.5.0 ✔ workflowsets 1.1.1 #> ✔ purrr 1.2.2 ✔ yardstick 1.4.0 #> ── Conflicts ───────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ── #> ✖ purrr::discard() masks scales::discard() #> ✖ dplyr::filter() masks stats::filter() #> ✖ dplyr::lag() masks stats::lag() #> ✖ recipes::step() masks stats::step() library(keras3) #> #> Attaching package: 'keras3' #> The following object is masked from 'package:yardstick': #> #> get_weights #> The following object is masked from 'package:infer': #> #> generate # 1. Define Layer Blocks (Required by kerasnip) # The first block must initialize the sequential model input_block <- function(model, input_shape) { keras_model_sequential(input_shape = input_shape) } # Hidden layer block dense_block <- function(model, units = 32) { model |> layer_dense(units = units, activation = "relu") } # Output layer block (units = 1 for regression) output_block <- function(model, num_classes) { model |> layer_dense(units = 1) } # 2. Generate the parsnip specification create_keras_sequential_spec( model_name = "my_mlp", layer_blocks = list( input = input_block, hidden = dense_block, output = output_block ), mode = "regression" ) # 3. Use the newly created 'my_mlp' function mod_spec <- my_mlp(fit_epochs = 10) |> set_engine("keras") # 4. Standard tidymodels workflow rec_spec <- recipe(mpg ~ ., data = mtcars) |> step_normalize(all_predictors()) fit_wf <- workflow() |> add_recipe(rec_spec) |> add_model(mod_spec) |> fit(data = mtcars) # Predict new_data <- mtcars[1:3, ] predict(fit_wf, new_data) #> 1/1 - 0s - 116ms/step #> # A tibble: 3 × 1 #> .pred #> #> 1 1.00 #> 2 0.956 #> 3 3.03 ``` The first call to predict() detects that the Python pointer is invalid and restores the model from the stored bytes automatically. ## What kerasnip does behind the scenes `kerasnip` handles persistence automatically: - At **fit time**, the Keras model is serialized to a raw byte vector (`.keras`format) and stored alongside the parsnip `model_fit` object. - At **predict time**, if the Python pointer is detected as invalid, `predict()` automatically restores the model from those bytes before dispatching. - The parsnip model specification is also re-registered if it is missing from the session (e.g. after a fresh start). This means you can use the persistence strategy that best suits your workflow without any extra boilerplate. ## Strategy 1: Plain `saveRDS()` / `readRDS()` For most use cases: sharing a model file with a colleague, caching a fit between R sessions, or checkpointing during development; plain RDS is the simplest approach. ``` r library(kerasnip) library(workflows) library(parsnip) library(recipes) # --- Save --- saveRDS(fit_wf, "my_model.rds") # --- Reload in the same or a new R session --- library(kerasnip) fit_wf <- readRDS("my_model.rds") # predict() restores the Keras model from bytes automatically predictions <- predict(fit_wf, new_data = new_data) #> 1/1 - 0s - 118ms/step predictions #> # A tibble: 3 × 1 #> .pred #> #> 1 1.00 #> 2 0.956 #> 3 3.03 ``` There is nothing special to do after `readRDS()`. The first call to `predict()` detects the invalid pointer, restores the model from the stored bytes, and then proceeds normally. ## Strategy 2: `bundle` / `unbundle` The [`bundle`](https://rstudio.github.io/bundle/) package provides a standardized serialization interface used by `vetiver`, `plumber`, and other MLOps tools. It is the right choice when: - You are deploying a model to a `vetiver` API or a Docker container. - You want a self-contained, version-controlled artifact that does not rely on any R session state. - You are sharing a model across machines with different R library paths. ``` r library(kerasnip) library(bundle) library(workflows) # --- Save --- bundled <- bundle(fit_wf) saveRDS(bundled, "my_model_bundle.rds") # --- Reload in any R session --- library(kerasnip) library(bundle) bundled <- readRDS("my_model_bundle.rds") fit_wf <- unbundle(bundled) predictions <- predict(fit_wf, new_data = new_data) #> 1/1 - 0s - 110ms/step predictions #> # A tibble: 3 × 1 #> .pred #> #> 1 1.00 #> 2 0.956 #> 3 3.03 ``` ## Comparison | | `saveRDS` / `readRDS` | `bundle` / `unbundle` | |---|---|---| | Works across sessions | ✅ | ✅ | | Works across machines | ✅ (same R library) | ✅ | | `vetiver` / Docker compatible | ❌ | ✅ | | Extra dependency needed | ❌ | `bundle` package | | Code complexity | Minimal | Minimal | ## What happens under the hood When `kerasnip` fits a model, the generic fit function calls `keras_model_to_bytes()`, which writes the model to a temporary `.keras` file using `keras3::save_model()` and reads the bytes back into R: ``` r # Simplified version of what happens inside generic_sequential_fit() keras_bytes <- keras_model_to_bytes(model) # keras_bytes is a raw vector stored in object$fit$keras_bytes ``` When `predict()` is called on a reloaded object, `predict.kerasnip_model_fit()` runs: ``` r # Simplified version of predict.kerasnip_model_fit() if (!is.null(object$fit$keras_bytes)) { is_valid <- tryCatch( { reticulate::py_validate_xptr(object$fit$fit) TRUE }, error = function(e) FALSE ) if (!is_valid) { object$fit$fit <- keras_model_from_bytes(object$fit$keras_bytes) } } ``` If `keras_model_to_bytes()` fails (e.g. if the model was compiled with a non-serialisable custom object), a warning is issued at fit time and `keras_bytes` is set to `NULL`. In that case, `predict()` after reload will fail with a clear error.