This RMarkdown notebook is a demonstration of running KNN regression with RStudio and Databricks integration. For more information, please refer to Databricks RStudio Integration.

Set up your RStudio and Databricks integration

Identify the existing SparkR backend port

Sys.getenv("EXISTING_SPARKR_BACKEND_PORT")
## [1] "38855"

Attach SparkR library

library(SparkR)
## 
## Attaching package: 'SparkR'
## The following objects are masked from 'package:stats':
## 
##     cov, filter, lag, na.omit, predict, sd, var, window
## The following objects are masked from 'package:base':
## 
##     as.data.frame, colnames, colnames<-, drop, endsWith,
##     intersect, rank, rbind, sample, startsWith, subset, summary,
##     transform, union

Establish SparkR session

sparkR.session()
## Spark package found in SPARK_HOME: /databricks/spark
## Java ref type org.apache.spark.sql.SparkSession id 22

Run the KNN Regression Demo

Prepare input data

library(ggplot2)
write.csv(diamonds, "/dbfs/tmp/diamonds.csv", row.names = F)

Review the diamonds dataset

The diamond dataset contains 10 features of 50K diamonds, including price. We are going to use other features of diamonds to predict price using a well-known non-parametric model called k-nearest neighbors (KNN) regression. We will be using the FNN package for this purpose. We first create a library on Databricks workspace and attach it to our serverless pool so that FNN will be available on all workers.

library(magrittr)
## 
## Attaching package: 'magrittr'
## The following object is masked from 'package:SparkR':
## 
##     not
diamonds <- read.csv("/dbfs/tmp/diamonds.csv", header = T)
temp.data <- subset(diamonds, select = -price)
summary(diamonds)
##      carat               cut        color        clarity     
##  Min.   :0.2000   Fair     : 1610   D: 6775   SI1    :13065  
##  1st Qu.:0.4000   Good     : 4906   E: 9797   VS2    :12258  
##  Median :0.7000   Ideal    :21551   F: 9542   SI2    : 9194  
##  Mean   :0.7979   Premium  :13791   G:11292   VS1    : 8171  
##  3rd Qu.:1.0400   Very Good:12082   H: 8304   VVS2   : 5066  
##  Max.   :5.0100                     I: 5422   VVS1   : 3655  
##                                     J: 2808   (Other): 2531  
##      depth           table           price             x         
##  Min.   :43.00   Min.   :43.00   Min.   :  326   Min.   : 0.000  
##  1st Qu.:61.00   1st Qu.:56.00   1st Qu.:  950   1st Qu.: 4.710  
##  Median :61.80   Median :57.00   Median : 2401   Median : 5.700  
##  Mean   :61.75   Mean   :57.46   Mean   : 3933   Mean   : 5.731  
##  3rd Qu.:62.50   3rd Qu.:59.00   3rd Qu.: 5324   3rd Qu.: 6.540  
##  Max.   :79.00   Max.   :95.00   Max.   :18823   Max.   :10.740  
##                                                                  
##        y                z         
##  Min.   : 0.000   Min.   : 0.000  
##  1st Qu.: 4.720   1st Qu.: 2.910  
##  Median : 5.710   Median : 3.530  
##  Mean   : 5.735   Mean   : 3.539  
##  3rd Qu.: 6.540   3rd Qu.: 4.040  
##  Max.   :58.900   Max.   :31.800  
## 

Create function using the KNN algorithm

This simple function uses KNN algorithm for a given K and subset of columns of diamonds dataset. It returns the leave-one-out cross-validation (LOOCV) sum of squared residuals.

run.knn <- function(params) {
  library(FNN)
  library(magrittr)
  
  k <- params$k
  columns <- unlist(params$columns[[1]])
  id <- paste0(sort(columns), collapse = '')
  
  diamonds <- read.csv("/dbfs/tmp/diamonds.csv", header = T)
  temp.train.df <- diamonds %>% subset(select = -price) %>% sapply(as.numeric) %>% as.data.frame
  train.df <- temp.train.df[, columns]
  
  response <- diamonds$price
  
  knn.fit <- knn.reg(train = train.df,  y = response, k = k)
  c(PRESS = sum(knn.fit$PRESS), R2pred = knn.fit$R2Pred, numColumns = length(columns), k = k, id = id)
}

