Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/sentry/flags/endpoints/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,19 @@ def serialize(self, obj, attrs, user, **kwargs) -> FlagWebhookSigningSecretRespo

class FlagWebhookSigningSecretValidator(serializers.Serializer):
provider = serializers.ChoiceField(
choices=["launchdarkly", "generic", "unleash"], required=True
choices=["launchdarkly", "generic", "unleash", "statsig"], required=True
)
secret = serializers.CharField(required=True, max_length=32, min_length=32)
secret = serializers.CharField(required=True)

def validate_secret(self, value):
if self.initial_data.get("provider") == "statsig":
if not value.startswith("webhook-"):
raise serializers.ValidationError(
"Ensure this field is of the format webhook-<hash>"
)
return serializers.CharField(min_length=32, max_length=64).run_validation(value)

return serializers.CharField(min_length=32, max_length=32).run_validation(value)


@region_silo_endpoint
Expand Down
154 changes: 146 additions & 8 deletions src/sentry/flags/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FlagWebHookSigningSecretModel,
)
from sentry.silo.base import SiloLimit
from sentry.utils.safe import get_path


def write(rows: list["FlagAuditLogRow"]) -> None:
Expand Down Expand Up @@ -54,7 +55,7 @@ class ProviderProtocol(Protocol[T]):
provider_name: str
signature: str | None

def __init__(self, organization_id: int, signature: str | None) -> None: ...
def __init__(self, organization_id: int, signature: str | None, **kwargs) -> None: ...
def handle(self, message: T) -> list[FlagAuditLogRow]: ...
def validate(self, message_bytes: bytes) -> bool: ...

