Python 用户定义的表函数 (UDDF)

重要

此功能在 Databricks Runtime 14.3 LTS 及更高版本中作为公共预览版提供。

用户定义的表函数(UDTF)允许注册返回表的函数,而不是标量值。 与从每个调用返回单个结果值的标量函数不同,每个 UDTF 在 SQL 语句的 FROM 子句中调用,并将整个表作为输出返回。

每个 UDTF 调用都可以接受零个或多个参数。 这些参数可以是表示整个输入表的标量表达式或表参数。

基本 UDTF 语法

Apache Spark 通过必需的 eval 方法使用 yield 发出输出行来将 Python UDDF 实现为 Python 类。

若要将类用作 UDTF,必须导入 PySpark udtf 函数。 Databricks 建议将此函数用作修饰器,并使用 returnType 选项显式指定字段名称和类型(除非类定义 analyze 方法,如后面的部分所述)。

以下 UDTF 使用一组固定的两个整数参数创建一个表:

from pyspark.sql.functions import lit, udtf

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, x: int, y: int):
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()
+----+-----+
| sum| diff|
+----+-----+
|   3|   -1|
+----+-----+

注册 UDTF

UDTF 将注册到本地 SparkSession,并在笔记本或作业级别隔离。

不能将 UDF 注册为 Unity 目录中的对象,UDDF 不能用于 SQL 仓库。

可以将 UDTF 注册到当前 SparkSession,以便使用函数 spark.udtf.register() 进行 SQL 查询。 提供 SQL 函数和 Python UDTF 类的名称。

spark.udtf.register("get_sum_diff", GetSumDiff)

调用已注册的 UDTF

注册后,可以使用 %sql magic 命令或 spark.sql() 函数在 SQL 中使用 UDTF:

spark.udtf.register("get_sum_diff", GetSumDiff)
spark.sql("SELECT * FROM get_sum_diff(1,2);")
%sql
SELECT * FROM get_sum_diff(1,2);

使用 Apache Arrow

如果 UDTF 接收少量数据作为输入,但输出大型表,Databricks 建议使用 Apache Arrow。 可以通过在声明 UDTF 时指定 useArrow 参数来启用它:

@udtf(returnType="c1: int, c2: int", useArrow=True)

变量参数列表 - *args 和 **kwargs

可以使用 Python *args**kwargs 语法并实现逻辑来处理未指定数量的输入值。

以下示例返回相同的结果,同时显式检查参数的输入长度和类型:

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, *args):
        assert(len(args) == 2)
        assert(isinstance(arg, int) for arg in args)
        x = args[0]
        y = args[1]
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()

下面是相同的示例,但使用关键字参数:

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, **kwargs):
        x = kwargs["x"]
        y = kwargs["y"]
        yield x + y, x - y

GetSumDiff(x=lit(1), y=lit(2)).show()

在注册时定义静态架构

UDTF 返回带有输出架构的行,该架构由列名和类型的有序序列组成。 如果 UDTF 架构在所有查询中始终保持不变,则可以在 @udtf 修饰器后指定静态固定架构。 它必须是 StructType

StructType().add("c1", StringType())

或表示结构类型的 DDL 字符串:

c1: string

在函数调用时计算动态架构

UDDF 还可以根据输入参数的值以编程方式计算每个调用的输出架构。 为此,请定义一个名为 analyze 的静态方法,该方法接受与提供给特定 UDTF 调用的参数相对应的零个或多个参数。

analyze 方法的每个参数都是 AnalyzeArgument 类的实例,该类包含以下字段:

AnalyzeArgument 类字段 描述
dataType 输入参数的类型是 DataType。 对于输入表参数,这是表示表列的 StructType
value 作为 Optional[Any] 的输入参数的值。 对于非常数表参数或文本标量参数,这是 None
isTable 输入参数是否是作为 BooleanType 的表。
isConstantExpression 输入参数是否为常量可折叠表达式,并以 BooleanType的形式呈现。

analyze 方法返回 AnalyzeResult 类的实例,该实例包括结果表的架构作为 StructType 加上一些可选字段。 如果 UDTF 接受输入表参数,则 AnalyzeResult 还可以包含一种请求的方法,以便跨多个 UDTF 调用对输入表的行进行分区和排序,如下所述。

AnalyzeResult 类字段 描述
schema 作为 StructType 的结果表的架构。
withSinglePartition 是否将所有输入行发送到与 BooleanType相同的 UDTF 类实例。
partitionBy 如果设置为非空,则每个分区表达式值的唯一组合的所有行将由 UDTF 类的单独实例处理。
orderBy 如果设置为非空,则指定每个分区中的行的顺序。
select 如果设置为非空,表示这是一个由 UDTF 指定的表达式序列,Catalyst 将根据这些表达式对输入的 TABLE 参数中的列进行计算。 UDTF 按列出的顺序接收列表中的每个名称的一个输入属性。

