有状态应用程序示例

本文包含自定义有状态应用程序的代码示例。 Databricks 建议对常见操作(如聚合和连接)使用内置有状态方法。

本文中的模式使用 transformWithState 运算符和相关类,这些运算符和类在 Databricks Runtime 16.2 及更高版本的公共预览中可用。 请参阅 生成自定义有状态应用程序

注释

Python 使用 transformWithStateInPandas 运算符提供相同的功能。 以下示例提供 Python 和 Scala 中的代码。

要求

transformWithState运算符和相关 API 和类具有以下要求:

  • 在 Databricks Runtime 16.2 及更高版本中可用。
  • 计算必须使用专用或无隔离访问模式。
  • 必须使用 RocksDB 状态存储提供程序。 Databricks 建议在计算配置过程中启用 RocksDB。
  • transformWithStateInPandas 支持 Databricks Runtime 16.3 及更高版本中的标准访问模式。

注释

若要为当前会话启用 RocksDB 状态存储提供程序,请运行以下命令:

spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

缓慢变化维度 (SCD) 类型 1

以下代码是一个使用 transformWithStateSCD 类型 1 实现的示例。 SCD 类型 1 仅跟踪给定字段的最新值。

注释

可以使用由 Delta Lake 支持的表通过流式处理表和 APPLY CHANGES INTO 来实现 SCD 类型 1 或类型 2。 此示例在状态存储中实现 SCD 类型 1,为准实时应用程序提供较低的延迟。

Python语言

# Import the necessary libraries
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator

# Set the state store provider to RocksDB
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

# Define the output schema for the streaming query
output_schema = StructType([
    StructField("user", StringType(), True),
    StructField("time", LongType(), True),
    StructField("location", StringType(), True)
])

# Define a custom StatefulProcessor for slowly changing dimension type 1 (SCD1) operations
class SCDType1StatefulProcessor(StatefulProcessor):
    def init(self, handle: StatefulProcessorHandle) -> None:
        self.handle = handle
        # Define the schema for the state value
        value_state_schema = StructType([...])
            StructField("user", StringType(), True),
            StructField("time", LongType(), True),
            StructField("location", StringType(), True)
        ])
        # Initialize the state to store the latest location for each user
        self.latest_location = handle.getValueState("latestLocation", value_state_schema)

    def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]:
        # Find the row with the maximum time value
        max_row = None
        max_time = float('-inf')
        for pdf in rows:
            for _, pd_row in pdf.iterrows():
                time_value = pd_row["time"]
                if time_value > max_time:
                    max_time = time_value
                    max_row = tuple(pd_row)

        # Check whether state exists and update if necessary
        exists = self.latest_location.exists()
        if not exists or max_row[1] > self.latest_location.get()[1]:
            # Update the state with the new max row
            self.latest_location.update(max_row)
            # Yield the updated row
            yield pd.DataFrame(
                {"user": (max_row[0],), "time": (max_row[1],), "location": (max_row[2],)}
            )
        # Yield an empty DataFrame if no update is needed
        yield pd.DataFrame()

    def close(self) -> None:
        # No cleanup needed
        pass

# Apply the stateful transformation to the input DataFrame
(df.groupBy("user")
  .transformWithStateInPandas(
      statefulProcessor=SCDType1StatefulProcessor(),
      outputStructType=output_schema,
      outputMode="Update",
      timeMode="None",
  )
  .writeStream...  # Continue with stream writing configuration
)

Scala(编程语言)

// Define a case class to represent user location data
case class UserLocation(
    user: String,
    time: Long,
    location: String)

// Define a stateful processor for slowly changing dimension type 1 (SCD1) operations
class SCDType1StatefulProcessor extends StatefulProcessor[String, UserLocation, UserLocation] {
  import org.apache.spark.sql.{Encoders}

  // Transient value state to store the latest location for each user
  @transient private var _latestLocation: ValueState[UserLocation] = _

  // Initialize the state store
  override def init(
      outputMode: OutputMode,
      timeMode: TimeMode): Unit = {
    // Create a value state named "locationState" using UserLocation encoder
    // TTLConfig.NONE means the state has no expiration
    _latestLocation = getHandle.getValueState[UserLocation]("locationState",
      Encoders.product[UserLocation], TTLConfig.NONE)
  }

