Nota:
El acceso a esta página requiere autorización. Puede intentar iniciar sesión o cambiar directorios.
El acceso a esta página requiere autorización. Puede intentar cambiar los directorios.
本文包含自定义有状态应用程序的代码示例。 Databricks 建议对常见操作(如聚合和连接)使用内置有状态方法。
本文中的模式使用 transformWithState Databricks Runtime 16.2 及更高版本中提供的运算符和相关类。 请参阅 生成自定义有状态应用程序。
注释
Python 支持基于 transformWithState 行的 API(在微批处理模式和实时模式下可用)和基于 transformWithStateInPandas Pandas 的运算符。 以下示例提供在 transformWithStateInPandas Python 和 transformWithState Scala 中的代码。
要求
transformWithState运算符和相关 API 和类具有以下要求:
- 在 Databricks Runtime 16.2 及更高版本中可用。
- Databricks Runtime 16.3 及更高版本支持 Python(
transformWithStateInPandas以及基于行的transformWithState)的标准访问模式,Databricks Runtime 17.3 及更高版本支持 Scala(transformWithState)的标准访问模式。 - 必须使用 RocksDB 状态存储提供程序。 Databricks 建议在计算配置过程中启用 RocksDB。
注释
若要为当前会话启用 RocksDB 状态存储提供程序,请运行以下命令:
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
缓慢变化维度 (SCD) 类型 1
以下代码是一个使用 transformWithStateSCD 类型 1 实现的示例。 SCD 类型 1 仅跟踪给定字段的最新值。
注释
可以使用由 Delta Lake 支持的表通过流式处理表和 AUTO CDC ... INTO 来实现 SCD 类型 1 或类型 2。 此示例在状态存储中实现 SCD 类型 1,为准实时应用程序提供较低的延迟。
Python
# Import the necessary libraries
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
# Set the state store provider to RocksDB
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
# Define the output schema for the streaming query
output_schema = StructType([
StructField("user", StringType(), True),
StructField("time", LongType(), True),
StructField("location", StringType(), True)
])
# Define a custom StatefulProcessor for slowly changing dimension type 1 (SCD1) operations
class SCDType1StatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
self.handle = handle
# Define the schema for the state value
value_state_schema = StructType([...])
StructField("user", StringType(), True),
StructField("time", LongType(), True),
StructField("location", StringType(), True)
])
# Initialize the state to store the latest location for each user
self.latest_location = handle.getValueState("latestLocation", value_state_schema)
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
# Find the row with the maximum time value
max_row = None
max_time = float('-inf')
for pdf in rows:
for _, pd_row in pdf.iterrows():
time_value = pd_row["time"]
if time_value > max_time:
max_time = time_value
max_row = tuple(pd_row)
# Check whether state exists and update if necessary
exists = self.latest_location.exists()
if not exists or max_row[1] > self.latest_location.get()[1]:
# Update the state with the new max row
self.latest_location.update(max_row)
# Yield the updated row
yield pd.DataFrame(
{"user": (max_row[0],), "time": (max_row[1],), "location": (max_row[2],)}
)
# Yield an empty DataFrame if no update is needed
yield pd.DataFrame()
def close(self) -> None:
# No cleanup needed
pass
# Apply the stateful transformation to the input DataFrame
(df.groupBy("user")
.transformWithStateInPandas(
statefulProcessor=SCDType1StatefulProcessor(),
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
)
.writeStream... # Continue with stream writing configuration
)
Scala(编程语言)
// Define a case class to represent user location data
case class UserLocation(
user: String,
time: Long,
location: String)
// Define a stateful processor for slowly changing dimension type 1 (SCD1) operations
class SCDType1StatefulProcessor extends StatefulProcessor[String, UserLocation, UserLocation] {
import org.apache.spark.sql.{Encoders}
// Transient value state to store the latest location for each user
@transient private var _latestLocation: ValueState[UserLocation] = _
private val userLocationEncoder = Encoders.product[UserLocation]
// Initialize the state store
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
// Create a value state named "locationState" using UserLocation encoder
// TTLConfig.NONE means the state has no expiration
_latestLocation = getHandle.getValueState[UserLocation]("locationState",
userLocationEncoder, TTLConfig.NONE)
}
// Process input rows and update state
override def handleInputRows(
key: String,
inputRows: Iterator[UserLocation],
timerValues: TimerValues): Iterator[UserLocation] = {
// Find the location with the maximum timestamp from input rows
val maxNewLocation = inputRows.maxBy(_.time)
// Update state and emit output if:
// 1. No previous state exists, or
// 2. New location has a more recent timestamp than the stored one
if (_latestLocation.getOption().isEmpty || maxNewLocation.time > _latestLocation.get().time) {
_latestLocation.update(maxNewLocation)
Iterator.single(maxNewLocation) // Emit the updated location
} else {
Iterator.empty // No update needed, emit nothing
}
}
}
}
缓慢变化维度 (SCD) 类型 2
以下笔记本包含一个示例,展示如何在 Python 或 Scala 中使用 transformWithState 实现 SCD 类型 2。
SCD 类型 2 Python
SCD 类型 2 Scala
故障时间检测器
transformWithState 可实现计时器,以支持你根据已用时间执行操作,即使在微批处理中未处理给定键的任何记录也是如此。
以下示例展示停机检测器模式的实现。 每次看到给定键的新值时,都会更新 lastSeen 状态值,清除任何现有计时器,并为将来重置计时器。
计时器过期时,应用程序将发出自密钥上次观察到事件以来经过的时间。 然后,它会设置新的计时器,以在 10 秒后发出更新。
Python
import datetime
import time
class DownTimeDetectorStatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
# Define the schema for the state value (timestamp)
state_schema = StructType([StructField("value", TimestampType(), True)])
self.handle = handle
# Initialize state to store the last seen timestamp for each key
self.last_seen = handle.getValueState("last_seen", state_schema)
def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
latest_from_existing = self.last_seen.get()
# Calculate downtime duration
downtime_duration = timerValues.getCurrentProcessingTimeInMs() - int(time.time() * 1000)
# Register a new timer for 10 seconds in the future
self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
# Yield a DataFrame with the key and downtime duration
yield pd.DataFrame(
{
"id": key,
"timeValues": str(downtime_duration),
}
)
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
# Find the row with the maximum timestamp
max_row = max((tuple(pdf.iloc[0]) for pdf in rows), key=lambda row: row[1])
# Get the latest timestamp from the existing state or use epoch start if a timestamp doesn't exist
if self.last_seen.exists():
latest_from_existing = self.last_seen.get()
else:
latest_from_existing = datetime.fromtimestamp(0)
# If the new data is more recent than the existing state
if latest_from_existing < max_row[1]:
# Delete all existing timers
for timer in self.handle.listTimers():
self.handle.deleteTimer(timer)
# Update the last seen timestamp
self.last_seen.update((max_row[1],))
# Register a new timer for 5 seconds in the future
self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000)
# Get current processing time in milliseconds
timestamp_in_millis = str(timerValues.getCurrentProcessingTimeInMs())
# Yield a DataFrame with the key and current timestamp
yield pd.DataFrame({"id": key, "timeValues": timestamp_in_millis})
def close(self) -> None:
# No cleanup needed
pass
Scala(编程语言)
import java.sql.Timestamp
import org.apache.spark.sql.Encoders
// The (String, Timestamp) schema represents an (id, time). We want to do downtime
// detection on every single unique sensor, where each sensor has a sensor ID.
class DowntimeDetector(duration: Duration) extends
StatefulProcessor[String, (String, Timestamp), (String, Duration)] {
@transient private var _lastSeen: ValueState[Timestamp] = _
private val timestampEncoder = Encoders.TIMESTAMP
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_lastSeen = getHandle.getValueState[Timestamp]("lastSeen", timestampEncoder, TTLConfig.NONE)
}
// The logic here is as follows: find the largest timestamp seen so far. Set a timer for
// the duration later.
override def handleInputRows(
key: String,
inputRows: Iterator[(String, Timestamp)],
timerValues: TimerValues): Iterator[(String, Duration)] = {
val latestRecordFromNewRows = inputRows.maxBy(_._2.getTime)
// Use getOrElse to initiate state variable if it doesn't exist
val latestTimestampFromExistingRows = _lastSeen.getOption().getOrElse(new Timestamp(0))
val latestTimestampFromNewRows = latestRecordFromNewRows._2
if (latestTimestampFromNewRows.after(latestTimestampFromExistingRows)) {
// Cancel the one existing timer, since we have a new latest timestamp.
// We call "listTimers()" because we don't know ahead of time what
// the timestamp of the existing timer will be.
getHandle.listTimers().foreach(timer => getHandle.deleteTimer(timer))
_lastSeen.update(latestTimestampFromNewRows)
// Use timerValues to schedule a timer using processing time.
getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + duration.toMillis)
} else {
// No new latest timestamp, so there is no need to update the state or set a timer.
}
Iterator.empty
}
override def handleExpiredTimer(
key: String,
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Duration)] = {
val latestTimestamp = _lastSeen.get()
val downtimeDuration = new Duration(
timerValues.getCurrentProcessingTimeInMs() - latestTimestamp.getTime)
// Register another timer that will fire in 10 seconds.
// Timers can be registered anywhere but init()
getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
Iterator((key, downtimeDuration))
}
}
迁移现有状态信息
以下示例演示如何实现接受初始状态的有状态应用程序。 可以将初始状态处理添加到任何有状态应用程序,但初始状态只能在首次初始化应用程序时设置。
此示例使用 statestore 读取器从检查点路径加载现有状态信息。 此模式的示例用例是从旧有状态应用程序迁移到 transformWithState的。
Python
# Import the necessary libraries
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType, IntegerType
from typing import Iterator
# Set RocksDB as the state store provider for better performance
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
"""
Input schema is as below
input_schema = StructType(
[StructField("id", StringType(), True)],
[StructField("value", StringType(), True)]
)
"""
# Define the output schema for the streaming query
output_schema = StructType([
StructField("id", StringType(), True),
StructField("accumulated", StringType(), True)
])
class AccumulatedCounterStatefulProcessorWithInitialState(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
# Define the schema for the state value (integer)
state_schema = StructType([StructField("value", IntegerType(), True)])
# Initialize state to store the accumulated counter for each id
self.counter_state = handle.getValueState("counter_state", state_schema)
self.handle = handle
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
# Check if state exists for the current key
exists = self.counter_state.exists()
if exists:
value_row = self.counter_state.get()
existing_value = value_row[0]
else:
existing_value = 0
accumulated_value = existing_value
# Process input rows and accumulate values
for pdf in rows:
value = pdf["value"].astype(int).sum()
accumulated_value += value
# Update the state with the new accumulated value
self.counter_state.update((accumulated_value,))
# Yield a DataFrame with the key and accumulated value
yield pd.DataFrame({"id": key, "accumulated": str(accumulated_value)})
def handleInitialState(self, key, initialState, timerValues) -> None:
# Initialize the state with the provided initial value
init_val = initialState.at[0, "initVal"]
self.counter_state.update((init_val,))
def close(self) -> None:
# No cleanup needed
pass
# Load initial state from a checkpoint directory
initial_state = spark.read.format("statestore")
.option("path", "$checkpointsDir")
.load()
# Apply the stateful transformation to the input DataFrame
df.groupBy("id")
.transformWithStateInPandas(
statefulProcessor=AccumulatedCounterStatefulProcessorWithInitialState(),
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
initialState=initial_state,
)
.writeStream... # Continue with stream writing configuration
Scala(编程语言)
// Import the necessary libraries
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.{Dataset, Encoder, Encoders, DataFrame}
import org.apache.spark.sql.types._
// Define a stateful processor that can handle the initial state
class InitialStateStatefulProcessor extends StatefulProcessorWithInitialState[String, (String, String, String), (String, String), (String, Int)] {
// Transient value state to store the accumulated value
@transient protected var valueState: ValueState[Int] = _
private val intEncoder = Encoders.scalaInt
// Initialize the state store
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
// Create a value state named "valueState" using Int encoder
// TTLConfig.NONE means the state has no automatic expiration
valueState = getHandle.getValueState[Int]("valueState",
intEncoder, TTLConfig.NONE)
}
// Process input rows and update state
override def handleInputRows(
key: String,
inputRows: Iterator[(String, String, String)],
timerValues: TimerValues): Iterator[(String, String)] = {
var existingValue = 0
// Retrieve existing value from state if it exists
if (valueState.exists()) {
existingValue += valueState.get()
}
var accumulatedValue = existingValue
// Accumulate values from input rows
for (row <- inputRows) {
accumulatedValue += row._2.toInt
}
// Update the state with the new accumulated value
valueState.update(accumulatedValue)
// Return the key and accumulated value as a string
Iterator((key, accumulatedValue.toString))
}
// Handle initial state when provided
override def handleInitialState(
key: String, initialState: (String, Int), timerValues: TimerValues): Unit = {
// Update the state with the initial value
valueState.update(initialState._2)
}
}
将 Delta 表迁移到用于初始化的状态存储
以下笔记本包含一个在 Python 或 Scala 中使用 Delta 表 transformWithState 初始化状态存储值的示例。
从 Delta Python 初始化状态
从 Delta Scala 初始化状态
会话跟踪
以下笔记本包含在 Python 或 Scala 中使用 transformWithState 的会话跟踪示例。
会话跟踪 Python
会话跟踪 Scala
使用 transformWithState 的自定义流-流联接
以下代码演示了使用 transformWithState 跨多个流的自定义流-流联接。 出于以下原因,可以使用此方法而不是内置联接运算符:
- 需要使用不支持流-流联接的更新输出模式。 这对于较低的延迟应用程序尤其有用。
- 需要继续对延迟到达行执行联接(水印过期之后)。
- 需要执行多对多流-流联接。
此示例向用户提供对状态过期逻辑的完全控制,这会允许执行动态保留期扩展,从而可以处理无序事件,即使它发生在水印之后也是如此。
Python
# Import the necessary libraries
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, StringType, TimestampType
from typing import Iterator
# Define output schema for the joined data
output_schema = StructType([
StructField("user_id", StringType(), True),
StructField("event_type", StringType(), True),
StructField("timestamp", TimestampType(), True),
StructField("profile_name", StringType(), True),
StructField("email", StringType(), True),
StructField("preferred_category", StringType(), True)
])
class CustomStreamJoinProcessor(StatefulProcessor):
# Initialize stateful storage for user profiles, preferences, and event tracking.
def init(self, handle: StatefulProcessorHandle) -> None:
# Define schemas for different types of state data
profile_schema = StructType([
StructField("name", StringType(), True),
StructField("email", StringType(), True),
StructField("updated_at", TimestampType(), True)
])
preferences_schema = StructType([
StructField("preferred_category", StringType(), True),
StructField("updated_at", TimestampType(), True)
])
activity_schema = StructType([
StructField("event_type", StringType(), True),
StructField("timestamp", TimestampType(), True)
])
map_state_key_schema = StructType([
StructField("user_id", StringType(), True)
])
# Initialize state storage for user profiles, preferences, and activity
self.profile_state = handle.getMapState("user_profiles", map_state_key_schema, profile_schema)
self.preferences_state = handle.getMapState("user_preferences", map_state_key_schema, preferences_schema)
self.activity_state = handle.getMapState("user_activity", map_state_key_schema, activity_schema)
# Process incoming events and update the state
def handleInputRows(self, key, rows: Iterator[pd.DataFrame], timerValues) -> Iterator[pd.DataFrame]:
df = pd.concat(rows, ignore_index=True)
output_rows = []
for _, row in df.iterrows():
user_id = row["user_id"]
if "event_type" in row: # User activity event
self.activity_state.updateValue(user_id, row.to_dict())
# Set a timer to process this event after a 10-second delay
self.handle.registerTimer(timerValues.get_current_processing_time_in_ms() + (10 * 1000))
elif "name" in row: # Profile update
self.profile_state.updateValue(user_id, row.to_dict())
elif "preferred_category" in row: # Preference update
self.preferences_state.updateValue(user_id, row.to_dict())
# No immediate output; processing will happen when the timer expires
return iter([])
# Perform lookup after delay, handling out-of-order and late-arriving events.
def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
# Retrieve stored state for the user
user_activity = self.activity_state.getValue(key)
user_profile = self.profile_state.getValue(key)
user_preferences = self.preferences_state.getValue(key)
if user_activity:
# Combine data from different states into a single output row
output_row = {
"user_id": key,
"event_type": user_activity["event_type"],
"timestamp": user_activity["timestamp"],
"profile_name": user_profile.get("name") if user_profile else None,
"email": user_profile.get("email") if user_profile else None,
"preferred_category": user_preferences.get("preferred_category") if user_preferences else None
}
return iter([pd.DataFrame([output_row])])
return iter([])
def close(self) -> None:
# No cleanup needed
pass
# Apply transformWithState to the input DataFrame
(df.groupBy("user_id")
.transformWithStateInPandas(
statefulProcessor=CustomStreamJoinProcessor(),
outputStructType=output_schema,
outputMode="Append",
timeMode="ProcessingTime"
)
.writeStream... # Continue with stream writing configuration
)
Scala(编程语言)
// Import the necessary libraries
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.TimestampType
import java.sql.Timestamp
// Define a case class for enriched user events, combining user activity with profile and preference data
case class EnrichedUserEvent(
user_id: String,
event_type: String,
timestamp: Timestamp,
profile_name: Option[String],
email: Option[String],
preferred_category: Option[String]
)
// Custom stateful processor for stream-stream join
class CustomStreamJoinProcessor extends StatefulProcessor[String, UserEvent, EnrichedUserEvent] {
// Transient state variables to store user profiles, preferences, and activities
@transient private var _profileState: MapState[String, UserProfile] = _
@transient private var _preferencesState: MapState[String, UserPreferences] = _
@transient private var _activityState: MapState[String, UserEvent] = _
private val userProfileEncoder = Encoders.product[UserProfile]
private val userPreferencesEncoder = Encoders.product[UserPreferences]
private val userEventEncoder = Encoders.product[UserEvent]
// Initialize state stores
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_profileState = getHandle.getMapState[String, UserProfile]("profileState", userProfileEncoder, TTLConfig.NONE)
_preferencesState = getHandle.getMapState[String, UserPreferences]("preferencesState", userPreferencesEncoder, TTLConfig.NONE)
_activityState = getHandle.getMapState[String, UserEvent]("activityState", userEventEncoder, TTLConfig.NONE)
}
// Handle incoming user events
override def handleInputRows(
key: String,
inputRows: Iterator[UserEvent],
timerValues: TimerValues): Iterator[EnrichedUserEvent] = {
inputRows.foreach { event =>
if (event.event_type.nonEmpty) {
// Update activity state and set a timer for 10 seconds in the future
_activityState.update(key, event)
getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
}
}
Iterator.empty
}
// Handle expired timers to produce enriched events
override def handleExpiredTimer(
key: String,
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo): Iterator[EnrichedUserEvent] = {
// Retrieve user data from state stores
val userEvent = _activityState.getOption(key)
val userProfile = _profileState.getOption(key)
val userPreferences = _preferencesState.getOption(key)
if (userEvent.isDefined) {
// Create and return an enriched event if user activity exists
val enrichedEvent = EnrichedUserEvent(
user_id = key,
event_type = userEvent.get.event_type,
timestamp = userEvent.get.timestamp,
profile_name = userProfile.map(_.name),
email = userProfile.map(_.email),
preferred_category = userPreferences.map(_.preferred_category)
)
Iterator.single(enrichedEvent)
} else {
Iterator.empty
}
}
}
// Apply the custom stateful processor to the input DataFrame
val enrichedStream = df
.groupByKey(_.user_id)
.transformWithState(
new CustomStreamJoinProcessor(),
TimeMode.ProcessingTime(),
OutputMode.Append()
)
// Write the enriched stream to Delta Lake
enrichedStream.writeStream
.format("delta")
.outputMode("append")
.option("checkpointLocation", "/mnt/delta/checkpoints")
.start("/mnt/delta/enriched_events")
Top-K 计算
以下示例使用具有优先级队列的 ListState,以近乎实时地维护和更新每个组键的数据流中的前 K 个元素。