analyze 示例为输入字符串参数中的每个单词返回一个输出列。

@udtf
class MyUDTF:
  @staticmethod
  def analyze(text: AnalyzeArgument) -> AnalyzeResult:
    schema = StructType()
    for index, word in enumerate(sorted(list(set(text.value.split(" "))))):
      schema = schema.add(f"word_{index}", IntegerType())
    return AnalyzeResult(schema=schema)

  def eval(self, text: str):
    counts = {}
    for word in text.split(" "):
      if word not in counts:
            counts[word] = 0
      counts[word] += 1
    result = []
    for word in sorted(list(set(text.split(" ")))):
      result.append(counts[word])
    yield result
['word_0', 'word_1']

将状态转发到将来的 eval 调用

analyze 方法可以用作执行初始化的一个便捷位置,然后将结果传递给同一 UDTF 调用的后续 eval 方法调用。

为此,请创建 AnalyzeResult 的子类,并从 analyze 方法返回子类的实例。 然后,将附加参数添加到 __init__ 方法以接受该实例。

analyze 示例返回常量输出架构,但在结果元数据中添加自定义信息,供将来 __init__ 方法调用使用:

@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
    buffer: str = ""

@udtf
class TestUDTF:
  def __init__(self, analyze_result=None):
    self._total = 0
    if analyze_result is not None:
      self._buffer = analyze_result.buffer
    else:
      self._buffer = ""

  @staticmethod
  def analyze(argument, _) -> AnalyzeResult:
    if (
      argument.value is None
      or argument.isTable
      or not isinstance(argument.value, str)
      or len(argument.value) == 0
    ):
      raise Exception("The first argument must be a non-empty string")
    assert argument.dataType == StringType()
    assert not argument.isTable
    return AnalyzeResultWithBuffer(
      schema=StructType()
        .add("total", IntegerType())
        .add("buffer", StringType()),
      withSinglePartition=True,
      buffer=argument.value,
    )

  def eval(self, argument, row: Row):
    self._total += 1

  def terminate(self):
    yield self._total, self._buffer

self.spark.udtf.register("test_udtf", TestUDTF)

spark.sql(
  """
  WITH t AS (
    SELECT id FROM range(1, 21)
  )
  SELECT total, buffer
  FROM test_udtf("abc", TABLE(t))
  """
).show()
+-------+-------+
| count | buffer|
+-------+-------+
|    20 |  "abc"|
+-------+-------+

生成输出行

eval 方法针对输入表参数的每一行运行一次(如果未提供任何表参数,则仅运行一次),随后将在末尾调用 terminate 方法。 任一方法通过生成元组、列表或 pyspark.sql.Row 对象,输出符合结果架构的零行或多行。

此示例通过提供一个包含三个元素的元组来返回一行:

def eval(self, x, y, z):
  yield (x, y, z)

还可以省略括号:

def eval(self, x, y, z):
  yield x, y, z

添加尾随逗号以返回仅包含一列的行:

def eval(self, x, y, z):
  yield x,

还可以生成 pyspark.sql.Row 对象。

def eval(self, x, y, z)
  from pyspark.sql.types import Row
  yield Row(x, y, z)

此示例使用 Python 列表从 terminate 方法生成输出行。 为此目的,你可以在 UDTF 评估的早期步骤中,将状态存储在类内。

def terminate(self):
  yield [self.x, self.y, self.z]

将标量参数传递给 UDTF

可以将标量参数作为由文本值或基于其的函数组成的常量表达式传递给 UDTF。 例如:

SELECT * FROM udtf(42, group => upper("finance_department"));

将表参数传递给 UDTF

除了标量输入参数外,Python UDF 还可以接受输入表作为参数。 单个 UDTF 还可以接受表参数和多个标量参数。

然后,任何 SQL 查询都可以使用 TABLE 关键字提供输入表,后跟括号中相应的表标识符,例如 TABLE(t)。 或者,可以传递表子查询,例如 TABLE(SELECT a, b, c FROM t)TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key))

然后,输入表参数表示为 eval 方法的 pyspark.sql.Row 参数,对输入表中每一行的 eval 方法进行一次调用。 可以使用标准的 PySpark 列字段标注来与每行中的列进行交互。 以下示例演示如何显式导入 PySpark Row 类型,然后在 id 字段中筛选传递的表:

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="id: int")
class FilterUDTF:
    def eval(self, row: Row):
        if row["id"] > 5:
            yield row["id"],