  // Process input rows and update state
  override def handleInputRows(
      key: String,
      inputRows: Iterator[UserLocation],
      timerValues: TimerValues): Iterator[UserLocation] = {
    // Find the location with the maximum timestamp from input rows
    val maxNewLocation = inputRows.maxBy(_.time)

    // Update state and emit output if:
    // 1. No previous state exists, or
    // 2. New location has a more recent timestamp than the stored one
    if (_latestLocation.getOption().isEmpty || maxNewLocation.time > _latestLocation.get().time) {
      _latestLocation.update(maxNewLocation)
      Iterator.single(maxNewLocation)  // Emit the updated location
    } else {
      Iterator.empty  // No update needed, emit nothing
    }
  }
}
}

缓慢变化维度 (SCD) 类型 2

以下笔记本包含一个示例,展示如何在 Python 或 Scala 中使用 transformWithState 实现 SCD 类型 2。

SCD 类型 2 Python

获取笔记本

SCD 类型 2 Scala

获取笔记本

故障时间检测器

transformWithState 可实现计时器,以支持你根据已用时间执行操作,即使在微批处理中未处理给定键的任何记录也是如此。

以下示例展示停机检测器模式的实现。 每次看到给定键的新值时,都会更新 lastSeen 状态值,清除任何现有计时器,并为将来重置计时器。

计时器过期时,应用程序将发出自密钥上次观察到事件以来经过的时间。 然后,它会设置新的计时器,以在 10 秒后发出更新。

Python语言

import datetime
import time

class DownTimeDetectorStatefulProcessor(StatefulProcessor):
    def init(self, handle: StatefulProcessorHandle) -> None:
        # Define the schema for the state value (timestamp)
        state_schema = StructType([StructField("value", TimestampType(), True)])
        self.handle = handle
        # Initialize state to store the last seen timestamp for each key
        self.last_seen = handle.getValueState("last_seen", state_schema)

    def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
        latest_from_existing = self.last_seen.get()
        # Calculate downtime duration
        downtime_duration = timerValues.getCurrentProcessingTimeInMs() - int(time.time() * 1000)
        # Register a new timer for 10 seconds in the future
        self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
        # Yield a DataFrame with the key and downtime duration
        yield pd.DataFrame(
            {
                "id": key,
                "timeValues": str(downtime_duration),
            }
        )

    def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
        # Find the row with the maximum timestamp
        max_row = max((tuple(pdf.iloc[0]) for pdf in rows), key=lambda row: row[1])

        # Get the latest timestamp from the existing state or use epoch start if a timestamp doesn't exist
        if self.last_seen.exists():
            latest_from_existing = self.last_seen.get()
        else:
            latest_from_existing = datetime.fromtimestamp(0)

        # If the new data is more recent than the existing state
        if latest_from_existing < max_row[1]:
            # Delete all existing timers
            for timer in self.handle.listTimers():
                self.handle.deleteTimer(timer)
            # Update the last seen timestamp
            self.last_seen.update((max_row[1],))

        # Register a new timer for 5 seconds in the future
        self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000)

        # Get current processing time in milliseconds
        timestamp_in_millis = str(timerValues.getCurrentProcessingTimeInMs())

        # Yield a DataFrame with the key and current timestamp
        yield pd.DataFrame({"id": key, "timeValues": timestamp_in_millis})

    def close(self) -> None:
        # No cleanup needed
        pass

Scala(编程语言)

import java.sql.Timestamp
import org.apache.spark.sql.Encoders

