--- title: "Tidymodels Workflow with Functional Keras Models (Multi-Input)" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Tidymodels Workflow with Functional Keras Models (Multi-Input)} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ## Introduction This vignette demonstrates a complete `tidymodels` workflow for a regression task using a Keras functional model defined with `kerasnip`. We will use the Ames Housing dataset to predict house prices. A key feature of this example is the use of a multi-input Keras model, where numerical and categorical features are processed through separate input branches. `kerasnip` allows you to define complex Keras architectures, including those with multiple inputs, and integrate them seamlessly into the `tidymodels` ecosystem for robust modeling and tuning. ## Setup First, we load the necessary packages. ``` r library(kerasnip) library(tidymodels) #> ── Attaching packages ────────────────────────────────────────────────────────────────────────────── tidymodels 1.5.0 ── #> ✔ broom 1.0.12 ✔ recipes 1.3.2 #> ✔ dials 1.4.3 ✔ rsample 1.3.2 #> ✔ dplyr 1.2.1 ✔ tailor 0.1.0 #> ✔ ggplot2 4.0.3 ✔ tidyr 1.3.2 #> ✔ infer 1.1.0 ✔ tune 2.1.0 #> ✔ modeldata 1.5.1 ✔ workflows 1.3.0 #> ✔ parsnip 1.5.0 ✔ workflowsets 1.1.1 #> ✔ purrr 1.2.2 ✔ yardstick 1.4.0 #> ── Conflicts ───────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ── #> ✖ purrr::discard() masks scales::discard() #> ✖ dplyr::filter() masks stats::filter() #> ✖ dplyr::lag() masks stats::lag() #> ✖ recipes::step() masks stats::step() library(keras3) #> #> Attaching package: 'keras3' #> The following object is masked from 'package:yardstick': #> #> get_weights #> The following object is masked from 'package:infer': #> #> generate library(dplyr) # For data manipulation library(ggplot2) # For plotting library(future) # For parallel processing #> #> Attaching package: 'future' #> The following object is masked from 'package:keras3': #> #> %<-% library(finetune) # For racing ``` ## Data Preparation We'll use the Ames Housing dataset, which is available in the `modeldata` package. We will then split the data into training and testing sets. ``` r # Select relevant columns and remove rows with missing values ames_df <- ames |> select( Sale_Price, Gr_Liv_Area, Year_Built, Neighborhood, Bldg_Type, Overall_Cond, Total_Bsmt_SF, contains("SF") ) |> na.omit() # Split data into training and testing sets set.seed(123) ames_split <- initial_split(ames_df, prop = 0.8, strata = Sale_Price) ames_train <- training(ames_split) ames_test <- testing(ames_split) # Create cross-validation folds for tuning ames_folds <- vfold_cv(ames_train, v = 5, strata = Sale_Price) ``` ## Recipe for Preprocessing We will create a `recipes` object to preprocess our data. This recipe will: * Predict `Sale_Price` using all other variables. * Normalize all numerical predictors. * Create dummy variables for categorical predictors. * Collapse each group of predictors into a single matrix column using `step_collapse()`. This final step is crucial for the multi-input Keras model, as the `kerasnip` functional API expects a list of matrices for multiple inputs, where each matrix corresponds to a distinct input layer. ``` r ames_recipe <- recipe(Sale_Price ~ ., data = ames_train) |> step_normalize(all_numeric_predictors()) |> step_collapse(all_numeric_predictors(), new_col = "numerical_input") |> step_dummy(Neighborhood) |> step_collapse(starts_with("Neighborhood"), new_col = "neighborhood_input") |> step_dummy(Bldg_Type) |> step_collapse(starts_with("Bldg_Type"), new_col = "bldg_input") |> step_dummy(Overall_Cond) |> step_collapse(starts_with("Overall_Cond"), new_col = "condition_input") ``` ## Define Keras Functional Model with `kerasnip` Now, we define our Keras functional model using `kerasnip`'s layer blocks. This model will have four distinct input layers: one for numerical features and three for categorical features. These branches will be processed separately and then concatenated before the final output layer. ``` r # Define layer blocks for multi-input functional model # Input blocks for numerical and categorical features input_numerical <- function(input_shape) { layer_input(shape = input_shape, name = "numerical_input") } input_neighborhood <- function(input_shape) { layer_input(shape = input_shape, name = "neighborhood_input") } input_bldg <- function(input_shape) { layer_input(shape = input_shape, name = "bldg_input") } input_condition <- function(input_shape) { layer_input(shape = input_shape, name = "condition_input") } # Processing blocks for each input type dense_numerical <- function(tensor, units = 32, activation = "relu") { tensor |> layer_dense(units = units, activation = activation) } dense_categorical <- function(tensor, units = 16, activation = "relu") { tensor |> layer_dense(units = units, activation = activation) } # Concatenation block concatenate_features <- function(numeric, neighborhood, bldg, condition) { layer_concatenate(list(numeric, neighborhood, bldg, condition)) } # Output block for regression output_regression <- function(tensor) { layer_dense(tensor, units = 1, name = "output") } # Create the kerasnip model specification function create_keras_functional_spec( model_name = "ames_functional_mlp", layer_blocks = list( numerical_input = input_numerical, neighborhood_input = input_neighborhood, bldg_input = input_bldg, condition_input = input_condition, processed_numerical = inp_spec(dense_numerical, "numerical_input"), processed_neighborhood = inp_spec(dense_categorical, "neighborhood_input"), processed_bldg = inp_spec(dense_categorical, "bldg_input"), processed_condition = inp_spec(dense_categorical, "condition_input"), combined_features = inp_spec( concatenate_features, c( numeric = "processed_numerical", neighborhood = "processed_neighborhood", bldg = "processed_bldg", condition = "processed_condition" ) ), output = inp_spec(output_regression, "combined_features") ), mode = "regression" ) ``` ## Model Specification We'll define our `ames_functional_mlp` model specification and set some hyperparameters to `tune()`. Note how the arguments are prefixed with their corresponding block names (e.g., `processed_numerical_units`). ``` r # Define the tunable model specification functional_mlp_spec <- ames_functional_mlp( # Tunable parameters for numerical branch processed_numerical_units = tune(), # Tunable parameters for categorical branch processed_neighborhood_units = tune(), processed_bldg_units = tune(), processed_condition_units = tune(), # Fixed compilation and fitting parameters compile_loss = "mean_squared_error", compile_optimizer = "adam", compile_metrics = c("mean_absolute_error"), fit_epochs = 50, fit_batch_size = 32, fit_validation_split = 0.2, fit_callbacks = list( callback_early_stopping(monitor = "val_loss", patience = 5) ) ) |> set_engine("keras") print(functional_mlp_spec) #> ames functional mlp Model Specification (regression) #> #> Main Arguments: #> num_numerical_input = structure(list(), class = "rlang_zap") #> num_neighborhood_input = structure(list(), class = "rlang_zap") #> num_bldg_input = structure(list(), class = "rlang_zap") #> num_condition_input = structure(list(), class = "rlang_zap") #> num_processed_numerical = structure(list(), class = "rlang_zap") #> num_processed_neighborhood = structure(list(), class = "rlang_zap") #> num_processed_bldg = structure(list(), class = "rlang_zap") #> num_processed_condition = structure(list(), class = "rlang_zap") #> num_combined_features = structure(list(), class = "rlang_zap") #> num_output = structure(list(), class = "rlang_zap") #> processed_numerical_units = tune() #> processed_numerical_activation = structure(list(), class = "rlang_zap") #> processed_neighborhood_units = tune() #> processed_neighborhood_activation = structure(list(), class = "rlang_zap") #> processed_bldg_units = tune() #> processed_bldg_activation = structure(list(), class = "rlang_zap") #> processed_condition_units = tune() #> processed_condition_activation = structure(list(), class = "rlang_zap") #> learn_rate = structure(list(), class = "rlang_zap") #> fit_batch_size = 32 #> fit_epochs = 50 #> fit_callbacks = list(callback_early_stopping(monitor = "val_loss", patience = 5)) #> fit_validation_split = 0.2 #> fit_validation_data = structure(list(), class = "rlang_zap") #> fit_shuffle = structure(list(), class = "rlang_zap") #> fit_class_weight = structure(list(), class = "rlang_zap") #> fit_sample_weight = structure(list(), class = "rlang_zap") #> fit_initial_epoch = structure(list(), class = "rlang_zap") #> fit_steps_per_epoch = structure(list(), class = "rlang_zap") #> fit_validation_steps = structure(list(), class = "rlang_zap") #> fit_validation_batch_size = structure(list(), class = "rlang_zap") #> fit_validation_freq = structure(list(), class = "rlang_zap") #> fit_verbose = structure(list(), class = "rlang_zap") #> fit_view_metrics = structure(list(), class = "rlang_zap") #> compile_optimizer = adam #> compile_loss = mean_squared_error #> compile_metrics = c("mean_absolute_error") #> compile_loss_weights = structure(list(), class = "rlang_zap") #> compile_weighted_metrics = structure(list(), class = "rlang_zap") #> compile_run_eagerly = structure(list(), class = "rlang_zap") #> compile_steps_per_execution = structure(list(), class = "rlang_zap") #> compile_jit_compile = structure(list(), class = "rlang_zap") #> compile_auto_scale_loss = structure(list(), class = "rlang_zap") #> #> Computational engine: keras ``` ## Create Workflow A `workflow` combines the recipe and the model specification. ``` r ames_wf <- workflow() |> add_recipe(ames_recipe) |> add_model(functional_mlp_spec) print(ames_wf) #> ══ Workflow ════════════════════════════════════════════════════════════════════════════════════════════════════════════ #> Preprocessor: Recipe #> Model: ames_functional_mlp() #> #> ── Preprocessor ──────────────────────────────────────────────────────────────────────────────────────────────────────── #> 8 Recipe Steps #> #> • step_normalize() #> • step_collapse() #> • step_dummy() #> • step_collapse() #> • step_dummy() #> • step_collapse() #> • step_dummy() #> • step_collapse() #> #> ── Model ─────────────────────────────────────────────────────────────────────────────────────────────────────────────── #> ames functional mlp Model Specification (regression) #> #> Main Arguments: #> num_numerical_input = structure(list(), class = "rlang_zap") #> num_neighborhood_input = structure(list(), class = "rlang_zap") #> num_bldg_input = structure(list(), class = "rlang_zap") #> num_condition_input = structure(list(), class = "rlang_zap") #> num_processed_numerical = structure(list(), class = "rlang_zap") #> num_processed_neighborhood = structure(list(), class = "rlang_zap") #> num_processed_bldg = structure(list(), class = "rlang_zap") #> num_processed_condition = structure(list(), class = "rlang_zap") #> num_combined_features = structure(list(), class = "rlang_zap") #> num_output = structure(list(), class = "rlang_zap") #> processed_numerical_units = tune() #> processed_numerical_activation = structure(list(), class = "rlang_zap") #> processed_neighborhood_units = tune() #> processed_neighborhood_activation = structure(list(), class = "rlang_zap") #> processed_bldg_units = tune() #> processed_bldg_activation = structure(list(), class = "rlang_zap") #> processed_condition_units = tune() #> processed_condition_activation = structure(list(), class = "rlang_zap") #> learn_rate = structure(list(), class = "rlang_zap") #> fit_batch_size = 32 #> fit_epochs = 50 #> fit_callbacks = list(callback_early_stopping(monitor = "val_loss", patience = 5)) #> fit_validation_split = 0.2 #> fit_validation_data = structure(list(), class = "rlang_zap") #> fit_shuffle = structure(list(), class = "rlang_zap") #> fit_class_weight = structure(list(), class = "rlang_zap") #> fit_sample_weight = structure(list(), class = "rlang_zap") #> fit_initial_epoch = structure(list(), class = "rlang_zap") #> fit_steps_per_epoch = structure(list(), class = "rlang_zap") #> fit_validation_steps = structure(list(), class = "rlang_zap") #> fit_validation_batch_size = structure(list(), class = "rlang_zap") #> fit_validation_freq = structure(list(), class = "rlang_zap") #> fit_verbose = structure(list(), class = "rlang_zap") #> fit_view_metrics = structure(list(), class = "rlang_zap") #> compile_optimizer = adam #> compile_loss = mean_squared_error #> compile_metrics = c("mean_absolute_error") #> compile_loss_weights = structure(list(), class = "rlang_zap") #> compile_weighted_metrics = structure(list(), class = "rlang_zap") #> compile_run_eagerly = structure(list(), class = "rlang_zap") #> compile_steps_per_execution = structure(list(), class = "rlang_zap") #> compile_jit_compile = structure(list(), class = "rlang_zap") #> compile_auto_scale_loss = structure(list(), class = "rlang_zap") #> #> Computational engine: keras ``` ## Define Tuning Grid We will create a regular grid for our hyperparameters. ``` r # Define the tuning grid params <- extract_parameter_set_dials(ames_wf) |> update( processed_numerical_units = hidden_units(range = c(32, 128)), processed_neighborhood_units = hidden_units(range = c(16, 64)), processed_bldg_units = hidden_units(range = c(16, 64)), processed_condition_units = hidden_units(range = c(16, 64)) ) functional_mlp_grid <- grid_regular(params, levels = 3) print(functional_mlp_grid) #> # A tibble: 81 × 4 #> processed_numerical_units processed_neighborhood_units processed_bldg_units processed_condition_units #> #> 1 32 16 16 16 #> 2 80 16 16 16 #> 3 128 16 16 16 #> 4 32 40 16 16 #> 5 80 40 16 16 #> 6 128 40 16 16 #> 7 32 64 16 16 #> 8 80 64 16 16 #> 9 128 64 16 16 #> 10 32 16 40 16 #> # ℹ 71 more rows ``` ## Tune Model Now, we'll use `tune_race_anova()` to perform cross-validation and find the best hyperparameters. ``` r # Note: Parallel processing with `plan(multisession)` is currently not working # with Keras models due to backend conflicts set.seed(123) ames_tune_results <- tune_race_anova( ames_wf, resamples = ames_folds, grid = functional_mlp_grid, metrics = metric_set(rmse, mae, rsq), control = control_race(save_pred = TRUE, save_workflow = TRUE) ) ``` ## Inspect Tuning Results We can inspect the tuning results to see which hyperparameter combinations performed best. ``` r # Show the best performing models based on RMSE show_best(ames_tune_results, metric = "rmse", n = 5) #> # A tibble: 2 × 10 #> processed_numerical_units processed_neighborho…¹ processed_bldg_units processed_condition_…² .metric .estimator mean #> #> 1 128 64 64 64 rmse standard 53524. #> 2 128 64 40 64 rmse standard 54215. #> # ℹ abbreviated names: ¹​processed_neighborhood_units, ²​processed_condition_units #> # ℹ 3 more variables: n , std_err , .config # Autoplot the results # Currently does not work due to a label issue: autoplot(ames_tune_results) # Select the best hyperparameters best_functional_mlp_params <- select_best(ames_tune_results, metric = "rmse") print(best_functional_mlp_params) #> # A tibble: 1 × 5 #> processed_numerical_units processed_neighborhood_units processed_bldg_units processed_condition_units .config #> #> 1 128 64 64 64 pre0_mod81_post0 ``` ## Finalize Workflow and Fit Model Once we have the best hyperparameters, we finalize the workflow and fit the model on the entire training dataset. ``` r # Finalize the workflow with the best hyperparameters final_ames_wf <- finalize_workflow(ames_wf, best_functional_mlp_params) # Fit the final model on the full training data final_ames_fit <- fit(final_ames_wf, data = ames_train) print(final_ames_fit) #> ══ Workflow [trained] ══════════════════════════════════════════════════════════════════════════════════════════════════ #> Preprocessor: Recipe #> Model: ames_functional_mlp() #> #> ── Preprocessor ──────────────────────────────────────────────────────────────────────────────────────────────────────── #> 8 Recipe Steps #> #> • step_normalize() #> • step_collapse() #> • step_dummy() #> • step_collapse() #> • step_dummy() #> • step_collapse() #> • step_dummy() #> • step_collapse() #> #> ── Model ─────────────────────────────────────────────────────────────────────────────────────────────────────────────── #> $fit #> Model: "functional_262" #> ┌───────────────────────────────────┬──────────────────────────────┬───────────────────┬─────────────────────────────── #> │ Layer (type) │ Output Shape │ Param # │ Connected to #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ numerical_input (InputLayer) │ (None, 1, 10) │ 0 │ - #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ neighborhood_input (InputLayer) │ (None, 1, 28) │ 0 │ - #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ bldg_input (InputLayer) │ (None, 1, 4) │ 0 │ - #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ condition_input (InputLayer) │ (None, 1, 9) │ 0 │ - #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ dense_1033 (Dense) │ (None, 1, 128) │ 1,408 │ numerical_input[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ dense_1034 (Dense) │ (None, 1, 64) │ 1,856 │ neighborhood_input[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ dense_1035 (Dense) │ (None, 1, 64) │ 320 │ bldg_input[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ dense_1036 (Dense) │ (None, 1, 64) │ 640 │ condition_input[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ concatenate_258 (Concatenate) │ (None, 1, 320) │ 0 │ dense_1033[0][0], #> │ │ │ │ dense_1034[0][0], #> │ │ │ │ dense_1035[0][0], #> │ │ │ │ dense_1036[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ output (Dense) │ (None, 1, 1) │ 321 │ concatenate_258[0][0] #> └───────────────────────────────────┴──────────────────────────────┴───────────────────┴─────────────────────────────── #> Total params: 13,637 (53.27 KB) #> Trainable params: 4,545 (17.75 KB) #> Non-trainable params: 0 (0.00 B) #> Optimizer params: 9,092 (35.52 KB) #> #> $keras_bytes #> [1] 50 4b 03 04 14 00 00 00 00 00 00 00 21 00 39 22 4e 35 40 00 00 00 40 00 00 00 0d 00 00 00 6d 65 74 61 64 61 74 #> [38] 61 2e 6a 73 6f 6e 7b 22 6b 65 72 61 73 5f 76 65 72 73 69 6f 6e 22 3a 20 22 33 2e 31 34 2e 30 22 2c 20 22 64 61 #> [75] 74 65 5f 73 61 76 65 64 22 3a 20 22 32 30 32 36 2d 30 35 2d 30 31 40 31 32 3a 35 38 3a 31 31 22 7d 50 4b 03 04 #> [112] 14 00 00 00 00 00 00 00 21 00 c3 b3 02 c8 41 20 00 00 41 20 00 00 0b 00 00 00 63 6f 6e 66 69 67 2e 6a 73 6f 6e #> [149] 7b 22 6d 6f 64 75 6c 65 22 3a 20 22 6b 65 72 61 73 2e 73 72 63 2e 6d 6f 64 65 6c 73 2e 66 75 6e 63 74 69 6f 6e #> [186] 61 6c 22 2c 20 22 63 6c 61 73 73 5f 6e 61 6d 65 22 3a 20 22 46 75 6e 63 74 69 6f 6e 61 6c 22 2c 20 22 63 6f 6e #> [223] 66 69 67 22 3a 20 7b 22 6e 61 6d 65 22 3a 20 22 66 75 6e 63 74 69 6f 6e 61 6c 5f 32 36 32 22 2c 20 22 74 72 61 #> [260] 69 6e 61 62 6c 65 22 3a 20 74 72 75 65 2c 20 22 6c 61 79 65 72 73 22 3a 20 5b 7b 22 6d 6f 64 75 6c 65 22 3a 20 #> [297] 22 6b 65 72 61 73 2e 6c 61 79 65 72 73 22 2c 20 22 63 6c 61 73 73 5f 6e 61 6d 65 22 3a 20 22 49 6e 70 75 74 4c #> [334] 61 79 65 72 22 2c 20 22 63 6f 6e 66 69 67 22 3a 20 7b 22 62 61 74 63 68 5f 73 68 61 70 65 22 3a 20 5b 6e 75 6c #> [371] 6c 2c 20 31 2c 20 31 30 5d 2c 20 22 64 74 79 70 65 22 3a 20 22 66 6c 6f 61 74 33 32 22 2c 20 22 73 70 61 72 73 #> [408] 65 22 3a 20 66 61 6c 73 65 2c 20 22 72 61 67 67 65 64 22 3a 20 66 61 6c 73 65 2c 20 22 6e 61 6d 65 22 3a 20 22 #> [445] 6e 75 6d 65 72 69 63 61 6c 5f 69 6e 70 75 74 22 2c 20 22 6f 70 74 69 6f 6e 61 6c 22 3a 20 66 61 6c 73 65 7d 2c #> [482] 20 22 72 65 67 69 73 74 65 72 65 64 5f 6e 61 6d 65 22 3a 20 6e 75 6c 6c 2c 20 22 6e 61 6d 65 22 3a 20 22 6e 75 #> [519] 6d 65 72 69 63 61 6c 5f 69 6e 70 75 74 22 2c 20 22 69 6e 62 6f 75 6e 64 5f 6e 6f 64 65 73 22 3a 20 5b 5d 7d 2c #> [556] 20 7b 22 6d 6f 64 75 6c 65 22 3a 20 22 6b 65 72 61 73 2e 6c 61 79 65 72 73 22 2c 20 22 63 6c 61 73 73 5f 6e 61 #> #> ... #> and 2790 more lines. ``` ### Inspect Final Model You can extract the underlying Keras model and its training history for further inspection. ``` r # Extract the Keras model summary final_ames_fit |> extract_fit_parsnip() |> extract_keras_model() |> summary() #> Model: "functional_262" #> ┌───────────────────────────────────┬──────────────────────────────┬───────────────────┬─────────────────────────────── #> │ Layer (type) │ Output Shape │ Param # │ Connected to #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ numerical_input (InputLayer) │ (None, 1, 10) │ 0 │ - #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ neighborhood_input (InputLayer) │ (None, 1, 28) │ 0 │ - #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ bldg_input (InputLayer) │ (None, 1, 4) │ 0 │ - #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ condition_input (InputLayer) │ (None, 1, 9) │ 0 │ - #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ dense_1033 (Dense) │ (None, 1, 128) │ 1,408 │ numerical_input[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ dense_1034 (Dense) │ (None, 1, 64) │ 1,856 │ neighborhood_input[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ dense_1035 (Dense) │ (None, 1, 64) │ 320 │ bldg_input[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ dense_1036 (Dense) │ (None, 1, 64) │ 640 │ condition_input[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ concatenate_258 (Concatenate) │ (None, 1, 320) │ 0 │ dense_1033[0][0], #> │ │ │ │ dense_1034[0][0], #> │ │ │ │ dense_1035[0][0], #> │ │ │ │ dense_1036[0][0] #> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼─────────────────────────────── #> │ output (Dense) │ (None, 1, 1) │ 321 │ concatenate_258[0][0] #> └───────────────────────────────────┴──────────────────────────────┴───────────────────┴─────────────────────────────── #> Total params: 13,637 (53.27 KB) #> Trainable params: 4,545 (17.75 KB) #> Non-trainable params: 0 (0.00 B) #> Optimizer params: 9,092 (35.52 KB) ``` ``` r # Plot the Keras model final_ames_fit |> extract_fit_parsnip() |> extract_keras_model() |> plot(show_shapes = TRUE) ``` ![Model](images/model_plot_shapes_wf.png){fig-alt="A picture showing the model shape"} ``` r # Plot the training history final_ames_fit |> extract_fit_parsnip() |> extract_keras_history() |> plot() ``` ![plot of chunk inspect-final-keras-model-history](figure/inspect-final-keras-model-history-1.png) ## Make Predictions and Evaluate Finally, we will make predictions on the test set and evaluate the model's performance. ``` r # Make predictions on the test set ames_test_pred <- predict(final_ames_fit, new_data = ames_test) #> 19/19 - 0s - 10ms/step # Combine predictions with actuals ames_results <- tibble::tibble( Sale_Price = ames_test$Sale_Price, .pred = ames_test_pred$.pred ) print(head(ames_results)) #> # A tibble: 6 × 2 #> Sale_Price .pred #> #> 1 189900 193909. #> 2 195500 195484. #> 3 236500 234049. #> 4 212000 217096. #> 5 210000 241706. #> 6 142000 126019. # Evaluate performance using yardstick metrics metrics_results <- metric_set( rmse, mae, rsq )( ames_results, truth = Sale_Price, estimate = .pred ) print(metrics_results) #> # A tibble: 3 × 3 #> .metric .estimator .estimate #> #> 1 rmse standard 44687. #> 2 mae standard 27910. #> 3 rsq standard 0.767 ``` ## Saving and Reloading Your Model `kerasnip` serializes the Keras model weights to bytes at fit time and stores them alongside the workflow object. This means plain `saveRDS()` / `readRDS()` **works out of the box** — the underlying Keras model is restored automatically the first time `predict()` is called on the reloaded object. ``` r # Save the FINAL fitted workflow saveRDS(final_ames_fit, "ames_model.rds") # Reload — no extra steps needed final_ames_fit_loaded <- readRDS("ames_model.rds") # Make predictions again to prove it works predict(final_ames_fit_loaded, new_data = ames_test) |> head() #> 19/19 - 0s - 11ms/step #> # A tibble: 6 × 1 #> .pred #> #> 1 193909. #> 2 195484. #> 3 234049. #> 4 217096. #> 5 241706. #> 6 126019. ``` If you need a fully self-contained bundle suitable for deployment with `vetiver` or other MLOps tools, use `bundle::bundle()` instead: ``` r library(bundle) # Save as a portable bundle bundled <- bundle(final_ames_fit) saveRDS(bundled, "ames_model_bundle.rds") # Reload in any R session library(kerasnip) library(bundle) final_ames_fit_loaded <- unbundle(readRDS("ames_model_bundle.rds")) predict(final_ames_fit_loaded, new_data = ames_test) |> head() #> 19/19 - 0s - 9ms/step #> # A tibble: 6 × 1 #> .pred #> #> 1 193909. #> 2 195484. #> 3 234049. #> 4 217096. #> 5 241706. #> 6 126019. ``` See `vignette("saving_and_reloading")` for a detailed comparison of both approaches.