spark.udtf.register("filter_udtf", FilterUDTF)

若要查询函数,请使用 TABLE SQL 关键字:

SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)));
+---+
| id|
+---+
|  6|
|  7|
|  8|
|  9|
+---+

指定从函数调用中得到的输入行的分区

使用表参数调用 UDTF 时,任何 SQL 查询都可以根据一个或多个输入表列的值跨多个 UDTF 调用对输入表进行分区。

若要指定分区,请在 TABLE 参数后在函数调用中使用 PARTITION BY 子句。 这可以保证具有分区列值的每个唯一组合的所有输入行都由 UDTF 类的一个实例使用。

请注意,除了简单的列引用外,PARTITION BY 子句还接受基于输入表列的任意表达式。 例如,可以指定字符串的 LENGTH、从日期提取月份或连接两个值。

还可以指定 WITH SINGLE PARTITION 而不是 PARTITION BY 以仅请求一个分区,其中所有输入行必须由 UDTF 类的一个实例使用。

在每个分区中,可以选择指定 UDTF 的 eval 方法使用输入行时所需的顺序。 为此,在上述 PARTITION BYWITH SINGLE PARTITION 子句后面提供 ORDER BY 子句。

例如,请考虑以下 UDTF:

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="a: string, b: int")
class FilterUDTF:
  def __init__(self):
    self.key = ""
    self.max = 0

  def eval(self, row: Row):
    self.key = row["a"]
    self.max = max(self.max, row["b"])

  def terminate(self):
    yield self.key, self.max

spark.udtf.register("filter_udtf", FilterUDTF)

可以通过多种方法在输入表上调用 UDTF 时指定分区选项:

-- Create an input table with some example values.
DROP TABLE IF EXISTS values_table;
CREATE TABLE values_table (a STRING, b INT);
INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)";
SELECT * FROM values_table;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 2  |
| "abc" | 4  |
| "def" | 6  |
| "def" | 8  |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique value in the `a` column are processed by the same
-- instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 4  |
| "def" | 8  |
+-------+----+

-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
-- processed by the same instance of the UDTF class. Within each partition, the rows are ordered
-- by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1;
+-------+---+
|     a | b |
+-------+---+
| "def" | 8 |
+-------+---+
-- Query the UDTF with the input table as an argument and a directive to consider all the input
-- rows in one single partition such that exactly one instance of the UDTF class consumes all of
-- the input rows. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "def" | 8 |
+-------+----+

通过 analyze 方法指定输入行的分区

请注意,对于在 SQL 查询中调用 UDF 时对输入表进行分区的每种方法,UDTF 的 analyze 方法都有相应的方法来自动指定相同的分区方法。

  • 与其使用 SELECT * FROM udtf(TABLE(t) PARTITION BY a)调用 UDTF,不如更新 analyze 方法以设置字段 partitionBy=[PartitioningColumn("a")],然后使用 SELECT * FROM udtf(TABLE(t))调用函数。
  • 同样地,您可以让 analyze 设置字段 withSinglePartition=trueorderBy=[OrderingColumn("b")],然后只需传递 TABLE(t),而不是在 SQL 查询中指定 TABLE(t) WITH SINGLE PARTITION ORDER BY b
  • 可以将 analyze 设置为 select=[SelectedColumn("a")],然后仅传递 TABLE(t),而不是在 SQL 查询中传递 TABLE(SELECT a FROM t)

在以下示例中,analyze 返回常量输出架构,从输入表中选择列的子集,并指定根据 date 列的值,将输入表分区到多个 UDTF 调用中:

@staticmethod
def analyze(*args) -> AnalyzeResult:
  """
  The input table will be partitioned across several UDTF calls based on the monthly
  values of each `date` column. The rows within each partition will arrive ordered by the `date`
  column. The UDTF will only receive the `date` and `word` columns from the input table.
  """
  from pyspark.sql.functions import (
    AnalyzeResult,
    OrderingColumn,
    PartitioningColumn,
  )

  assert len(args) == 1, "This function accepts one argument only"
  assert args[0].isTable, "Only table arguments are supported"
  return AnalyzeResult(
    schema=StructType()
      .add("month", DateType())
      .add('longest_word", IntegerType()),
    partitionBy=[
      PartitioningColumn("extract(month from date)")],
    orderBy=[
      OrderingColumn("date")],
    select=[
      SelectedColumn("date"),
      SelectedColumn(
        name="length(word),
        alias="length_word")])