// The (String, Timestamp) schema represents an (id, time). We want to do downtime
// detection on every single unique sensor, where each sensor has a sensor ID.
class DowntimeDetector(duration: Duration) extends
  StatefulProcessor[String, (String, Timestamp), (String, Duration)] {

  @transient private var _lastSeen: ValueState[Timestamp] = _

  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
    _lastSeen = getHandle.getValueState[Timestamp]("lastSeen", Encoders.TIMESTAMP, TTLConfig.NONE)
  }

  // The logic here is as follows: find the largest timestamp seen so far. Set a timer for
  // the duration later.
  override def handleInputRows(
      key: String,
      inputRows: Iterator[(String, Timestamp)],
      timerValues: TimerValues): Iterator[(String, Duration)] = {
    val latestRecordFromNewRows = inputRows.maxBy(_._2.getTime)

    // Use getOrElse to initiate state variable if it doesn't exist
    val latestTimestampFromExistingRows = _lastSeen.getOption().getOrElse(new Timestamp(0))
    val latestTimestampFromNewRows = latestRecordFromNewRows._2

    if (latestTimestampFromNewRows.after(latestTimestampFromExistingRows)) {
      // Cancel the one existing timer, since we have a new latest timestamp.
      // We call "listTimers()" because we don't know ahead of time what
      // the timestamp of the existing timer will be.
      getHandle.listTimers().foreach(timer => getHandle.deleteTimer(timer))

      _lastSeen.update(latestTimestampFromNewRows)
      // Use timerValues to schedule a timer using processing time.
      getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + duration.toMillis)
    } else {
      // No new latest timestamp, so there is no need to update the state or set a timer.
    }

    Iterator.empty
  }

  override def handleExpiredTimer(
    key: String,
    timerValues: TimerValues,
    expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Duration)] = {
      val latestTimestamp = _lastSeen.get()
      val downtimeDuration = new Duration(
        timerValues.getCurrentProcessingTimeInMs() - latestTimestamp.getTime)

      // Register another timer that will fire in 10 seconds.
      // Timers can be registered anywhere but init()
      getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)

      Iterator((key, downtimeDuration))
  }
}

迁移现有状态信息

以下示例演示如何实现接受初始状态的有状态应用程序。 可以将初始状态处理添加到任何有状态应用程序,但初始状态只能在首次初始化应用程序时设置。

此示例使用 statestore 读取器从检查点路径加载现有状态信息。 此模式的示例用例是从旧有状态应用程序迁移到 transformWithState的。

Python语言

# Import the necessary libraries
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType, IntegerType
from typing import Iterator

# Set RocksDB as the state store provider for better performance
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

"""
Input schema is as below

input_schema = StructType(
    [StructField("id", StringType(), True)],
    [StructField("value", StringType(), True)]
)
"""

# Define the output schema for the streaming query
output_schema = StructType([
    StructField("id", StringType(), True),
    StructField("accumulated", StringType(), True)
])

class AccumulatedCounterStatefulProcessorWithInitialState(StatefulProcessor):

    def init(self, handle: StatefulProcessorHandle) -> None:
        # Define the schema for the state value (integer)
        state_schema = StructType([StructField("value", IntegerType(), True)])
        # Initialize state to store the accumulated counter for each id
        self.counter_state = handle.getValueState("counter_state", state_schema)
        self.handle = handle

    def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
        # Check if state exists for the current key
        exists = self.counter_state.exists()
        if exists:
            value_row = self.counter_state.get()
            existing_value = value_row[0]
        else:
            existing_value = 0

        accumulated_value = existing_value

        # Process input rows and accumulate values
        for pdf in rows:
            value = pdf["value"].astype(int).sum()
            accumulated_value += value

        # Update the state with the new accumulated value
        self.counter_state.update((accumulated_value,))

        # Yield a DataFrame with the key and accumulated value
        yield pd.DataFrame({"id": key, "accumulated": str(accumulated_value)})

    def handleInitialState(self, key, initialState, timerValues) -> None:
        # Initialize the state with the provided initial value
        init_val = initialState.at[0, "initVal"]
        self.counter_state.update((init_val,))

    def close(self) -> None:
        # No cleanup needed
        pass

# Load initial state from a checkpoint directory
initial_state = spark.read.format("statestore")
  .option("path", "$checkpointsDir")
  .load()

# Apply the stateful transformation to the input DataFrame
df.groupBy("id")
  .transformWithStateInPandas(
      statefulProcessor=AccumulatedCounterStatefulProcessorWithInitialState(),
      outputStructType=output_schema,
      outputMode="Update",
      timeMode="None",
      initialState=initial_state,
  )
  .writeStream...  # Continue with stream writing configuration

Scala(编程语言)

// Import the necessary libraries
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.{Dataset, Encoder, Encoders, DataFrame}
import org.apache.spark.sql.types._

// Define a stateful processor that can handle the initial state
class InitialStateStatefulProcessor extends StatefulProcessorWithInitialState[String, (String, String, String), (String, String), (String, Int)] {
  // Transient value state to store the accumulated value
  @transient protected var valueState: ValueState[Int] = _

  // Initialize the state store
  override def init(
      outputMode: OutputMode,
      timeMode: TimeMode): Unit = {
    // Create a value state named "valueState" using Int encoder
    // TTLConfig.NONE means the state has no automatic expiration
    valueState = getHandle.getValueState[Int]("valueState",
      Encoders.scalaInt, TTLConfig.NONE)
  }

