PySpark 自定义数据源

PySpark 自定义数据源是使用 Python (PySpark) DataSource API创建的,它允许使用 Python 从自定义数据源读取和写入到 Apache Spark 中的自定义数据接收器。 可以使用 PySpark 自定义数据源来定义与数据系统的自定义连接,并实现其他功能来构建可重用的数据源。

注意

PySpark 自定义数据源需要 Databricks Runtime 15.4 LTS 及更高版本,或无服务器环境版本 2

DataSource 类

PySpark DataSource 是基类,它提供创建数据读取者和写入者的方法。

实现数据源子类

根据你的用例,任何子类都必须实现以下方法,以使数据源具有可读、可写或同时具备可读写功能:

属性或方法 说明
name 必填。 数据源的名称
schema 必填。 要读取或写入的数据源的架构
reader() 必须返回 DataSourceReader 才能使数据源可读(批处理)
writer() 必须返回 DataSourceWriter 才能使数据接收器变为可写状态(批处理)
streamReader()simpleStreamReader() 必须返回 DataSourceStreamReader 才能使数据流处于可读状态(流式传输)
streamWriter() 必须返回 DataSourceStreamWriter 以使数据流可写(流化)

注意

用户定义的 DataSourceDataSourceReaderDataSourceWriterDataSourceStreamReader及其 DataSourceStreamWriter方法必须可序列化。 换言之,它们必须是包含基元类型的字典或嵌套字典。

注册数据源

实现接口后,必须注册它,然后才能加载或以其他方式使用它,如以下示例所示:

# Register the data source
spark.dataSource.register(MyDataSourceClass)

# Read from a custom data source
spark.read.format("my_datasource_name").load().show()

示例 1:为批处理查询创建 PySpark DataSource

若要演示 PySpark DataSource 读取器功能,请创建一个数据源,该数据源使用 faker Python 包生成示例数据。 有关 faker 的更多信息,请参阅 Faker 文档

使用以下命令安装 faker 包:

%pip install faker

步骤 1:实现批处理查询的读取器

首先,实现读取器逻辑以生成示例数据。 使用已安装的 faker 库填充架构中的每个字段。

class FakeDataSourceReader(DataSourceReader):

    def __init__(self, schema, options):
        self.schema: StructType = schema
        self.options = options

    def read(self, partition):
        # Library imports must be within the method.
        from faker import Faker
        fake = Faker()

        # Every value in this `self.options` dictionary is a string.
        num_rows = int(self.options.get("numRows", 3))
        for _ in range(num_rows):
            row = []
            for field in self.schema.fields:
                value = getattr(fake, field.name)()
                row.append(value)
            yield tuple(row)

步骤 2:定义示例 DataSource

接下来,将新的 PySpark DataSource 定义为具有名称、架构和读取器的子类 DataSource 。 必须将 reader() 方法定义为从批处理查询中的数据源读取。

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType

class FakeDataSource(DataSource):
    """
    An example data source for batch query using the `faker` library.
    """

    @classmethod
    def name(cls):
        return "fake"

    def schema(self):
        return "name string, date string, zipcode string, state string"

    def reader(self, schema: StructType):
        return FakeDataSourceReader(schema, self.options)

步骤 3:注册并使用示例数据源

要使用数据源,请对其进行注册。 默认情况下,FakeDataSource 有三行,架构包括以下 string 字段:namedatezipcodestate。 以下示例使用默认值注册、加载和输出示例数据源:

spark.dataSource.register(FakeDataSource)
spark.read.format("fake").load().show()
+-----------------+----------+-------+----------+
|             name|      date|zipcode|     state|
+-----------------+----------+-------+----------+
|Christine Sampson|1979-04-24|  79766|  Colorado|
|       Shelby Cox|2011-08-05|  24596|   Florida|
|  Amanda Robinson|2019-01-06|  57395|Washington|
+-----------------+----------+-------+----------+

仅支持 string 字段,但可以使用与 faker 包提供程序的字段对应的任何字段指定架构,以生成随机数据来进行测试和开发。 以下示例加载具有 namecompany 字段的数据源:

spark.read.format("fake").schema("name string, company string").load().show()
+---------------------+--------------+
|name                 |company       |
+---------------------+--------------+
|Tanner Brennan       |Adams Group   |
|Leslie Maxwell       |Santiago Group|
|Mrs. Jacqueline Brown|Maynard Inc   |
+---------------------+--------------+

若要加载具有自定义行数的数据源,请指定 numRows 选项。 以下示例指定了 5 行:

spark.read.format("fake").option("numRows", 5).load().show()
+--------------+----------+-------+------------+
|          name|      date|zipcode|       state|
+--------------+----------+-------+------------+
|  Pam Mitchell|1988-10-20|  23788|   Tennessee|
|Melissa Turner|1996-06-14|  30851|      Nevada|
|  Brian Ramsey|2021-08-21|  55277|  Washington|
|  Caitlin Reed|1983-06-22|  89813|Pennsylvania|
| Douglas James|2007-01-18|  46226|     Alabama|
+--------------+----------+-------+------------+

