library(torch)
library(luz)
library(tidyverse)
library(tidymodels)
library(MLDataR)2 Tutorial 2: Deep neural network with dataset functions and callbacks (binary classification)
Aims: To predict heart disease using a deep neural network with custom dataset functions and training callbacks.
Data: heartdisease data from MLDataR package.
Code description: This code demonstrates the use of torch with custom dataset functions, dataset subsets, and training callbacks including early stopping and best model checkpointing.
Packages
Data
heart_df <-
heartdisease %>%
mutate(across(c(Sex, RestingECG, Angina), as.factor))Explore data.
skimr::skim(heart_df)| Name | heart_df |
| Number of rows | 918 |
| Number of columns | 10 |
| _______________________ | |
| Column type frequency: | |
| factor | 3 |
| numeric | 7 |
| ________________________ | |
| Group variables | None |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| Sex | 0 | 1 | FALSE | 2 | M: 725, F: 193 |
| RestingECG | 0 | 1 | FALSE | 3 | Nor: 552, LVH: 188, ST: 178 |
| Angina | 0 | 1 | FALSE | 2 | N: 547, Y: 371 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| Age | 0 | 1 | 53.51 | 9.43 | 28.0 | 47.00 | 54.0 | 60.0 | 77.0 | ▁▅▇▆▁ |
| RestingBP | 0 | 1 | 132.40 | 18.51 | 0.0 | 120.00 | 130.0 | 140.0 | 200.0 | ▁▁▃▇▁ |
| Cholesterol | 0 | 1 | 198.80 | 109.38 | 0.0 | 173.25 | 223.0 | 267.0 | 603.0 | ▃▇▇▁▁ |
| FastingBS | 0 | 1 | 0.23 | 0.42 | 0.0 | 0.00 | 0.0 | 0.0 | 1.0 | ▇▁▁▁▂ |
| MaxHR | 0 | 1 | 136.81 | 25.46 | 60.0 | 120.00 | 138.0 | 156.0 | 202.0 | ▁▃▇▆▂ |
| HeartPeakReading | 0 | 1 | 0.89 | 1.07 | -2.6 | 0.00 | 0.6 | 1.5 | 6.2 | ▁▇▆▁▁ |
| HeartDisease | 0 | 1 | 0.55 | 0.50 | 0.0 | 0.00 | 1.0 | 1.0 | 1.0 | ▆▁▁▁▇ |
Dataset function
heart_dataset <- dataset(
initialize = function(df) {
# Pre-process and store as tensors
self$x_num <- df %>%
select(Age, RestingBP, Cholesterol, FastingBS, MaxHR, HeartPeakReading) %>%
mutate(across(everything(), scale)) %>%
as.matrix() %>%
torch_tensor(dtype = torch_float())
self$x_cat <- model.matrix(~ Sex + RestingECG + Angina, data = df)[, -1] %>%
as.matrix() %>%
torch_tensor(dtype = torch_float())
self$y <- torch_tensor(as.matrix(df$HeartDisease), dtype = torch_float())
},
.getitem = function(i) {
list(x = list(self$x_num[i, ], self$x_cat[i, ]), y = self$y[i])
},
.length = function() {
self$y$size(1)
}
)
# Convert to torch dataset
ds_tensor <- heart_dataset(heart_df)
ds_tensor[1]$x
$x[[1]]
torch_tensor
-1.4324
0.4107
0.8246
-0.5510
1.3822
-0.8320
[ CPUFloatType{6} ]
$x[[2]]
torch_tensor
1
1
0
0
[ CPUFloatType{4} ]
$y
torch_tensor
0
[ CPUFloatType{1} ]
Split data with dataset subsets
set.seed(123)
n <- nrow(heart_df)
train_size <- floor(0.6 * n)
valid_size <- floor(0.2 * n)
# Create indices
all_indices <- 1:n
train_indices <- sample(all_indices, size = train_size)
remaining_indices <- setdiff(all_indices, train_indices)
valid_indices <- sample(remaining_indices, size = valid_size)
test_indices <- setdiff(remaining_indices, valid_indices)
# Create Subsets
train_ds <- dataset_subset(ds_tensor, train_indices)
valid_ds <- dataset_subset(ds_tensor, valid_indices)
test_ds <- dataset_subset(ds_tensor, test_indices)Convert to dataloader
train_dl <- train_ds %>%
dataloader(batch_size = 10, shuffle = TRUE)
valid_dl <- valid_ds %>%
dataloader(batch_size = 10, shuffle = FALSE)
test_dl <- test_ds %>%
dataloader(batch_size = 10, shuffle = FALSE)Specify the model
net <-
nn_module(
initialize = function(d_in){
self$net <- nn_sequential(
nn_linear(d_in, 32),
nn_relu(),
nn_dropout(0.5),
nn_linear(32, 64),
nn_relu(),
nn_dropout(0.5),
nn_linear(64, 1),
nn_sigmoid()
)
},
forward = function(x){
# x is currently a list of two tensors (numeric and categorical)
# Concatenate them along the feature dimension (dim=2)
input <- torch_cat(x, dim = 2)
self$net(input)
}
)Fit the model
Set parameters
d_in <- length(ds_tensor[1]$x[[1]]) + length(ds_tensor[1]$x[[2]]) # total number of featuresFit with callbacks
fitted <-
net %>%
setup(
loss = nn_bce_loss(),
optimizer = optim_adam,
metrics = list(
luz_metric_binary_accuracy(),
luz_metric_binary_auroc()
)
) %>%
set_hparams(d_in = d_in) %>%
fit(
train_dl,
epoch = 50,
valid_data = valid_dl,
callbacks = list(
luz_callback_early_stopping(patience = 10),
luz_callback_keep_best_model()
)
)Training plot
fitted %>% plot()
Better plot
hist <- get_metrics(fitted)
optimal_epoch <- hist %>%
filter(metric == "loss", set == "valid") %>%
slice_min(value, n = 1) %>%
pull(epoch)
hist %>%
ggplot(aes(x = epoch, y = value, color = set)) +
geom_line(linewidth = 1) +
geom_point(size = 1.5) +
facet_wrap(~ metric, scales = "free_y", ncol = 1) +
theme_minimal() +
labs(
title = "Training vs Validation Metrics",
subtitle = paste("Optimal epoch:", optimal_epoch),
y = "Value",
x = "Epoch",
color = "Dataset"
)
Re-fit the model
Note: No need to refit manually since we use luz_callback_keep_best_model() to automatically save the best model based on validation loss.
Predict testing set
y_pred <- fitted %>% predict(test_dl)
y_true <- ds_tensor$y[test_ds$indices] %>% as_array()
dat_pred <-
y_pred %>%
as_array() %>%
as_data_frame() %>%
rename(prob = V1) %>%
mutate(
pred = factor(ifelse(prob > 0.5, 1, 0)),
true = factor(y_true)
)Warning: `as_data_frame()` was deprecated in tibble 2.0.0.
ℹ Please use `as_tibble()` (with slightly different semantics) to convert to a
tibble, or `as.data.frame()` to convert to a data frame.
Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if
`.name_repair` is omitted as of tibble 2.0.0.
ℹ Using compatibility `.name_repair`.
ℹ The deprecated feature was likely used in the tibble package.
Please report the issue at <https://github.com/tidyverse/tibble/issues>.
dat_pred# A tibble: 185 × 3
prob pred true
<dbl> <fct> <fct>
1 0.101 0 0
2 0.238 0 0
3 0.0764 0 0
4 0.968 1 1
5 0.0822 0 0
6 0.172 0 0
7 0.288 0 0
8 0.323 0 0
9 0.846 1 0
10 0.128 0 1
# ℹ 175 more rows
Evaluate
fitted %>% evaluate(test_dl)A `luz_module_evaluation`
── Results ─────────────────────────────────────────────────────────────────────
loss: 0.4093
acc: 0.8378
auc: 0.8844
Confusion matrix
dat_pred %>%
conf_mat(true, pred) %>%
autoplot("heatmap")
Accuracy
dat_pred %>%
accuracy(truth = true, estimate = pred)# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy binary 0.838
Plot ROC
dat_pred %>%
roc_curve(true, prob, event_level = "second") %>%
autoplot()
# ROC-AUC
dat_pred %>%
roc_auc(true, prob, event_level = "second")# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.885