  // Process input rows and update state
  override def handleInputRows(
      key: String,
      inputRows: Iterator[(String, String, String)],
      timerValues: TimerValues): Iterator[(String, String)] = {
    var existingValue = 0
    // Retrieve existing value from state if it exists
    if (valueState.exists()) {
      existingValue += valueState.get()
    }
    var accumulatedValue = existingValue
    // Accumulate values from input rows
    for (row <- inputRows) {
      accumulatedValue += row._2.toInt
    }
    // Update the state with the new accumulated value
    valueState.update(accumulatedValue)
    // Return the key and accumulated value as a string
    Iterator((key, accumulatedValue.toString))
  }

  // Handle initial state when provided
  override def handleInitialState(
      key: String, initialState: (String, Int), timerValues: TimerValues): Unit = {
    // Update the state with the initial value
    valueState.update(initialState._2)
  }
}

将 Delta 表迁移到用于初始化的状态存储

以下笔记本包含一个在 Python 或 Scala 中使用 Delta 表 transformWithState 初始化状态存储值的示例。

从 Delta Python 初始化状态

获取笔记本

从 Delta Scala 初始化状态

获取笔记本

会话跟踪

以下笔记本包含在 Python 或 Scala 中使用 transformWithState 的会话跟踪示例。

会话跟踪 Python

获取笔记本

会话跟踪 Scala

获取笔记本

使用 transformWithState 的自定义流-流联接

以下代码演示了使用 transformWithState 跨多个流的自定义流-流联接。 出于以下原因,可以使用此方法而不是内置联接运算符:

  • 需要使用不支持流-流联接的更新输出模式。 这对于较低的延迟应用程序尤其有用。
  • 需要继续对延迟到达行执行联接(水印过期之后)。
  • 需要执行多对多流-流联接。

此示例向用户提供对状态过期逻辑的完全控制,这会允许执行动态保留期扩展,从而可以处理无序事件,即使它发生在水印之后也是如此。

Python语言

# Import the necessary libraries
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, StringType, TimestampType
from typing import Iterator

# Define output schema for the joined data
output_schema = StructType([
    StructField("user_id", StringType(), True),
    StructField("event_type", StringType(), True),
    StructField("timestamp", TimestampType(), True),
    StructField("profile_name", StringType(), True),
    StructField("email", StringType(), True),
    StructField("preferred_category", StringType(), True)
])

class CustomStreamJoinProcessor(StatefulProcessor):
    # Initialize stateful storage for user profiles, preferences, and event tracking.
    def init(self, handle: StatefulProcessorHandle) -> None:

        # Define schemas for different types of state data
        profile_schema = StructType([
            StructField("name", StringType(), True),
            StructField("email", StringType(), True),
            StructField("updated_at", TimestampType(), True)
        ])
        preferences_schema = StructType([
            StructField("preferred_category", StringType(), True),
            StructField("updated_at", TimestampType(), True)
        ])
        activity_schema = StructType([
            StructField("event_type", StringType(), True),
            StructField("timestamp", TimestampType(), True)
        ])

        # Initialize state storage for user profiles, preferences, and activity
        self.profile_state = handle.getMapState("user_profiles", "string", profile_schema)
        self.preferences_state = handle.getMapState("user_preferences", "string", preferences_schema)
        self.activity_state = handle.getMapState("user_activity", "string", activity_schema)

    # Process incoming events and update the state
    def handleInputRows(self, key, rows: Iterator[pd.DataFrame], timer_values) -> Iterator[pd.DataFrame]:
        df = pd.concat(rows, ignore_index=True)
        output_rows = []

        for _, row in df.iterrows():
            user_id = row["user_id"]

            if "event_type" in row:  # User activity event
                self.activity_state.update_value(user_id, row.to_dict())
                # Set a timer to process this event after a 10-second delay
                self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() + (10 * 1000))

            elif "name" in row:  # Profile update
                self.profile_state.update_value(user_id, row.to_dict())

            elif "preferred_category" in row:  # Preference update
                self.preferences_state.update_value(user_id, row.to_dict())

        # No immediate output; processing will happen when the timer expires
        return iter([])

    # Perform lookup after delay, handling out-of-order and late-arriving events.
    def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> Iterator[pd.DataFrame]:

        # Retrieve stored state for the user
        user_activity = self.activity_state.get_value(key)
        user_profile = self.profile_state.get_value(key)
        user_preferences = self.preferences_state.get_value(key)

        if user_activity:
            # Combine data from different states into a single output row
            output_row = {
                "user_id": key,
                "event_type": user_activity["event_type"],
                "timestamp": user_activity["timestamp"],
                "profile_name": user_profile.get("name") if user_profile else None,
                "email": user_profile.get("email") if user_profile else None,
                "preferred_category": user_preferences.get("preferred_category") if user_preferences else None
            }
            return iter([pd.DataFrame([output_row])])

        return iter([])

    def close(self) -> None:
        # No cleanup needed
        pass

