1+ from __future__ import annotations
12from decimal import Decimal
2- from typing import cast
33
44from opentelemetry import trace
55from opentelemetry .sdk .trace .sampling import Sampler , SamplingResult , Decision
2121from typing import TYPE_CHECKING
2222
2323if TYPE_CHECKING :
24- from typing import Any , Optional , Sequence , Union
24+ from typing import Any , Optional , Sequence
2525 from opentelemetry .context import Context
2626 from opentelemetry .trace import Link , SpanKind
2727 from opentelemetry .trace .span import SpanContext
2828 from opentelemetry .util .types import Attributes
2929
3030
31- def get_parent_sampled (parent_context , trace_id ):
32- # type: (Optional[SpanContext], int) -> Optional[bool]
31+ def get_parent_sampled (
32+ parent_context : Optional [SpanContext ], trace_id : int
33+ ) -> Optional [bool ]:
3334 if parent_context is None :
3435 return None
3536
@@ -54,8 +55,9 @@ def get_parent_sampled(parent_context, trace_id):
5455 return None
5556
5657
57- def get_parent_sample_rate (parent_context , trace_id ):
58- # type: (Optional[SpanContext], int) -> Optional[float]
58+ def get_parent_sample_rate (
59+ parent_context : Optional [SpanContext ], trace_id : int
60+ ) -> Optional [float ]:
5961 if parent_context is None :
6062 return None
6163
@@ -74,8 +76,9 @@ def get_parent_sample_rate(parent_context, trace_id):
7476 return None
7577
7678
77- def get_parent_sample_rand (parent_context , trace_id ):
78- # type: (Optional[SpanContext], int) -> Optional[Decimal]
79+ def get_parent_sample_rand (
80+ parent_context : Optional [SpanContext ], trace_id : int
81+ ) -> Optional [Decimal ]:
7982 if parent_context is None :
8083 return None
8184
@@ -91,8 +94,12 @@ def get_parent_sample_rand(parent_context, trace_id):
9194 return None
9295
9396
94- def dropped_result (span_context , attributes , sample_rate = None , sample_rand = None ):
95- # type: (SpanContext, Attributes, Optional[float], Optional[Decimal]) -> SamplingResult
97+ def dropped_result (
98+ span_context : SpanContext ,
99+ attributes : Attributes ,
100+ sample_rate : Optional [float ] = None ,
101+ sample_rand : Optional [Decimal ] = None ,
102+ ) -> SamplingResult :
96103 """
97104 React to a span getting unsampled and return a DROP SamplingResult.
98105
@@ -129,8 +136,12 @@ def dropped_result(span_context, attributes, sample_rate=None, sample_rand=None)
129136 )
130137
131138
132- def sampled_result (span_context , attributes , sample_rate = None , sample_rand = None ):
133- # type: (SpanContext, Attributes, Optional[float], Optional[Decimal]) -> SamplingResult
139+ def sampled_result (
140+ span_context : SpanContext ,
141+ attributes : Attributes ,
142+ sample_rate : Optional [float ] = None ,
143+ sample_rand : Optional [Decimal ] = None ,
144+ ) -> SamplingResult :
134145 """
135146 React to a span being sampled and return a sampled SamplingResult.
136147
@@ -151,8 +162,12 @@ def sampled_result(span_context, attributes, sample_rate=None, sample_rand=None)
151162 )
152163
153164
154- def _update_trace_state (span_context , sampled , sample_rate = None , sample_rand = None ):
155- # type: (SpanContext, bool, Optional[float], Optional[Decimal]) -> TraceState
165+ def _update_trace_state (
166+ span_context : SpanContext ,
167+ sampled : bool ,
168+ sample_rate : Optional [float ] = None ,
169+ sample_rand : Optional [Decimal ] = None ,
170+ ) -> TraceState :
156171 trace_state = span_context .trace_state
157172
158173 sampled = "true" if sampled else "false"
@@ -175,15 +190,14 @@ def _update_trace_state(span_context, sampled, sample_rate=None, sample_rand=Non
175190class SentrySampler (Sampler ):
176191 def should_sample (
177192 self ,
178- parent_context , # type: Optional[Context]
179- trace_id , # type: int
180- name , # type: str
181- kind = None , # type: Optional[SpanKind]
182- attributes = None , # type: Attributes
183- links = None , # type: Optional[Sequence[Link]]
184- trace_state = None , # type: Optional[TraceState]
185- ):
186- # type: (...) -> SamplingResult
193+ parent_context : Optional [Context ],
194+ trace_id : int ,
195+ name : str ,
196+ kind : Optional [SpanKind ] = None ,
197+ attributes : Attributes = None ,
198+ links : Optional [Sequence [Link ]] = None ,
199+ trace_state : Optional [TraceState ] = None ,
200+ ) -> SamplingResult :
187201 client = sentry_sdk .get_client ()
188202
189203 parent_span_context = trace .get_current_span (parent_context ).get_span_context ()
@@ -209,13 +223,12 @@ def should_sample(
209223 sample_rand = parent_sample_rand
210224 else :
211225 # We are the head SDK and we need to generate a new sample_rand
212- sample_rand = cast ( Decimal , _generate_sample_rand (str (trace_id ), (0 , 1 ) ))
226+ sample_rand = _generate_sample_rand (str (trace_id ), (0 , 1 ))
213227
214228 # Explicit sampled value provided at start_span
215- custom_sampled = cast (
216- "Optional[bool]" , attributes .get (SentrySpanAttribute .CUSTOM_SAMPLED )
217- )
218- if custom_sampled is not None :
229+ custom_sampled = attributes .get (SentrySpanAttribute .CUSTOM_SAMPLED )
230+
231+ if custom_sampled is not None and isinstance (custom_sampled , bool ):
219232 if is_root_span :
220233 sample_rate = float (custom_sampled )
221234 if sample_rate > 0 :
@@ -262,7 +275,8 @@ def should_sample(
262275 sample_rate_to_propagate = sample_rate
263276
264277 # If the sample rate is invalid, drop the span
265- if not is_valid_sample_rate (sample_rate , source = self .__class__ .__name__ ):
278+ sample_rate = is_valid_sample_rate (sample_rate , source = self .__class__ .__name__ )
279+ if sample_rate is None :
266280 logger .warning (
267281 f"[Tracing.Sampler] Discarding { name } because of invalid sample rate."
268282 )
@@ -275,7 +289,6 @@ def should_sample(
275289 sample_rate_to_propagate = sample_rate
276290
277291 # Compare sample_rand to sample_rate to make the final sampling decision
278- sample_rate = float (cast ("Union[bool, float, int]" , sample_rate ))
279292 sampled = sample_rand < Decimal .from_float (sample_rate )
280293
281294 if sampled :
@@ -307,9 +320,13 @@ def get_description(self) -> str:
307320 return self .__class__ .__name__
308321
309322
310- def create_sampling_context (name , attributes , parent_span_context , trace_id ):
311- # type: (str, Attributes, Optional[SpanContext], int) -> dict[str, Any]
312- sampling_context = {
323+ def create_sampling_context (
324+ name : str ,
325+ attributes : Attributes ,
326+ parent_span_context : Optional [SpanContext ],
327+ trace_id : int ,
328+ ) -> dict [str , Any ]:
329+ sampling_context : dict [str , Any ] = {
313330 "transaction_context" : {
314331 "name" : name ,
315332 "op" : attributes .get (SentrySpanAttribute .OP ) if attributes else None ,
@@ -318,7 +335,7 @@ def create_sampling_context(name, attributes, parent_span_context, trace_id):
318335 ),
319336 },
320337 "parent_sampled" : get_parent_sampled (parent_span_context , trace_id ),
321- } # type: dict[str, Any]
338+ }
322339
323340 if attributes is not None :
324341 sampling_context .update (attributes )
0 commit comments