diff --git a/sentry_sdk/integrations/spark/spark_driver.py b/sentry_sdk/integrations/spark/spark_driver.py index b22dc2c807..3581dea1a1 100644 --- a/sentry_sdk/integrations/spark/spark_driver.py +++ b/sentry_sdk/integrations/spark/spark_driver.py @@ -11,6 +11,8 @@ from sentry_sdk._types import Event, Hint from pyspark import SparkContext +_spark_context_class = None + class SparkIntegration(Integration): identifier = "spark" @@ -100,10 +102,19 @@ def _activate_integration(sc): def _patch_spark_context_init(): # type: () -> None - from pyspark import SparkContext + global _spark_context_class + if _spark_context_class is None: + from pyspark import SparkContext + + _spark_context_class = SparkContext + else: + SparkContext = _spark_context_class spark_context_init = SparkContext._do_init + if getattr(spark_context_init, "_sentry_patched", False): + return + @ensure_integration_enabled(SparkIntegration, spark_context_init) def _sentry_patched_spark_context_init(self, *args, **kwargs): # type: (SparkContext, *Any, **Any) -> Optional[Any] @@ -111,12 +122,19 @@ def _sentry_patched_spark_context_init(self, *args, **kwargs): _activate_integration(self) return rv + _sentry_patched_spark_context_init._sentry_patched = True SparkContext._do_init = _sentry_patched_spark_context_init def _setup_sentry_tracing(): # type: () -> None - from pyspark import SparkContext + global _spark_context_class + if _spark_context_class is None: + from pyspark import SparkContext + + _spark_context_class = SparkContext + else: + SparkContext = _spark_context_class if SparkContext._active_spark_context is not None: _activate_integration(SparkContext._active_spark_context)