# Apply transformWithState to the input DataFrame
(df.groupBy("user_id")
  .transformWithStateInPandas(
      statefulProcessor=CustomStreamJoinProcessor(),
      outputStructType=output_schema,
      outputMode="Append",
      timeMode="ProcessingTime"
  )
  .writeStream...  # Continue with stream writing configuration
)

Scala(编程语言)

// Import the necessary libraries
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.TimestampType
import java.sql.Timestamp

// Define a case class for enriched user events, combining user activity with profile and preference data
case class EnrichedUserEvent(
    user_id: String,
    event_type: String,
    timestamp: Timestamp,
    profile_name: Option[String],
    email: Option[String],
    preferred_category: Option[String]
)

// Custom stateful processor for stream-stream join
class CustomStreamJoinProcessor extends StatefulProcessor[String, UserEvent, EnrichedUserEvent] {
  // Transient state variables to store user profiles, preferences, and activities
  @transient private var _profileState: MapState[String, UserProfile] = _
  @transient private var _preferencesState: MapState[String, UserPreferences] = _
  @transient private var _activityState: MapState[String, UserEvent] = _

  // Initialize state stores
  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
    _profileState = getHandle.getMapState[String, UserProfile]("profileState", Encoders.product[UserProfile], TTLConfig.NONE)
    _preferencesState = getHandle.getMapState[String, UserPreferences]("preferencesState", Encoders.product[UserPreferences], TTLConfig.NONE)
    _activityState = getHandle.getMapState[String, UserEvent]("activityState", Encoders.product[UserEvent], TTLConfig.NONE)
  }

  // Handle incoming user events
  override def handleInputRows(
      key: String,
      inputRows: Iterator[UserEvent],
      timerValues: TimerValues): Iterator[EnrichedUserEvent] = {

    inputRows.foreach { event =>
      if (event.event_type.nonEmpty) {
        // Update activity state and set a timer for 10 seconds in the future
        _activityState.update(key, event)
        getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
      }
    }
    Iterator.empty
  }

  // Handle expired timers to produce enriched events
  override def handleExpiredTimer(
      key: String,
      timerValues: TimerValues,
      expiredTimerInfo: ExpiredTimerInfo): Iterator[EnrichedUserEvent] = {

    // Retrieve user data from state stores
    val userEvent = _activityState.getOption(key)
    val userProfile = _profileState.getOption(key)
    val userPreferences = _preferencesState.getOption(key)

    if (userEvent.isDefined) {
      // Create and return an enriched event if user activity exists
      val enrichedEvent = EnrichedUserEvent(
        user_id = key,
        event_type = userEvent.get.event_type,
        timestamp = userEvent.get.timestamp,
        profile_name = userProfile.map(_.name),
        email = userProfile.map(_.email),
        preferred_category = userPreferences.map(_.preferred_category)
      )
      Iterator.single(enrichedEvent)
    } else {
      Iterator.empty
    }
  }
}

// Apply the custom stateful processor to the input DataFrame
val enrichedStream = df
  .groupByKey(_.user_id)
  .transformWithState(
    new CustomStreamJoinProcessor(),
    TimeMode.ProcessingTime(),
    OutputMode.Append()
  )

// Write the enriched stream to Delta Lake
enrichedStream.writeStream
  .format("delta")
  .outputMode("append")
  .option("checkpointLocation", "/mnt/delta/checkpoints")
  .start("/mnt/delta/enriched_events")

Top-K 计算

以下示例使用具有优先级队列的 ListState,以近乎实时地维护和更新每个组键的数据流中的前 K 个元素。

Top-K Python

获取笔记本

Top-K Scala

获取笔记本