示例 2:使用变体创建 PySpark GitHub DataSource

为了演示如何在 PySpark DataSource 中使用变体,此示例创建从 GitHub 读取拉取请求的数据源。

注意

Databricks Runtime 17.1 及更高版本中的 PySpark 自定义数据源支持变体。

有关变体的信息,请参阅 查询变体数据

步骤 1:实现读取功能以检索拉取请求

首先,实现读取器逻辑,从指定的GitHub存储库中检索拉取请求。

class GithubVariantPullRequestReader(DataSourceReader):
    def __init__(self, options):
        self.token = options.get("token")
        self.repo = options.get("path")
        if self.repo is None:
            raise Exception(f"Must specify a repo in `.load()` method.")
        # Every value in this `self.options` dictionary is a string.
        self.num_rows = int(options.get("numRows", 10))

    def read(self, partition):
        header = {
            "Accept": "application/vnd.github+json",
        }
        if self.token is not None:
            header["Authorization"] = f"Bearer {self.token}"
        url = f"https://api.github.com/repos/{self.repo}/pulls"
        response = requests.get(url, headers=header)
        response.raise_for_status()
        prs = response.json()
        for pr in prs[:self.num_rows]:
            yield Row(
                id = pr.get("number"),
                title = pr.get("title"),
                user = VariantVal.parseJson(json.dumps(pr.get("user"))),
                created_at = pr.get("created_at"),
                updated_at = pr.get("updated_at")
            )

步骤 2:定义GitHub数据源

接下来,使用名称、架构和方法DataSource将新的 PySpark GitHub DataSource 定义为 reader() 的子类。 架构包括以下字段:id、、titleusercreated_atupdated_at。 字段 user 定义为变体。

import json
import requests

from pyspark.sql import Row
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import VariantVal

class GithubVariantDataSource(DataSource):
    @classmethod
    def name(self):
        return "githubVariant"
    def schema(self):
        return "id int, title string, user variant, created_at string, updated_at string"
    def reader(self, schema):
        return GithubVariantPullRequestReader(self.options)

步骤 3:注册和使用数据源

要使用数据源,请对其进行注册。 以下示例注册,然后加载数据源并输出GitHub存储库 PR 数据的三行:

spark.dataSource.register(GithubVariantDataSource)
spark.read.format("githubVariant").option("numRows", 3).load("apache/spark").display()
+---------+-----------------------------------------------------+---------------------+----------------------+----------------------+
| id      | title                                               | user                | created_at           | updated_at           |
+---------+---------------------------------------------------- +---------------------+----------------------+----------------------+
|   51293 |[SPARK-52586][SQL] Introduce AnyTimeType             |  {"avatar_url":...} | 2025-06-26T09:20:59Z | 2025-06-26T15:22:39Z |
|   51292 |[WIP][PYTHON] Arrow UDF for aggregation              |  {"avatar_url":...} | 2025-06-26T07:52:27Z | 2025-06-26T07:52:37Z |
|   51290 |[SPARK-50686][SQL] Hash to sort aggregation fallback |  {"avatar_url":...} | 2025-06-26T06:19:58Z | 2025-06-26T06:20:07Z |
+---------+-----------------------------------------------------+---------------------+----------------------+----------------------+

示例 3:创建 PySpark DataSource 以流式传输读取和写入

若要演示 PySpark DataSource 流读取器和编写器功能,请创建一个示例数据源,该数据源使用 faker Python 包在每个微批处理中生成两行。 有关 faker 的更多信息,请参阅 Faker 文档

使用以下命令安装 faker 包:

%pip install faker

步骤 1:实现流读取器

首先,实现示例流式处理数据读取器,该读取器在每个微包中生成两行。 你可以实现 DataSourceStreamReader,或者如果数据源的吞吐量较低且不需要分区,则可以改为实现 SimpleDataSourceStreamReader。 必须实现 simpleStreamReader()streamReader(),并且仅当未实现 simpleStreamReader() 时才会调用 streamReader()

DataSourceStreamReader 实现

streamReader 实例具有一个整数偏移量,它在每个微批中递增 2,并通过 DataSourceStreamReader 接口实现。

from pyspark.sql.datasource import InputPartition
from typing import Iterator, Tuple
import os
import json

class RangePartition(InputPartition):
    def __init__(self, start, end):
        self.start = start
        self.end = end

class FakeStreamReader(DataSourceStreamReader):
    def __init__(self, schema, options):
        self.current = 0

    def initialOffset(self) -> dict:
        """
        Returns the initial start offset of the reader.
        """
        return {"offset": 0}

    def latestOffset(self) -> dict:
        """
        Returns the current latest offset that the next microbatch will read to.
        """
        self.current += 2
        return {"offset": self.current}

    def partitions(self, start: dict, end: dict):
        """
        Plans the partitioning of the current microbatch defined by start and end offset. It
        needs to return a sequence of :class:`InputPartition` objects.
        """
        return [RangePartition(start["offset"], end["offset"])]

    def commit(self, end: dict):
        """
        This is invoked when the query has finished processing data before end offset. This
        can be used to clean up the resource.
        """
        pass

    def read(self, partition) -> Iterator[Tuple]:
        """
        Takes a partition as an input and reads an iterator of tuples from the data source.
        """
        start, end = partition.start, partition.end
        for i in range(start, end):
            yield (i, str(i))

