Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,6 @@ exclude = [
"grpc_test_service_pb2.py",
"grpc_test_service_pb2_grpc.py",
]
per-file-ignores = [
"sentry_sdk/integrations/spark/*:N802,N803",
]
92 changes: 36 additions & 56 deletions sentry_sdk/integrations/celery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import sys
from collections.abc import Mapping
from functools import wraps
Expand Down Expand Up @@ -62,11 +63,10 @@ class CeleryIntegration(Integration):

def __init__(
self,
propagate_traces=True,
monitor_beat_tasks=False,
exclude_beat_tasks=None,
):
# type: (bool, bool, Optional[List[str]]) -> None
propagate_traces: bool = True,
monitor_beat_tasks: bool = False,
exclude_beat_tasks: Optional[List[str]] = None,
) -> None:
self.propagate_traces = propagate_traces
self.monitor_beat_tasks = monitor_beat_tasks
self.exclude_beat_tasks = exclude_beat_tasks
Expand All @@ -76,8 +76,7 @@ def __init__(
_setup_celery_beat_signals(monitor_beat_tasks)

@staticmethod
def setup_once():
# type: () -> None
def setup_once() -> None:
_check_minimum_version(CeleryIntegration, CELERY_VERSION)

_patch_build_tracer()
Expand All @@ -97,16 +96,14 @@ def setup_once():
ignore_logger("celery.redirected")


def _set_status(status):
# type: (str) -> None
def _set_status(status: str) -> None:
with capture_internal_exceptions():
span = sentry_sdk.get_current_span()
if span is not None:
span.set_status(status)


def _capture_exception(task, exc_info):
# type: (Any, ExcInfo) -> None
def _capture_exception(task: Any, exc_info: ExcInfo) -> None:
client = sentry_sdk.get_client()
if client.get_integration(CeleryIntegration) is None:
return
Expand All @@ -129,10 +126,10 @@ def _capture_exception(task, exc_info):
sentry_sdk.capture_event(event, hint=hint)


def _make_event_processor(task, uuid, args, kwargs, request=None):
# type: (Any, Any, Any, Any, Optional[Any]) -> EventProcessor
def event_processor(event, hint):
# type: (Event, Hint) -> Optional[Event]
def _make_event_processor(
task: Any, uuid: Any, args: Any, kwargs: Any, request: Optional[Any] = None
) -> EventProcessor:
def event_processor(event: Event, hint: Hint) -> Optional[Event]:

with capture_internal_exceptions():
tags = event.setdefault("tags", {})
Expand All @@ -158,8 +155,9 @@ def event_processor(event, hint):
return event_processor


def _update_celery_task_headers(original_headers, span, monitor_beat_tasks):
# type: (dict[str, Any], Optional[Span], bool) -> dict[str, Any]
def _update_celery_task_headers(
original_headers: dict[str, Any], span: Optional[Span], monitor_beat_tasks: bool
) -> dict[str, Any]:
"""
Updates the headers of the Celery task with the tracing information
and eventually Sentry Crons monitoring information for beat tasks.
Expand Down Expand Up @@ -233,20 +231,16 @@ def _update_celery_task_headers(original_headers, span, monitor_beat_tasks):


class NoOpMgr:
def __enter__(self):
# type: () -> None
def __enter__(self) -> None:
return None

def __exit__(self, exc_type, exc_value, traceback):
# type: (Any, Any, Any) -> None
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
return None


def _wrap_task_run(f):
# type: (F) -> F
def _wrap_task_run(f: F) -> F:
@wraps(f)
def apply_async(*args, **kwargs):
# type: (*Any, **Any) -> Any
def apply_async(*args: Any, **kwargs: Any) -> Any:
# Note: kwargs can contain headers=None, so no setdefault!
# Unsure which backend though.
integration = sentry_sdk.get_client().get_integration(CeleryIntegration)
Expand All @@ -262,15 +256,15 @@ def apply_async(*args, **kwargs):
return f(*args, **kwargs)

if isinstance(args[0], Task):
task_name = args[0].name # type: str
task_name: str = args[0].name
elif len(args) > 1 and isinstance(args[1], str):
task_name = args[1]
else:
task_name = "<unknown Celery task>"

task_started_from_beat = sentry_sdk.get_isolation_scope()._name == "celery-beat"

span_mgr = (
span_mgr: Union[Span, NoOpMgr] = (
sentry_sdk.start_span(
op=OP.QUEUE_SUBMIT_CELERY,
name=task_name,
Expand All @@ -279,7 +273,7 @@ def apply_async(*args, **kwargs):
)
if not task_started_from_beat
else NoOpMgr()
) # type: Union[Span, NoOpMgr]
)

with span_mgr as span:
kwargs["headers"] = _update_celery_task_headers(
Expand All @@ -290,8 +284,7 @@ def apply_async(*args, **kwargs):
return apply_async # type: ignore


def _wrap_tracer(task, f):
# type: (Any, F) -> F
def _wrap_tracer(task: Any, f: F) -> F:

# Need to wrap tracer for pushing the scope before prerun is sent, and
# popping it after postrun is sent.
Expand All @@ -301,8 +294,7 @@ def _wrap_tracer(task, f):
# crashes.
@wraps(f)
@ensure_integration_enabled(CeleryIntegration, f)
def _inner(*args, **kwargs):
# type: (*Any, **Any) -> Any
def _inner(*args: Any, **kwargs: Any) -> Any:
with isolation_scope() as scope:
scope._name = "celery"
scope.clear_breadcrumbs()
Expand Down Expand Up @@ -333,8 +325,7 @@ def _inner(*args, **kwargs):
return _inner # type: ignore


def _set_messaging_destination_name(task, span):
# type: (Any, Span) -> None
def _set_messaging_destination_name(task: Any, span: Span) -> None:
"""Set "messaging.destination.name" tag for span"""
with capture_internal_exceptions():
delivery_info = task.request.delivery_info
Expand All @@ -346,8 +337,7 @@ def _set_messaging_destination_name(task, span):
span.set_attribute(SPANDATA.MESSAGING_DESTINATION_NAME, routing_key)


def _wrap_task_call(task, f):
# type: (Any, F) -> F
def _wrap_task_call(task: Any, f: F) -> F:

# Need to wrap task call because the exception is caught before we get to
# see it. Also celery's reported stacktrace is untrustworthy.
Expand All @@ -358,8 +348,7 @@ def _wrap_task_call(task, f):
# to add @functools.wraps(f) here.
# https://github.com/getsentry/sentry-python/issues/421
@ensure_integration_enabled(CeleryIntegration, f)
def _inner(*args, **kwargs):
# type: (*Any, **Any) -> Any
def _inner(*args: Any, **kwargs: Any) -> Any:
try:
with sentry_sdk.start_span(
op=OP.QUEUE_PROCESS,
Expand Down Expand Up @@ -409,14 +398,12 @@ def _inner(*args, **kwargs):
return _inner # type: ignore


def _patch_build_tracer():
# type: () -> None
def _patch_build_tracer() -> None:
import celery.app.trace as trace # type: ignore

original_build_tracer = trace.build_tracer

def sentry_build_tracer(name, task, *args, **kwargs):
# type: (Any, Any, *Any, **Any) -> Any
def sentry_build_tracer(name: Any, task: Any, *args: Any, **kwargs: Any) -> Any:
if not getattr(task, "_sentry_is_patched", False):
# determine whether Celery will use __call__ or run and patch
# accordingly
Expand All @@ -435,29 +422,25 @@ def sentry_build_tracer(name, task, *args, **kwargs):
trace.build_tracer = sentry_build_tracer


def _patch_task_apply_async():
# type: () -> None
def _patch_task_apply_async() -> None:
Task.apply_async = _wrap_task_run(Task.apply_async)


def _patch_celery_send_task():
# type: () -> None
def _patch_celery_send_task() -> None:
from celery import Celery

Celery.send_task = _wrap_task_run(Celery.send_task)


def _patch_worker_exit():
# type: () -> None
def _patch_worker_exit() -> None:

# Need to flush queue before worker shutdown because a crashing worker will
# call os._exit
from billiard.pool import Worker # type: ignore

original_workloop = Worker.workloop

def sentry_workloop(*args, **kwargs):
# type: (*Any, **Any) -> Any
def sentry_workloop(*args: Any, **kwargs: Any) -> Any:
try:
return original_workloop(*args, **kwargs)
finally:
Expand All @@ -471,13 +454,11 @@ def sentry_workloop(*args, **kwargs):
Worker.workloop = sentry_workloop


def _patch_producer_publish():
# type: () -> None
def _patch_producer_publish() -> None:
original_publish = Producer.publish

@ensure_integration_enabled(CeleryIntegration, original_publish)
def sentry_publish(self, *args, **kwargs):
# type: (Producer, *Any, **Any) -> Any
def sentry_publish(self: Producer, *args: Any, **kwargs: Any) -> Any:
kwargs_headers = kwargs.get("headers", {})
if not isinstance(kwargs_headers, Mapping):
# Ensure kwargs_headers is a Mapping, so we can safely call get().
Expand Down Expand Up @@ -521,8 +502,7 @@ def sentry_publish(self, *args, **kwargs):
Producer.publish = sentry_publish


def _prepopulate_attributes(task, args, kwargs):
# type: (Any, *Any, **Any) -> dict[str, str]
def _prepopulate_attributes(task: Any, args: Any, kwargs: Any) -> dict[str, str]:
attributes = {
"celery.job.task": task.name,
}
Expand Down
48 changes: 22 additions & 26 deletions sentry_sdk/integrations/celery/beat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import sentry_sdk
from sentry_sdk.crons import capture_checkin, MonitorStatus
from sentry_sdk.integrations import DidNotEnable
Expand Down Expand Up @@ -42,8 +43,7 @@
RedBeatScheduler = None


def _get_headers(task):
# type: (Task) -> dict[str, Any]
def _get_headers(task: Task) -> dict[str, Any]:
headers = task.request.get("headers") or {}

# flatten nested headers
Expand All @@ -56,12 +56,13 @@ def _get_headers(task):
return headers


def _get_monitor_config(celery_schedule, app, monitor_name):
# type: (Any, Celery, str) -> MonitorConfig
monitor_config = {} # type: MonitorConfig
schedule_type = None # type: Optional[MonitorConfigScheduleType]
schedule_value = None # type: Optional[Union[str, int]]
schedule_unit = None # type: Optional[MonitorConfigScheduleUnit]
def _get_monitor_config(
celery_schedule: Any, app: Celery, monitor_name: str
) -> MonitorConfig:
monitor_config: MonitorConfig = {}
schedule_type: Optional[MonitorConfigScheduleType] = None
schedule_value: Optional[Union[str, int]] = None
schedule_unit: Optional[MonitorConfigScheduleUnit] = None

if isinstance(celery_schedule, crontab):
schedule_type = "crontab"
Expand Down Expand Up @@ -113,8 +114,11 @@ def _get_monitor_config(celery_schedule, app, monitor_name):
return monitor_config


def _apply_crons_data_to_schedule_entry(scheduler, schedule_entry, integration):
# type: (Any, Any, sentry_sdk.integrations.celery.CeleryIntegration) -> None
def _apply_crons_data_to_schedule_entry(
scheduler: Any,
schedule_entry: Any,
integration: sentry_sdk.integrations.celery.CeleryIntegration,
) -> None:
"""
Add Sentry Crons information to the schedule_entry headers.
"""
Expand Down Expand Up @@ -158,8 +162,7 @@ def _apply_crons_data_to_schedule_entry(scheduler, schedule_entry, integration):
schedule_entry.options["headers"] = headers


def _wrap_beat_scheduler(original_function):
# type: (Callable[..., Any]) -> Callable[..., Any]
def _wrap_beat_scheduler(original_function: Callable[..., Any]) -> Callable[..., Any]:
"""
Makes sure that:
- a new Sentry trace is started for each task started by Celery Beat and
Expand All @@ -178,8 +181,7 @@ def _wrap_beat_scheduler(original_function):

from sentry_sdk.integrations.celery import CeleryIntegration

def sentry_patched_scheduler(*args, **kwargs):
# type: (*Any, **Any) -> None
def sentry_patched_scheduler(*args: Any, **kwargs: Any) -> None:
integration = sentry_sdk.get_client().get_integration(CeleryIntegration)
if integration is None:
return original_function(*args, **kwargs)
Expand All @@ -197,29 +199,25 @@ def sentry_patched_scheduler(*args, **kwargs):
return sentry_patched_scheduler


def _patch_beat_apply_entry():
# type: () -> None
def _patch_beat_apply_entry() -> None:
Scheduler.apply_entry = _wrap_beat_scheduler(Scheduler.apply_entry)


def _patch_redbeat_apply_async():
# type: () -> None
def _patch_redbeat_apply_async() -> None:
if RedBeatScheduler is None:
return

RedBeatScheduler.apply_async = _wrap_beat_scheduler(RedBeatScheduler.apply_async)


def _setup_celery_beat_signals(monitor_beat_tasks):
# type: (bool) -> None
def _setup_celery_beat_signals(monitor_beat_tasks: bool) -> None:
if monitor_beat_tasks:
task_success.connect(crons_task_success)
task_failure.connect(crons_task_failure)
task_retry.connect(crons_task_retry)


def crons_task_success(sender, **kwargs):
# type: (Task, dict[Any, Any]) -> None
def crons_task_success(sender: Task, **kwargs: dict[Any, Any]) -> None:
logger.debug("celery_task_success %s", sender)
headers = _get_headers(sender)

Expand All @@ -243,8 +241,7 @@ def crons_task_success(sender, **kwargs):
)


def crons_task_failure(sender, **kwargs):
# type: (Task, dict[Any, Any]) -> None
def crons_task_failure(sender: Task, **kwargs: dict[Any, Any]) -> None:
logger.debug("celery_task_failure %s", sender)
headers = _get_headers(sender)

Expand All @@ -268,8 +265,7 @@ def crons_task_failure(sender, **kwargs):
)


def crons_task_retry(sender, **kwargs):
# type: (Task, dict[Any, Any]) -> None
def crons_task_retry(sender: Task, **kwargs: dict[Any, Any]) -> None:
logger.debug("celery_task_retry %s", sender)
headers = _get_headers(sender)

Expand Down
Loading
Loading