This RMarkdown notebook is a demonstration of running KNN regression with RStudio and Databricks integration. For more information, please refer to Databricks RStudio Integration.
plotly
, combinat
, and FNN
R-packages to all of the worker nodes of your Databricks cluster via Upload a CRAN librarySys.getenv("EXISTING_SPARKR_BACKEND_PORT")
## [1] "38855"
library(SparkR)
##
## Attaching package: 'SparkR'
## The following objects are masked from 'package:stats':
##
## cov, filter, lag, na.omit, predict, sd, var, window
## The following objects are masked from 'package:base':
##
## as.data.frame, colnames, colnames<-, drop, endsWith,
## intersect, rank, rbind, sample, startsWith, subset, summary,
## transform, union
sparkR.session()
## Spark package found in SPARK_HOME: /databricks/spark
## Java ref type org.apache.spark.sql.SparkSession id 22
library(ggplot2)
write.csv(diamonds, "/dbfs/tmp/diamonds.csv", row.names = F)
The diamond dataset contains 10 features of 50K diamonds, including price. We are going to use other features of diamonds to predict price using a well-known non-parametric model called k-nearest neighbors (KNN) regression. We will be using the FNN package for this purpose. We first create a library on Databricks workspace and attach it to our serverless pool so that FNN will be available on all workers.
library(magrittr)
##
## Attaching package: 'magrittr'
## The following object is masked from 'package:SparkR':
##
## not
diamonds <- read.csv("/dbfs/tmp/diamonds.csv", header = T)
temp.data <- subset(diamonds, select = -price)
summary(diamonds)
## carat cut color clarity
## Min. :0.2000 Fair : 1610 D: 6775 SI1 :13065
## 1st Qu.:0.4000 Good : 4906 E: 9797 VS2 :12258
## Median :0.7000 Ideal :21551 F: 9542 SI2 : 9194
## Mean :0.7979 Premium :13791 G:11292 VS1 : 8171
## 3rd Qu.:1.0400 Very Good:12082 H: 8304 VVS2 : 5066
## Max. :5.0100 I: 5422 VVS1 : 3655
## J: 2808 (Other): 2531
## depth table price x
## Min. :43.00 Min. :43.00 Min. : 326 Min. : 0.000
## 1st Qu.:61.00 1st Qu.:56.00 1st Qu.: 950 1st Qu.: 4.710
## Median :61.80 Median :57.00 Median : 2401 Median : 5.700
## Mean :61.75 Mean :57.46 Mean : 3933 Mean : 5.731
## 3rd Qu.:62.50 3rd Qu.:59.00 3rd Qu.: 5324 3rd Qu.: 6.540
## Max. :79.00 Max. :95.00 Max. :18823 Max. :10.740
##
## y z
## Min. : 0.000 Min. : 0.000
## 1st Qu.: 4.720 1st Qu.: 2.910
## Median : 5.710 Median : 3.530
## Mean : 5.735 Mean : 3.539
## 3rd Qu.: 6.540 3rd Qu.: 4.040
## Max. :58.900 Max. :31.800
##
This simple function uses KNN algorithm for a given K and subset of columns of diamonds dataset. It returns the leave-one-out cross-validation (LOOCV) sum of squared residuals.
run.knn <- function(params) {
library(FNN)
library(magrittr)
k <- params$k
columns <- unlist(params$columns[[1]])
id <- paste0(sort(columns), collapse = '')
diamonds <- read.csv("/dbfs/tmp/diamonds.csv", header = T)
temp.train.df <- diamonds %>% subset(select = -price) %>% sapply(as.numeric) %>% as.data.frame
train.df <- temp.train.df[, columns]
response <- diamonds$price
knn.fit <- knn.reg(train = train.df, y = response, k = k)
c(PRESS = sum(knn.fit$PRESS), R2pred = knn.fit$R2Pred, numColumns = length(columns), k = k, id = id)
}
We will be using grid-search on top of Apache Spark to find the optimal solution. To do so, we first build a grid of all possible combinations of feature subsets and vary K from 1 to 25. This results in 12,775 possibilities.
library(combinat)
##
## Attaching package: 'combinat'
## The following object is masked from 'package:utils':
##
## combn
col.search <- sapply(1:ncol(temp.data), function(n) {
combn(1:ncol(temp.data), n) %>% as.data.frame %>% t %>% split(seq(nrow(.)))
}) %>% unlist(recursive=F)
k <- 1:25
grid <- expand.grid(columns = col.search, k = k) %>% split(seq(nrow(.)))
length(grid)
## [1] 12775
To compare test-error across all these possibilities we will parallelize it using SparkR::spark.lapply() function. We construct a data.frame of input parameters and corresponding test error values.
#library(SparkR)
#sparkR.session()
grid.search <- spark.lapply(grid, run.knn)
id <- sapply(grid.search, function(x) { x[[5]] })
k <- sapply(grid.search, function(x) { x[[4]] })
n <- sapply(grid.search, function(x) { x[[3]] })
r2 <- sapply(grid.search, function(x) { x[[2]] })
press <- sapply(grid.search, function(x) { x[[1]] })
A key part of identifying optimal values is visual exploration of the result, which we can do using plot.ly and ggplot2.
plot.data <- data.frame(
id = as.factor(id),
k = as.numeric(k),
n = as.factor(n),
press = as.numeric(press))
library(plotly)
##
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
##
## last_plot
## The following objects are masked from 'package:SparkR':
##
## arrange, distinct, filter, group_by, mutate, rename, schema,
## select
## The following object is masked from 'package:stats':
##
## filter
## The following object is masked from 'package:graphics':
##
## layout
plot_ly(plot.data, x = ~k, y = ~n, z = ~press, color = ~n) %>% add_markers()
## Warning in RColorBrewer::brewer.pal(N, "Set2"): n too large, allowed maximum for palette Set2 is 8
## Returning the palette you asked for with that many colors
library(plyr)
##
## Attaching package: 'plyr'
## The following objects are masked from 'package:plotly':
##
## arrange, mutate, rename, summarise
## The following objects are masked from 'package:SparkR':
##
## arrange, count, desc, join, mutate, rename, summarize, take
min.surface <- ddply(plot.data, .(k, n), summarize, press = min(press))
plot_ly(min.surface, x = ~k, y = ~n, z = ~press, color = ~n) %>% add_markers()
## Warning in RColorBrewer::brewer.pal(N, "Set2"): n too large, allowed maximum for palette Set2 is 8
## Returning the palette you asked for with that many colors
ggplot(min.surface, aes(k, press, color = n)) +
geom_point() +
geom_line(aes(group = n)) +
geom_label(data = subset(min.surface, k == 1), aes(label = n)) +
theme_bw() + geom_hline(yintercept = min(plot.data$press), size = 1.2)
Finally we can find the minimum value to identify optimal combinations of columns and K.
plot.data[which(plot.data$press == min(plot.data$press)), ]
## id k n press
## 3119 134 7 3 17976133574
diamonds[, c(1, 3, 4)] %>% head
## carat color clarity
## 1 0.23 E SI2
## 2 0.21 E SI1
## 3 0.23 E VS1
## 4 0.29 I VS2
## 5 0.31 J SI2
## 6 0.24 J VVS2