SimpleDataSourceStreamReader 实现

SimpleStreamReader 实例与 FakeStreamReader 实例相同,在每个批中生成两行,但它是使用 SimpleDataSourceStreamReader 接口实现的,而无需分区。

class SimpleStreamReader(SimpleDataSourceStreamReader):
    def initialOffset(self):
        """
        Returns the initial start offset of the reader.
        """
        return {"offset": 0}

    def read(self, start: dict) -> (Iterator[Tuple], dict):
        """
        Takes start offset as an input, then returns an iterator of tuples and the start offset of the next read.
        """
        start_idx = start["offset"]
        it = iter([(i,) for i in range(start_idx, start_idx + 2)])
        return (it, {"offset": start_idx + 2})

    def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
        """
        Takes start and end offset as inputs, then reads an iterator of data deterministically.
        This is called when the query replays batches during restart or after a failure.
        """
        start_idx = start["offset"]
        end_idx = end["offset"]
        return iter([(i,) for i in range(start_idx, end_idx)])

    def commit(self, end):
        """
        This is invoked when the query has finished processing data before end offset. This can be used to clean up resources.
        """
        pass

步骤 2:实现流编写器

接下来实现流式处理编写器。 此流数据写入器将每个微批的元数据写入本地路径。

from pyspark.sql.datasource import DataSourceStreamWriter, WriterCommitMessage

class SimpleCommitMessage(WriterCommitMessage):
   def __init__(self, partition_id: int, count: int):
       self.partition_id = partition_id
       self.count = count

class FakeStreamWriter(DataSourceStreamWriter):
   def __init__(self, options):
       self.options = options
       self.path = self.options.get("path")
       assert self.path is not None

   def write(self, iterator):
       """
       Writes the data and then returns the commit message for that partition. Library imports must be within the method.
       """
       from pyspark import TaskContext
       context = TaskContext.get()
       partition_id = context.partitionId()
       cnt = 0
       for row in iterator:
           cnt += 1
       return SimpleCommitMessage(partition_id=partition_id, count=cnt)

   def commit(self, messages, batchId) -> None:
       """
       Receives a sequence of :class:`WriterCommitMessage` when all write tasks have succeeded, then decides what to do with it.
       In this FakeStreamWriter, the metadata of the microbatch(number of rows and partitions) is written into a JSON file inside commit().
       """
       status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
       with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
           file.write(json.dumps(status) + "\n")

   def abort(self, messages, batchId) -> None:
       """
       Receives a sequence of :class:`WriterCommitMessage` from successful tasks when some other tasks have failed, then decides what to do with it.
       In this FakeStreamWriter, a failure message is written into a text file inside abort().
       """
       with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
           file.write(f"failed in batch {batchId}")

步骤 3:定义示例 DataSource

现在,请将新的 PySpark DataSource 定义为DataSource的子类,并包含名称、架构和方法 streamReader()streamWriter()

from pyspark.sql.datasource import DataSource, DataSourceStreamReader, SimpleDataSourceStreamReader, DataSourceStreamWriter
from pyspark.sql.types import StructType

class FakeStreamDataSource(DataSource):
    """
    An example data source for streaming read and write using the `faker` library.
    """

    @classmethod
    def name(cls):
        return "fakestream"

    def schema(self):
        return "name string, state string"

    def streamReader(self, schema: StructType):
        return FakeStreamReader(schema, self.options)

    # If you don't need partitioning, you can implement the simpleStreamReader method instead of streamReader.
    # def simpleStreamReader(self, schema: StructType):
    # return SimpleStreamReader()

    def streamWriter(self, schema: StructType, overwrite: bool):
        return FakeStreamWriter(self.options)

步骤 4:注册并使用示例数据源

要使用数据源,请对其进行注册。 注册后,在流式查询中,将可通过传递短名称或全名给 format() 作为源或接收器来使用。 以下示例注册数据源,然后启动从示例数据源读取并输出到控制台的查询:

spark.dataSource.register(FakeStreamDataSource)
query = spark.readStream.format("fakestream").load().writeStream.format("console").start()

或者,以下代码使用示例流作为接收器并指定输出路径:

spark.dataSource.register(FakeStreamDataSource)

# Make sure the output directory exists and is writable
output_path = "/output_path"
dbutils.fs.mkdirs(output_path)
checkpoint_path = "/output_path/checkpoint"

query = (
    spark.readStream
    .format("fakestream")
    .load()
    .writeStream
    .format("fakestream")
    .option("path", output_path)
    .option("checkpointLocation", checkpoint_path)
    .start()
)

故障排除

如果输出是以下错误,则表示计算不支持 PySpark 自定义数据源。 必须使用 Databricks Runtime 15.2 或更高版本。

Error: [UNSUPPORTED_FEATURE.PYTHON_DATA_SOURCE] The feature is not supported: Python data sources. SQLSTATE: 0A000