What is the optimal subset of features and number of neighbors for this model?

We will be using grid-search on top of Apache Spark to find the optimal solution. To do so, we first build a grid of all possible combinations of feature subsets and vary K from 1 to 25. This results in 12,775 possibilities.

library(combinat)
## 
## Attaching package: 'combinat'
## The following object is masked from 'package:utils':
## 
##     combn
col.search <- sapply(1:ncol(temp.data), function(n) {
  combn(1:ncol(temp.data), n) %>% as.data.frame %>% t %>% split(seq(nrow(.)))
}) %>% unlist(recursive=F) 
k <- 1:25
grid <- expand.grid(columns = col.search, k = k) %>% split(seq(nrow(.)))

length(grid)
## [1] 12775

Compare test-error

To compare test-error across all these possibilities we will parallelize it using SparkR::spark.lapply() function. We construct a data.frame of input parameters and corresponding test error values.

#library(SparkR)
#sparkR.session()
grid.search <- spark.lapply(grid, run.knn)

id <- sapply(grid.search, function(x) { x[[5]] })
k <- sapply(grid.search, function(x) { x[[4]] })
n <- sapply(grid.search, function(x) { x[[3]] })
r2 <- sapply(grid.search, function(x) { x[[2]] })
press <- sapply(grid.search, function(x) { x[[1]] })

Visualize the results

A key part of identifying optimal values is visual exploration of the result, which we can do using plot.ly and ggplot2.

plot.data <- data.frame(
  id = as.factor(id),
  k = as.numeric(k),
  n = as.factor(n),
  press = as.numeric(press))

library(plotly)
## 
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
## 
##     last_plot
## The following objects are masked from 'package:SparkR':
## 
##     arrange, distinct, filter, group_by, mutate, rename, schema,
##     select
## The following object is masked from 'package:stats':
## 
##     filter
## The following object is masked from 'package:graphics':
## 
##     layout
plot_ly(plot.data, x = ~k, y = ~n, z = ~press, color = ~n) %>% add_markers()
## Warning in RColorBrewer::brewer.pal(N, "Set2"): n too large, allowed maximum for palette Set2 is 8
## Returning the palette you asked for with that many colors
library(plyr)
## 
## Attaching package: 'plyr'
## The following objects are masked from 'package:plotly':
## 
##     arrange, mutate, rename, summarise
## The following objects are masked from 'package:SparkR':
## 
##     arrange, count, desc, join, mutate, rename, summarize, take
min.surface <- ddply(plot.data, .(k, n), summarize, press = min(press))
plot_ly(min.surface, x = ~k, y = ~n, z = ~press, color = ~n) %>% add_markers()
## Warning in RColorBrewer::brewer.pal(N, "Set2"): n too large, allowed maximum for palette Set2 is 8
## Returning the palette you asked for with that many colors
ggplot(min.surface, aes(k, press, color = n)) + 
  geom_point() +
  geom_line(aes(group = n)) +
  geom_label(data = subset(min.surface, k == 1), aes(label = n)) +
  theme_bw() + geom_hline(yintercept = min(plot.data$press), size = 1.2) 

Finally we can find the minimum value to identify optimal combinations of columns and K.

plot.data[which(plot.data$press == min(plot.data$press)), ]
##       id k n       press
## 3119 134 7 3 17976133574
diamonds[, c(1, 3, 4)] %>% head
##   carat color clarity
## 1  0.23     E     SI2
## 2  0.21     E     SI1
## 3  0.23     E     VS1
## 4  0.29     I     VS2
## 5  0.31     J     SI2
## 6  0.24     J    VVS2