pyspark.sql.DataFrame.observe#
- DataFrame.observe(observation, *exprs)[source]#
Define (named) metrics to observe on the DataFrame. This method returns an ‘observed’ DataFrame that returns the same result as the input, with the following guarantees:
- It will compute the defined aggregates (metrics) on all the data that is flowing through
the Dataset at that point.
- It will report the value of the defined aggregate columns as soon as we reach a completion
point. A completion point is either the end of a query (batch mode) or the end of a streaming epoch. The value of the aggregates only reflects the data processed since the previous completion point.
The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that contain references to the input Dataset’s columns must always be wrapped in an aggregate function.
A user can observe these metrics by adding Python’s
StreamingQueryListener
, Scala/Java’sorg.apache.spark.sql.streaming.StreamingQueryListener
or Scala/Java’sorg.apache.spark.sql.util.QueryExecutionListener
to the spark session.New in version 3.3.0.
Changed in version 3.5.0: Supports Spark Connect.
- Parameters
- observation
Observation
or str str to specify the name, or an
Observation
instance to obtain the metric.Changed in version 3.4.0: Added support for str in this parameter.
- exprs
Column
column expressions (
Column
).
- observation
- Returns
Notes
When
observation
isObservation
, this method only supports batch queries. Whenobservation
is a string, this method works for both batch and streaming queries. Continuous execution is currently not supported yet.Examples
When
observation
isObservation
, only batch queries work as below.>>> from pyspark.sql.functions import col, count, lit, max >>> from pyspark.sql import Observation >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) >>> observation = Observation("my metrics") >>> observed_df = df.observe(observation, count(lit(1)).alias("count"), max(col("age"))) >>> observed_df.count() 2 >>> observation.get {'count': 2, 'max(age)': 5}
When
observation
is a string, streaming queries also work as below.>>> from pyspark.sql.streaming import StreamingQueryListener >>> import time >>> class MyErrorListener(StreamingQueryListener): ... def onQueryStarted(self, event): ... pass ... ... def onQueryProgress(self, event): ... row = event.progress.observedMetrics.get("my_event") ... # Trigger if the number of errors exceeds 5 percent ... num_rows = row.rc ... num_error_rows = row.erc ... ratio = num_error_rows / num_rows ... if ratio > 0.05: ... # Trigger alert ... pass ... ... def onQueryIdle(self, event): ... pass ... ... def onQueryTerminated(self, event): ... pass ... >>> error_listener = MyErrorListener() >>> spark.streams.addListener(error_listener) >>> sdf = spark.readStream.format("rate").load().withColumn( ... "error", col("value") ... ) >>> # Observe row count (rc) and error row count (erc) in the streaming Dataset ... observed_ds = sdf.observe( ... "my_event", ... count(lit(1)).alias("rc"), ... count(col("error")).alias("erc")) >>> try: ... q = observed_ds.writeStream.format("console").start() ... time.sleep(5) ... ... finally: ... q.stop() ... spark.streams.removeListener(error_listener) ...