注释
本文介绍 Databricks Connect for Databricks Runtime 14.1 及更高版本。
Databricks Connect for Scala 支持从本地开发环境在 Databricks 群集上运行用户定义的函数(UDF)。
本页介绍如何使用用于 Scala 的 Databricks Connect 执行用户定义的函数。
有关本文的 Python 版本,请参阅 Databricks Connect for Python 中的用户定义的函数。
上传已编译的类和 JAR
若要使 UDF 正常工作,必须使用 API 将已编译的类和 JAR 上传到群集 addCompiledArtifacts() 。
注释
客户端使用的 Scala 必须与 Azure Databricks 群集上的 Scala 版本匹配。 若要查看群集的 Scala 版本,请参阅 Databricks Runtime 发行说明版本和兼容性中群集 Databricks Runtime 版本的“系统环境”部分。
以下 Scala 程序设置了一个简单的 UDF,用于对列中的值进行平方运算。
import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}
object Main {
def main(args: Array[String]): Unit = {
val spark = getSession()
val squared = udf((x: Long) => x * x)
spark.range(3)
.withColumn("squared", squared(col("id")))
.select("squared")
.show()
}
}
def getSession(): SparkSession = {
if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
// On a Databricks cluster — reuse the active session
SparkSession.active
} else {
// Locally with Databricks Connect — upload local JARs and classes
DatabricksSession
.builder()
.addCompiledArtifacts(
Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI
)
.getOrCreate()
}
}
}
Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI 指向与项目的编译输出相同的位置(例如,目标/类或生成的 JAR)。 所有编译的类都上传到 Databricks,而不仅仅是 Main。
target/scala-2.13/classes/
├── com/
│ ├── examples/
│ │ ├── Main.class
│ │ └── MyUdfs.class
│ └── utils/
│ └── Helper.class
初始化 Spark 会话后,可以使用 spark.addArtifact() API 上传进一步编译的类和 JAR。
注释
上传 JAR 时,必须包括所有可传递依赖项的 JAR。 API 不会对传递性依赖执行任何自动检测。
具有第三方依赖项的 UDF
如果你已经在build.sbt中添加了一个 Maven 依赖项,并且在 UDF 中使用,但它在 Databricks 群集上不可用,例如:
// In build.sbt
libraryDependencies += "org.apache.commons" % "commons-text" % "1.10.0"
// In your code
import org.apache.commons.text.StringEscapeUtils
// ClassNotFoundException thrown during UDF execution of this function on the server side
val escapeUdf = udf((text: String) => {
StringEscapeUtils.escapeHtml4(text)
})
用spark.addArtifact()和ivy://从 Maven 下载依赖项
将
oro库添加到build.sbt文件libraryDependencies ++= Seq( "org.apache.commons" % "commons-text" % "1.10.0" % Provided, "oro" % "oro" % "2.0.8" // Required for ivy:// to work )使用
addArtifact()API 创建会话后添加项目:def getSession(): SparkSession = { if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) { SparkSession.active } else { val spark = DatabricksSession.builder() .addCompiledArtifacts(Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI) .getOrCreate() // Convert Maven coordinates to ivy:// format // From: "org.apache.commons" % "commons-text" % "1.10.0" // To: ivy://org.apache.commons:commons-text:1.10.0 spark.addArtifact("ivy://org.apache.commons:commons-text:1.10.0") spark } }
类型化数据集 API
类型化数据集 API 允许对生成的数据集运行转换,例如map(),filter()mapPartitions()和聚合。 使用 addCompiledArtifacts() API 将编译的类和 JAR 上传到群集也适用于这些类,因此代码的行为必须有所不同,具体取决于其运行位置:
- 使用 Databricks Connect 进行本地开发:将项目上传到远程群集。
- 部署在群集上运行的 Databricks 上:无需上传任何内容,因为类已存在。
以下 Scala 应用程序使用 map() API 将结果列中的数字修改为前缀字符串。
import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}
object Main {
def main(args: Array[String]): Unit = {
val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI
val spark = DatabricksSession.builder()
.addCompiledArtifacts(sourceLocation)
.getOrCreate()
spark.range(3).map(f => s"row-$f").show()
}
}
外部 JAR 依赖项
如果使用不在群集上的专用或第三方库:
import com.mycompany.privatelib.DataProcessor
// ClassNotFoundException thrown during UDF execution of this function on the server side
val myUdf = udf((data: String) => {
DataProcessor.process(data)
})
创建会话时,从文件夹 lib/ 上传外部 JAR:
def getSession(): SparkSession = {
if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
SparkSession.active
} else {
val builder = DatabricksSession.builder()
.addCompiledArtifacts(Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI)
// Add all JARs from lib/ folder
val libFolder = new java.io.File("lib")
builder.addCompiledArtifacts(libFolder.toURI)
builder.getOrCreate()
}
}
这会在本地运行时自动将 lib/ 目录中的所有 JAR 上传到 Databricks。
具有多个模块的项目
在多模块 SBT 项目中, getClass.getProtectionDomain.getCodeSource.getLocation.toURI 仅返回当前模块的位置。 如果 UDF 使用来自其他模块的类,你将获得 ClassNotFoundException。
my-project/
├── module-a/ (main application)
├── module-b/ (utilities - module-a depends on this)
使用 getClass 来从每个模块中的类中获取所有位置,并分别上传它们。
// In module-a/src/main/scala/Main.scala
import com.company.moduleb.DataProcessor // From module-b
def getSession(): SparkSession = {
if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
SparkSession.active
} else {
// Get location using a class FROM module-a
val moduleALocation = Main.getClass
.getProtectionDomain.getCodeSource.getLocation.toURI
// Get location using a class FROM module-b
val moduleBLocation = DataProcessor.getClass
.getProtectionDomain.getCodeSource.getLocation.toURI
DatabricksSession.builder()
.addCompiledArtifacts(moduleALocation) // Upload module-a
.addCompiledArtifacts(moduleBLocation) // Upload module-b
.getOrCreate()
}
}