Python 用户定义的表函数 (UDDF)
重要
此功能在 Databricks Runtime 14.3 LTS 及更高版本中作为公共预览版提供。
用户定义的表函数(UDTF)允许注册返回表的函数,而不是标量值。 与从每个调用返回单个结果值的标量函数不同,每个 UDTF 在 SQL 语句的 FROM
子句中调用,并将整个表作为输出返回。
每个 UDTF 调用都可以接受零个或多个参数。 这些参数可以是表示整个输入表的标量表达式或表参数。
基本 UDTF 语法
Apache Spark 通过必需的 eval
方法使用 yield
发出输出行来将 Python UDDF 实现为 Python 类。
若要将类用作 UDTF,必须导入 PySpark udtf
函数。 Databricks 建议将此函数用作修饰器,并使用 returnType
选项显式指定字段名称和类型(除非类定义 analyze
方法,如后面的部分所述)。
以下 UDTF 使用一组固定的两个整数参数创建一个表:
from pyspark.sql.functions import lit, udtf
@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
def eval(self, x: int, y: int):
yield x + y, x - y
GetSumDiff(lit(1), lit(2)).show()
+----+-----+
| sum| diff|
+----+-----+
| 3| -1|
+----+-----+
注册 UDTF
UDTF 将注册到本地 SparkSession
,并在笔记本或作业级别隔离。
不能将 UDF 注册为 Unity 目录中的对象,UDDF 不能用于 SQL 仓库。
可以将 UDTF 注册到当前 SparkSession
,以便使用函数 spark.udtf.register()
进行 SQL 查询。 提供 SQL 函数和 Python UDTF 类的名称。
spark.udtf.register("get_sum_diff", GetSumDiff)
调用已注册的 UDTF
注册后,可以使用 %sql
magic 命令或 spark.sql()
函数在 SQL 中使用 UDTF:
spark.udtf.register("get_sum_diff", GetSumDiff)
spark.sql("SELECT * FROM get_sum_diff(1,2);")
%sql
SELECT * FROM get_sum_diff(1,2);
使用 Apache Arrow
如果 UDTF 接收少量数据作为输入,但输出大型表,Databricks 建议使用 Apache Arrow。 可以通过在声明 UDTF 时指定 useArrow
参数来启用它:
@udtf(returnType="c1: int, c2: int", useArrow=True)
变量参数列表 - *args 和 **kwargs
可以使用 Python *args
或 **kwargs
语法并实现逻辑来处理未指定数量的输入值。
以下示例返回相同的结果,同时显式检查参数的输入长度和类型:
@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
def eval(self, *args):
assert(len(args) == 2)
assert(isinstance(arg, int) for arg in args)
x = args[0]
y = args[1]
yield x + y, x - y
GetSumDiff(lit(1), lit(2)).show()
下面是相同的示例,但使用关键字参数:
@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
def eval(self, **kwargs):
x = kwargs["x"]
y = kwargs["y"]
yield x + y, x - y
GetSumDiff(x=lit(1), y=lit(2)).show()
在注册时定义静态架构
UDTF 返回带有输出架构的行,该架构由列名和类型的有序序列组成。 如果 UDTF 架构在所有查询中始终保持不变,则可以在 @udtf
修饰器后指定静态固定架构。 它必须是 StructType
:
StructType().add("c1", StringType())
或表示结构类型的 DDL 字符串:
c1: string
在函数调用时计算动态架构
UDDF 还可以根据输入参数的值以编程方式计算每个调用的输出架构。 为此,请定义一个名为 analyze
的静态方法,该方法接受与提供给特定 UDTF 调用的参数相对应的零个或多个参数。
analyze
方法的每个参数都是 AnalyzeArgument
类的实例,该类包含以下字段:
AnalyzeArgument 类字段 |
描述 |
---|---|
dataType |
输入参数的类型是 DataType 。 对于输入表参数,这是表示表列的 StructType 。 |
value |
作为 Optional[Any] 的输入参数的值。 对于非常数表参数或文本标量参数,这是 None 。 |
isTable |
输入参数是否是作为 BooleanType 的表。 |
isConstantExpression |
输入参数是否为常量可折叠表达式,并以 BooleanType 的形式呈现。 |
analyze
方法返回 AnalyzeResult
类的实例,该实例包括结果表的架构作为 StructType
加上一些可选字段。 如果 UDTF 接受输入表参数,则 AnalyzeResult
还可以包含一种请求的方法,以便跨多个 UDTF 调用对输入表的行进行分区和排序,如下所述。
AnalyzeResult 类字段 |
描述 |
---|---|
schema |
作为 StructType 的结果表的架构。 |
withSinglePartition |
是否将所有输入行发送到与 BooleanType 相同的 UDTF 类实例。 |
partitionBy |
如果设置为非空,则每个分区表达式值的唯一组合的所有行将由 UDTF 类的单独实例处理。 |
orderBy |
如果设置为非空,则指定每个分区中的行的顺序。 |
select |
如果设置为非空,表示这是一个由 UDTF 指定的表达式序列,Catalyst 将根据这些表达式对输入的 TABLE 参数中的列进行计算。 UDTF 按列出的顺序接收列表中的每个名称的一个输入属性。 |
本 analyze
示例为输入字符串参数中的每个单词返回一个输出列。
@udtf
class MyUDTF:
@staticmethod
def analyze(text: AnalyzeArgument) -> AnalyzeResult:
schema = StructType()
for index, word in enumerate(sorted(list(set(text.value.split(" "))))):
schema = schema.add(f"word_{index}", IntegerType())
return AnalyzeResult(schema=schema)
def eval(self, text: str):
counts = {}
for word in text.split(" "):
if word not in counts:
counts[word] = 0
counts[word] += 1
result = []
for word in sorted(list(set(text.split(" ")))):
result.append(counts[word])
yield result
['word_0', 'word_1']
将状态转发到将来的 eval
调用
analyze
方法可以用作执行初始化的一个便捷位置,然后将结果传递给同一 UDTF 调用的后续 eval
方法调用。
为此,请创建 AnalyzeResult
的子类,并从 analyze
方法返回子类的实例。
然后,将附加参数添加到 __init__
方法以接受该实例。
此 analyze
示例返回常量输出架构,但在结果元数据中添加自定义信息,供将来 __init__
方法调用使用:
@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
buffer: str = ""
@udtf
class TestUDTF:
def __init__(self, analyze_result=None):
self._total = 0
if analyze_result is not None:
self._buffer = analyze_result.buffer
else:
self._buffer = ""
@staticmethod
def analyze(argument, _) -> AnalyzeResult:
if (
argument.value is None
or argument.isTable
or not isinstance(argument.value, str)
or len(argument.value) == 0
):
raise Exception("The first argument must be a non-empty string")
assert argument.dataType == StringType()
assert not argument.isTable
return AnalyzeResultWithBuffer(
schema=StructType()
.add("total", IntegerType())
.add("buffer", StringType()),
withSinglePartition=True,
buffer=argument.value,
)
def eval(self, argument, row: Row):
self._total += 1
def terminate(self):
yield self._total, self._buffer
self.spark.udtf.register("test_udtf", TestUDTF)
spark.sql(
"""
WITH t AS (
SELECT id FROM range(1, 21)
)
SELECT total, buffer
FROM test_udtf("abc", TABLE(t))
"""
).show()
+-------+-------+
| count | buffer|
+-------+-------+
| 20 | "abc"|
+-------+-------+
生成输出行
eval
方法针对输入表参数的每一行运行一次(如果未提供任何表参数,则仅运行一次),随后将在末尾调用 terminate
方法。 任一方法通过生成元组、列表或 pyspark.sql.Row
对象,输出符合结果架构的零行或多行。
此示例通过提供一个包含三个元素的元组来返回一行:
def eval(self, x, y, z):
yield (x, y, z)
还可以省略括号:
def eval(self, x, y, z):
yield x, y, z
添加尾随逗号以返回仅包含一列的行:
def eval(self, x, y, z):
yield x,
还可以生成 pyspark.sql.Row
对象。
def eval(self, x, y, z)
from pyspark.sql.types import Row
yield Row(x, y, z)
此示例使用 Python 列表从 terminate
方法生成输出行。 为此目的,你可以在 UDTF 评估的早期步骤中,将状态存储在类内。
def terminate(self):
yield [self.x, self.y, self.z]
将标量参数传递给 UDTF
可以将标量参数作为由文本值或基于其的函数组成的常量表达式传递给 UDTF。 例如:
SELECT * FROM udtf(42, group => upper("finance_department"));
将表参数传递给 UDTF
除了标量输入参数外,Python UDF 还可以接受输入表作为参数。 单个 UDTF 还可以接受表参数和多个标量参数。
然后,任何 SQL 查询都可以使用 TABLE
关键字提供输入表,后跟括号中相应的表标识符,例如 TABLE(t)
。 或者,可以传递表子查询,例如 TABLE(SELECT a, b, c FROM t)
或 TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key))
。
然后,输入表参数表示为 eval
方法的 pyspark.sql.Row
参数,对输入表中每一行的 eval
方法进行一次调用。 可以使用标准的 PySpark 列字段标注来与每行中的列进行交互。 以下示例演示如何显式导入 PySpark Row
类型,然后在 id
字段中筛选传递的表:
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
@udtf(returnType="id: int")
class FilterUDTF:
def eval(self, row: Row):
if row["id"] > 5:
yield row["id"],
spark.udtf.register("filter_udtf", FilterUDTF)
若要查询函数,请使用 TABLE
SQL 关键字:
SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)));
+---+
| id|
+---+
| 6|
| 7|
| 8|
| 9|
+---+
指定从函数调用中得到的输入行的分区
使用表参数调用 UDTF 时,任何 SQL 查询都可以根据一个或多个输入表列的值跨多个 UDTF 调用对输入表进行分区。
若要指定分区,请在 TABLE
参数后在函数调用中使用 PARTITION BY
子句。
这可以保证具有分区列值的每个唯一组合的所有输入行都由 UDTF 类的一个实例使用。
请注意,除了简单的列引用外,PARTITION BY
子句还接受基于输入表列的任意表达式。 例如,可以指定字符串的 LENGTH
、从日期提取月份或连接两个值。
还可以指定 WITH SINGLE PARTITION
而不是 PARTITION BY
以仅请求一个分区,其中所有输入行必须由 UDTF 类的一个实例使用。
在每个分区中,可以选择指定 UDTF 的 eval
方法使用输入行时所需的顺序。 为此,在上述 PARTITION BY
或 WITH SINGLE PARTITION
子句后面提供 ORDER BY
子句。
例如,请考虑以下 UDTF:
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
@udtf(returnType="a: string, b: int")
class FilterUDTF:
def __init__(self):
self.key = ""
self.max = 0
def eval(self, row: Row):
self.key = row["a"]
self.max = max(self.max, row["b"])
def terminate(self):
yield self.key, self.max
spark.udtf.register("filter_udtf", FilterUDTF)
可以通过多种方法在输入表上调用 UDTF 时指定分区选项:
-- Create an input table with some example values.
DROP TABLE IF EXISTS values_table;
CREATE TABLE values_table (a STRING, b INT);
INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)";
SELECT * FROM values_table;
+-------+----+
| a | b |
+-------+----+
| "abc" | 2 |
| "abc" | 4 |
| "def" | 6 |
| "def" | 8 |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique value in the `a` column are processed by the same
-- instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1;
+-------+----+
| a | b |
+-------+----+
| "abc" | 4 |
| "def" | 8 |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
-- processed by the same instance of the UDTF class. Within each partition, the rows are ordered
-- by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1;
+-------+---+
| a | b |
+-------+---+
| "def" | 8 |
+-------+---+
-- Query the UDTF with the input table as an argument and a directive to consider all the input
-- rows in one single partition such that exactly one instance of the UDTF class consumes all of
-- the input rows. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1;
+-------+----+
| a | b |
+-------+----+
| "def" | 8 |
+-------+----+
通过 analyze
方法指定输入行的分区
请注意,对于在 SQL 查询中调用 UDF 时对输入表进行分区的每种方法,UDTF 的 analyze
方法都有相应的方法来自动指定相同的分区方法。
- 与其使用
SELECT * FROM udtf(TABLE(t) PARTITION BY a)
调用 UDTF,不如更新analyze
方法以设置字段partitionBy=[PartitioningColumn("a")]
,然后使用SELECT * FROM udtf(TABLE(t))
调用函数。 - 同样地,您可以让
analyze
设置字段withSinglePartition=true
和orderBy=[OrderingColumn("b")]
,然后只需传递TABLE(t)
,而不是在 SQL 查询中指定TABLE(t) WITH SINGLE PARTITION ORDER BY b
。 - 可以将
analyze
设置为select=[SelectedColumn("a")]
,然后仅传递TABLE(t)
,而不是在 SQL 查询中传递TABLE(SELECT a FROM t)
。
在以下示例中,analyze
返回常量输出架构,从输入表中选择列的子集,并指定根据 date
列的值,将输入表分区到多个 UDTF 调用中:
@staticmethod
def analyze(*args) -> AnalyzeResult:
"""
The input table will be partitioned across several UDTF calls based on the monthly
values of each `date` column. The rows within each partition will arrive ordered by the `date`
column. The UDTF will only receive the `date` and `word` columns from the input table.
"""
from pyspark.sql.functions import (
AnalyzeResult,
OrderingColumn,
PartitioningColumn,
)
assert len(args) == 1, "This function accepts one argument only"
assert args[0].isTable, "Only table arguments are supported"
return AnalyzeResult(
schema=StructType()
.add("month", DateType())
.add('longest_word", IntegerType()),
partitionBy=[
PartitioningColumn("extract(month from date)")],
orderBy=[
OrderingColumn("date")],
select=[
SelectedColumn("date"),
SelectedColumn(
name="length(word),
alias="length_word")])