pyspark.sql.DataFrame.observe#

DataFrame.observe(observation, *exprs)[source]#

Define (named) metrics to observe on the DataFrame. This method returns an ‘observed’ DataFrame that returns the same result as the input, with the following guarantees:

  • It will compute the defined aggregates (metrics) on all the data that is flowing through

    the Dataset at that point.

  • It will report the value of the defined aggregate columns as soon as we reach a completion

    point. A completion point is either the end of a query (batch mode) or the end of a streaming epoch. The value of the aggregates only reflects the data processed since the previous completion point.

The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that contain references to the input Dataset’s columns must always be wrapped in an aggregate function.

A user can observe these metrics by adding Python’s StreamingQueryListener, Scala/Java’s org.apache.spark.sql.streaming.StreamingQueryListener or Scala/Java’s org.apache.spark.sql.util.QueryExecutionListener to the spark session.

New in version 3.3.0.

Changed in version 3.5.0: Supports Spark Connect.

Parameters
observationObservation or str

str to specify the name, or an Observation instance to obtain the metric.

Changed in version 3.4.0: Added support for str in this parameter.

exprsColumn

column expressions (Column).

Returns
DataFrame

the observed DataFrame.

Notes

When observation is Observation, this method only supports batch queries. When observation is a string, this method works for both batch and streaming queries. Continuous execution is currently not supported yet.

Examples

When observation is Observation, only batch queries work as below.

>>> from pyspark.sql.functions import col, count, lit, max
>>> from pyspark.sql import Observation
>>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
>>> observation = Observation("my metrics")
>>> observed_df = df.observe(observation, count(lit(1)).alias("count"), max(col("age")))
>>> observed_df.count()
2
>>> observation.get
{'count': 2, 'max(age)': 5}

When observation is a string, streaming queries also work as below.

>>> from pyspark.sql.streaming import StreamingQueryListener
>>> import time
>>> class MyErrorListener(StreamingQueryListener):
...    def onQueryStarted(self, event):
...        pass
...
...    def onQueryProgress(self, event):
...        row = event.progress.observedMetrics.get("my_event")
...        # Trigger if the number of errors exceeds 5 percent
...        num_rows = row.rc
...        num_error_rows = row.erc
...        ratio = num_error_rows / num_rows
...        if ratio > 0.05:
...            # Trigger alert
...            pass
...
...    def onQueryIdle(self, event):
...        pass
...
...    def onQueryTerminated(self, event):
...        pass
...
>>> error_listener = MyErrorListener()
>>> spark.streams.addListener(error_listener)
>>> sdf = spark.readStream.format("rate").load().withColumn(
...     "error", col("value")
... )
>>> # Observe row count (rc) and error row count (erc) in the streaming Dataset
... observed_ds = sdf.observe(
...     "my_event",
...     count(lit(1)).alias("rc"),
...     count(col("error")).alias("erc"))
>>> try:
...     q = observed_ds.writeStream.format("console").start()
...     time.sleep(5)
...
... finally:
...     q.stop()
...     spark.streams.removeListener(error_listener)
...