Expand Down Expand Up @@ -82,6 +83,12 @@ def get_provider(
return GenericProvider(organization_id, signature=headers.get("X-Sentry-Signature"))
case "unleash":
return UnleashProvider(organization_id, signature=headers.get("Authorization"))
case "statsig":
return StatsigProvider(
organization_id,
signature=headers.get("X-Statsig-Signature"),
request_timestamp=headers.get("X-Statsig-Request-Timestamp"),
)
case _:
return None

Expand Down Expand Up @@ -355,6 +362,136 @@ def _handle_unleash_actions(action: str) -> int:
return ACTION_MAP["updated"]


"""Statsig provider."""

SUPPORTED_STATSIG_EVENTS = {"statsig::config_change"}

# config_change is subclassed by the type of Statsig feature. There's "Gate",
# "Experiment", and more. Feature gates are boolean release flags, but all
# other types are unstructured JSON. To reduce noise, Gate is the only type
# we audit for now.
SUPPORTED_STATSIG_TYPES = {
"Gate",
"gate", # Supporting this just in case. Statsig docs and sample events use capitalization.
}


class _StatsigEventSerializer(serializers.Serializer):
eventName = serializers.CharField(required=True)
timestamp = serializers.CharField(required=True) # Custom serializer defined below.
metadata = serializers.DictField(required=True)

user = serializers.DictField(required=False, child=serializers.CharField())
userID = serializers.CharField(required=False)
value = serializers.CharField(required=False)
statsigMetadata = serializers.DictField(required=False)
timeUUID = serializers.UUIDField(required=False)
unitID = serializers.CharField(required=False)

def validate_timestamp(self, value: str):
try:
float(value)
except ValueError:
raise serializers.ValidationError(
'"timestamp" field must be a string number, representing milliseconds since epoch.'
)
return value


class StatsigItemSerializer(serializers.Serializer):
data = serializers.ListField(child=_StatsigEventSerializer(), required=True) # type: ignore[assignment]


class StatsigProvider:
provider_name = "statsig"

def __init__(
self,
organization_id: int,
signature: str | None,
request_timestamp: str | None,
version: str = "v0",
) -> None:
self.organization_id = organization_id
self.signature = signature
self.request_timestamp = request_timestamp
self.version = version

# Strip the signature's version prefix. For example, signature format for v0 is "v0+{hash}"
prefix_len = len(version) + 1
if signature and len(signature) > prefix_len:
self.signature = signature[prefix_len:]

def handle(self, message: dict[str, Any]) -> list[FlagAuditLogRow]:
serializer = StatsigItemSerializer(data=message)
if not serializer.is_valid():
raise DeserializationError(serializer.errors)

events = serializer.validated_data["data"]
audit_logs = []
for event in events:
event_name = event["eventName"]

if event_name not in SUPPORTED_STATSIG_EVENTS:
continue

metadata = event.get("metadata") or {}
flag = metadata.get("name")
statsig_type = metadata.get("type")
action = (metadata.get("action") or "").lower()

if not flag or statsig_type not in SUPPORTED_STATSIG_TYPES or action not in ACTION_MAP:
continue

action = ACTION_MAP[action]

# Prioritize email > id > name for created_by.
if created_by := get_path(event, "user", "email"):
created_by_type = CREATED_BY_TYPE_MAP["email"]
elif created_by := event.get("userID") or get_path(event, "user", "userID"):
created_by_type = CREATED_BY_TYPE_MAP["id"]
elif created_by := get_path(event, "user", "name"):
created_by_type = CREATED_BY_TYPE_MAP["name"]
else:
created_by, created_by_type = None, None

created_at_ms = float(event["timestamp"])
created_at = datetime.datetime.fromtimestamp(created_at_ms / 1000.0, datetime.UTC)

tags = {}
if projectName := metadata.get("projectName"):
tags["projectName"] = projectName
if projectID := metadata.get("projectID"):
tags["projectID"] = projectID
if environments := metadata.get("environments"):
tags["environments"] = environments

audit_logs.append(
FlagAuditLogRow(
action=action,
created_at=created_at,
created_by=created_by,
created_by_type=created_by_type,
flag=flag,
organization_id=self.organization_id,
tags=tags,
)
)

return audit_logs

def validate(self, message_bytes: bytes) -> bool:
if self.request_timestamp is None:
return False

signature_basestring = f"{self.version}:{self.request_timestamp}:".encode() + message_bytes

validator = PayloadSignatureValidator(
self.organization_id, self.provider_name, signature_basestring, self.signature
)
return validator.validate()


"""Flagpole provider."""


Expand Down Expand Up @@ -389,10 +526,11 @@ def handle_flag_pole_event_internal(items: list[FlagAuditLogItem], organization_


class AuthTokenValidator:
"""Abstract payload validator.
"""Abstract validator for injecting dependencies in tests. Use this when a
provider does not support signing.

Similar to the SecretValidator class below, except we do not need
to validate the authorization string.
Similar to the PayloadSignatureValidator class below, except we do not
validate the authorization string with the payload.
"""

def __init__(
Expand All @@ -419,7 +557,7 @@ def validate(self) -> bool:


class PayloadSignatureValidator:
"""Abstract payload validator.
"""Abstract payload validator. Uses HMAC-SHA256 by default.

Allows us to inject dependencies for differing use cases. Specifically
the test suite.
Expand All @@ -429,14 +567,14 @@ def __init__(
self,
organization_id: int,
provider: str,
request_body: bytes,
message: bytes,
signature: str | None,
secret_finder: Callable[[int, str], Iterator[str]] | None = None,
secret_validator: Callable[[str, bytes], str] | None = None,
) -> None:
self.organization_id = organization_id
self.provider = provider
self.request_body = request_body
self.message = message
self.signature = signature
self.secret_finder = secret_finder or _query_signing_secrets
self.secret_validator = secret_validator or hmac_sha256_hex_digest
Expand All @@ -446,7 +584,7 @@ def validate(self) -> bool:
return False

for secret in self.secret_finder(self.organization_id, self.provider):
if self.secret_validator(secret, self.request_body) == self.signature:
if self.secret_validator(secret, self.message) == self.signature:
return True
return False

Expand Down
50 changes: 50 additions & 0 deletions tests/sentry/flags/endpoints/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,56 @@ def test_unleash_post_create(self, mock_incr):
)
assert FlagAuditLogModel.objects.count() == 1

def test_statsig_post_create(self, mock_incr):
request_data = {
"data": [
{
"user": {"name": "johndoe", "email": "[email protected]"},
"timestamp": 1739400185198,
"eventName": "statsig::config_change",
"metadata": {
"projectName": "sentry",
"projectID": "1",
"type": "Gate",
"name": "gate1",
"description": "Updated Config Conditions\n - Added rule Rule 1",
"environments": "development,staging,production",
"action": "updated",
"tags": [],
"targetApps": [],
},
},
]
}

secret = "webhook-Xk9pL8NQaR5Ym2cx7vHnWtBj4M3f6qyZdC12mnspk8"

FlagWebHookSigningSecretModel.objects.create(
organization=self.organization,
provider="statsig",
secret=secret,
)

request_timestamp = "1739400185400" # ms timestamp of the webhook request
signature_basestring = f"v0:{request_timestamp}:{json.dumps(request_data)}".encode()
signature = "v0=" + hmac_sha256_hex_digest(key=secret, message=signature_basestring)
headers = {
"X-Statsig-Signature": signature,
"X-Statsig-Request-Timestamp": request_timestamp,
}

with self.feature(self.features):
response = self.client.post(
reverse(self.endpoint, args=(self.organization.slug, "statsig")),
request_data,
headers=headers,
)
assert response.status_code == 200, response.content
mock_incr.assert_any_call(
"feature_flags.audit_log_event_posted", tags={"provider": "statsig"}
)
assert FlagAuditLogModel.objects.count() == 1

def test_launchdarkly_post_create(self, mock_incr):
request_data = LD_REQUEST
signature = hmac_sha256_hex_digest(key="456", message=json.dumps(request_data).encode())
Expand Down
61 changes: 60 additions & 1 deletion tests/sentry/flags/endpoints/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,77 @@ def test_post_unleash(self):
assert len(models) == 1
assert models[0].secret == "41271af8b9804cd99a4c787a28274991"

def test_post_statsig(self):
with self.feature(self.features):
response = self.client.post(
self.url,
data={
"secret": "webhook-Xk9pL8NQaR5Ym2cx7vHnWtBj4M3f6qyZdC12mnspk8",
"provider": "statsig",
},
)
assert response.status_code == 201, response.content

models = FlagWebHookSigningSecretModel.objects.filter(provider="statsig").all()
assert len(models) == 1
assert models[0].secret == "webhook-Xk9pL8NQaR5Ym2cx7vHnWtBj4M3f6qyZdC12mnspk8"

def test_post_disabled(self):
response = self.client.post(self.url, data={})
assert response.status_code == 404, response.content

def test_post_invalid(self):
def test_post_invalid_provider(self):
with self.feature(self.features):
url = reverse(self.endpoint, args=(self.organization.id,))
response = self.client.post(url, data={"secret": "123", "provider": "other"})
assert response.status_code == 400, response.content
assert response.json()["provider"] == ['"other" is not a valid choice.']
assert response.json()["secret"] == ["Ensure this field has at least 32 characters."]

def test_post_invalid_secret(self):
with self.feature(self.features):
for provider in ["launchdarkly", "generic", "unleash"]:
response = self.client.post(
self.url, data={"secret": "a" * 31, "provider": provider}
)
assert response.status_code == 400, response.content
assert response.json()["secret"] == [
"Ensure this field has at least 32 characters."
], provider

response = self.client.post(
self.url, data={"secret": "a" * 33, "provider": provider}
)
assert response.status_code == 400, response.content
assert response.json()["secret"] == [
"Ensure this field has no more than 32 characters."
], provider

# Statsig
response = self.client.post(self.url, data={"secret": "a" * 32, "provider": "statsig"})
assert response.status_code == 400, response.content
assert response.json()["secret"] == [
"Ensure this field is of the format webhook-<hash>"
], "statsig"

response = self.client.post(
self.url,
data={"secret": "webhook-" + "a" * (31 - len("webhook-")), "provider": "statsig"},
)
assert response.status_code == 400, response.content
assert response.json()["secret"] == [
"Ensure this field has at least 32 characters."
], "statsig"

response = self.client.post(
self.url,
data={"secret": "webhook-" + "a" * (65 - len("webhook-")), "provider": "statsig"},
)
assert response.status_code == 400, response.content
assert response.json()["secret"] == [
"Ensure this field has no more than 64 characters."
], "statsig"

def test_post_empty_request(self):
with self.feature(self.features):
response = self.client.post(self.url, data={})
Expand Down
Loading
Loading