다음을 통해 공유

用于 Scala 的 Databricks Connect 中的用户定义的函数

注释

本文介绍 Databricks Connect for Databricks Runtime 14.1 及更高版本。

Databricks Connect for Scala 支持从本地开发环境在 Databricks 群集上运行用户定义的函数(UDF)。

本页介绍如何使用用于 Scala 的 Databricks Connect 执行用户定义的函数。

有关本文的 Python 版本,请参阅 Databricks Connect for Python 中的用户定义的函数。

上传已编译的类和 JAR

若要使 UDF 正常工作,必须使用 API 将已编译的类和 JAR 上传到群集 addCompiledArtifacts()

注释

客户端使用的 Scala 必须与 Azure Databricks 群集上的 Scala 版本匹配。 若要查看群集的 Scala 版本,请参阅 Databricks Runtime 发行说明版本和兼容性中群集 Databricks Runtime 版本的“系统环境”部分。

以下 Scala 程序设置了一个简单的 UDF,用于对列中的值进行平方运算。

import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}

object Main {
  def main(args: Array[String]): Unit = {
    val spark = getSession()

    val squared = udf((x: Long) => x * x)

    spark.range(3)
      .withColumn("squared", squared(col("id")))
      .select("squared")
      .show()

    }
  }

  def getSession(): SparkSession = {
    if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
      // On a Databricks cluster — reuse the active session
      SparkSession.active
    } else {
      // Locally with Databricks Connect — upload local JARs and classes
      DatabricksSession
        .builder()
        .addCompiledArtifacts(
          Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI
        )
        .getOrCreate()
    }
  }
}

Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI 指向与项目的编译输出相同的位置(例如,目标/类或生成的 JAR)。 所有编译的类都上传到 Databricks,而不仅仅是 Main

target/scala-2.13/classes/
├── com/
│   ├── examples/
│   │   ├── Main.class
│   │   └── MyUdfs.class
│   └── utils/
│       └── Helper.class

初始化 Spark 会话后,可以使用 spark.addArtifact() API 上传进一步编译的类和 JAR。

注释

上传 JAR 时,必须包括所有可传递依赖项的 JAR。 API 不会对传递性依赖执行任何自动检测。

具有第三方依赖项的 UDF

如果你已经在build.sbt中添加了一个 Maven 依赖项,并且在 UDF 中使用,但它在 Databricks 群集上不可用,例如:

// In build.sbt
libraryDependencies += "org.apache.commons" % "commons-text" % "1.10.0"
// In your code
import org.apache.commons.text.StringEscapeUtils

// ClassNotFoundException thrown during UDF execution of this function on the server side
val escapeUdf = udf((text: String) => {
  StringEscapeUtils.escapeHtml4(text)
})

spark.addArtifact()ivy://从 Maven 下载依赖项

  1. oro 库添加到 build.sbt 文件

    libraryDependencies ++= Seq(
      "org.apache.commons" % "commons-text" % "1.10.0" % Provided,
      "oro" % "oro" % "2.0.8"  // Required for ivy:// to work
    )
    
  2. 使用 addArtifact() API 创建会话后添加项目:

    def getSession(): SparkSession = {
      if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
        SparkSession.active
      } else {
        val spark = DatabricksSession.builder()
          .addCompiledArtifacts(Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI)
          .getOrCreate()
    
        // Convert Maven coordinates to ivy:// format
        // From: "org.apache.commons" % "commons-text" % "1.10.0"
        // To:   ivy://org.apache.commons:commons-text:1.10.0
        spark.addArtifact("ivy://org.apache.commons:commons-text:1.10.0")
    
        spark
      }
    }
    

类型化数据集 API

类型化数据集 API 允许对生成的数据集运行转换,例如map()filter()mapPartitions()和聚合。 使用 addCompiledArtifacts() API 将编译的类和 JAR 上传到群集也适用于这些类,因此代码的行为必须有所不同,具体取决于其运行位置:

  • 使用 Databricks Connect 进行本地开发:将项目上传到远程群集。
  • 部署在群集上运行的 Databricks 上:无需上传任何内容,因为类已存在。

以下 Scala 应用程序使用 map() API 将结果列中的数字修改为前缀字符串。

import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}

object Main {
  def main(args: Array[String]): Unit = {
    val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI

    val spark = DatabricksSession.builder()
      .addCompiledArtifacts(sourceLocation)
      .getOrCreate()

    spark.range(3).map(f => s"row-$f").show()
  }
}

外部 JAR 依赖项

如果使用不在群集上的专用或第三方库:

import com.mycompany.privatelib.DataProcessor

// ClassNotFoundException thrown during UDF execution of this function on the server side
val myUdf = udf((data: String) => {
  DataProcessor.process(data)
})

创建会话时,从文件夹 lib/ 上传外部 JAR:

def getSession(): SparkSession = {
  if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
    SparkSession.active
  } else {
    val builder = DatabricksSession.builder()
      .addCompiledArtifacts(Main.getClass.getProtectionDomain.getCodeSource.getLocation.toURI)

     // Add all JARs from lib/ folder
     val libFolder = new java.io.File("lib")
     builder.addCompiledArtifacts(libFolder.toURI)

   builder.getOrCreate()
  }
}

这会在本地运行时自动将 lib/ 目录中的所有 JAR 上传到 Databricks。

具有多个模块的项目

在多模块 SBT 项目中, getClass.getProtectionDomain.getCodeSource.getLocation.toURI 仅返回当前模块的位置。 如果 UDF 使用来自其他模块的类,你将获得 ClassNotFoundException

my-project/
├── module-a/  (main application)
├── module-b/  (utilities - module-a depends on this)

使用 getClass 来从每个模块中的类中获取所有位置,并分别上传它们。

// In module-a/src/main/scala/Main.scala
import com.company.moduleb.DataProcessor  // From module-b

def getSession(): SparkSession = {
  if (sys.env.contains("DATABRICKS_RUNTIME_VERSION")) {
    SparkSession.active
  } else {
    // Get location using a class FROM module-a
    val moduleALocation = Main.getClass
      .getProtectionDomain.getCodeSource.getLocation.toURI

    // Get location using a class FROM module-b
    val moduleBLocation = DataProcessor.getClass
      .getProtectionDomain.getCodeSource.getLocation.toURI

    DatabricksSession.builder()
      .addCompiledArtifacts(moduleALocation)  // Upload module-a
      .addCompiledArtifacts(moduleBLocation)  // Upload module-b
      .getOrCreate()
  }
}