diff --git a/sentry_sdk/integrations/fastapi.py b/sentry_sdk/integrations/fastapi.py index 3572b1c07f..747c00af1e 100644 --- a/sentry_sdk/integrations/fastapi.py +++ b/sentry_sdk/integrations/fastapi.py @@ -7,6 +7,7 @@ from sentry_sdk.scope import should_send_default_pii from sentry_sdk.traces import NoOpStreamedSpan, StreamedSpan from sentry_sdk.tracing import SOURCE_FOR_STYLE, TransactionSource +from sentry_sdk.tracing_utils import has_span_streaming_enabled from sentry_sdk.utils import transaction_from_function from typing import TYPE_CHECKING @@ -19,6 +20,7 @@ from sentry_sdk.integrations.starlette import ( StarletteIntegration, StarletteRequestExtractor, + _set_request_body_data_on_streaming_segment, ) except DidNotEnable: raise DidNotEnable("Starlette is not installed") @@ -102,7 +104,8 @@ def _sentry_call(*args: "Any", **kwargs: "Any") -> "Any": old_app = old_get_request_handler(*args, **kwargs) async def _sentry_app(*args: "Any", **kwargs: "Any") -> "Any": - integration = sentry_sdk.get_client().get_integration(FastApiIntegration) + client = sentry_sdk.get_client() + integration = client.get_integration(FastApiIntegration) if integration is None: return await old_app(*args, **kwargs) @@ -137,6 +140,9 @@ def event_processor(event: "Event", hint: "Dict[str, Any]") -> "Event": _make_request_event_processor(request, integration) ) + if has_span_streaming_enabled(client.options): + _set_request_body_data_on_streaming_segment(info) + return await old_app(*args, **kwargs) return _sentry_app diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index 036b797685..a69cab668b 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -234,7 +234,7 @@ async def _sentry_send(*args: "Any", **kwargs: "Any") -> "Any": return middleware_class -def _serialize_body_data(data: "Any") -> str: +def _serialize_request_body_data(data: "Any") -> str: # data may be a JSON-serializable value, an AnnotatedValue, or a dict with AnnotatedValue values def _default(value: "Any") -> "Any": if isinstance(value, AnnotatedValue): @@ -244,6 +244,22 @@ def _default(value: "Any") -> "Any": return json.dumps(data, default=_default) +def _set_request_body_data_on_streaming_segment( + info: "Optional[Dict[str, Any]]", +) -> None: + current_span = sentry_sdk.get_current_span() + if ( + info + and "data" in info + and isinstance(current_span, StreamedSpan) + and not isinstance(current_span, NoOpStreamedSpan) + ): + current_span._segment.set_attribute( + "http.request.body.data", + _serialize_request_body_data(info["data"]), + ) + + @ensure_integration_enabled(StarletteIntegration) def _capture_exception(exception: BaseException, handled: "Any" = False) -> None: event, hint = event_from_exception( @@ -510,21 +526,8 @@ def event_processor( _make_request_event_processor(request, integration) ) - is_span_streaming_enabled = has_span_streaming_enabled(client.options) - if is_span_streaming_enabled: - current_span = sentry_sdk.get_current_span() - - if ( - info - and "data" in info - and isinstance(current_span, StreamedSpan) - and not isinstance(current_span, NoOpStreamedSpan) - ): - data = info["data"] - current_span._segment.set_attribute( - "http.request.body.data", - _serialize_body_data(data), - ) + if has_span_streaming_enabled(client.options): + _set_request_body_data_on_streaming_segment(info) return await old_func(*args, **kwargs) diff --git a/tests/integrations/fastapi/test_fastapi.py b/tests/integrations/fastapi/test_fastapi.py index d321db993c..990fd40ff4 100644 --- a/tests/integrations/fastapi/test_fastapi.py +++ b/tests/integrations/fastapi/test_fastapi.py @@ -6,6 +6,7 @@ from unittest import mock import fastapi +import starlette from fastapi import FastAPI, HTTPException, Request from fastapi.testclient import TestClient from fastapi.middleware.trustedhost import TrustedHostMiddleware @@ -20,6 +21,7 @@ FASTAPI_VERSION = parse_version(fastapi.__version__) +STARLETTE_VERSION = parse_version(starlette.__version__) from tests.integrations.conftest import parametrize_test_configurable_status_codes from tests.integrations.starlette import test_starlette @@ -245,6 +247,143 @@ def test_active_thread_id_span_streaming(sentry_init, capture_items, endpoint): assert str(data["active"]) == segments[0]["attributes"]["thread.id"] +def _post_body_fastapi_app(handler_awaitable): + app = FastAPI() + + @app.post("/body") + async def _route(request: Request): + await handler_awaitable(request) + return {"ok": True} + + return app + + +@pytest.mark.parametrize("middleware_spans", [False, True]) +def test_request_body_data_does_not_scrub_pii_span_streaming( + sentry_init, capture_items, middleware_spans +): + sentry_init( + auto_enabling_integrations=False, + integrations=[ + StarletteIntegration(middleware_spans=middleware_spans), + FastApiIntegration(middleware_spans=middleware_spans), + ], + traces_sample_rate=1.0, + _experiments={"trace_lifecycle": "stream"}, + ) + + async def _read_json(request): + await request.json() + + items = capture_items("span") + + client = TestClient(_post_body_fastapi_app(_read_json)) + response = client.post( + "/body", + json={ + "password": "ohno", + "authorization": "Bearer token", + "message": "hello", + }, + ) + assert response.status_code == 200 + + sentry_sdk.flush() + + segments = [item.payload for item in items if item.payload.get("is_segment")] + assert len(segments) == 1 + attr = segments[0]["attributes"]["http.request.body.data"] + + # Going forward, the sanitization of data will need to happen within the `before_send_span` hooks + # See https://sentry.slack.com/archives/C09RR0KD2N7/p1776951331206129?thread_ts=1776951227.440659&cid=C09RR0KD2N7 + assert "ohno" in attr + assert "Bearer token" in attr + assert "hello" in attr + + +@pytest.mark.skipif( + STARLETTE_VERSION < (0, 21), + reason="Requires Starlette >= 0.21, because earlier versions use a requests-based TestClient which does not support the 'content' kwarg", +) +@pytest.mark.parametrize("middleware_spans", [False, True]) +def test_request_body_data_annotated_value_top_level_span_streaming( + sentry_init, capture_items, middleware_spans +): + sentry_init( + auto_enabling_integrations=False, + integrations=[ + StarletteIntegration(middleware_spans=middleware_spans), + FastApiIntegration(middleware_spans=middleware_spans), + ], + traces_sample_rate=1.0, + _experiments={"trace_lifecycle": "stream"}, + ) + + async def _read_body(request): + await request.body() + + items = capture_items("span") + + client = TestClient(_post_body_fastapi_app(_read_body)) + response = client.post( + "/body", + content=b"not json and not form", + headers={"content-type": "application/octet-stream"}, + ) + assert response.status_code == 200 + + sentry_sdk.flush() + + segments = [item.payload for item in items if item.payload.get("is_segment")] + assert len(segments) == 1 + attr = segments[0]["attributes"]["http.request.body.data"] + + assert isinstance(attr, str) + assert "!raw" in attr + + +@pytest.mark.parametrize("middleware_spans", [False, True]) +def test_request_body_data_annotated_value_nested_span_streaming( + sentry_init, capture_items, middleware_spans +): + pytest.importorskip("multipart") + + sentry_init( + auto_enabling_integrations=False, + integrations=[ + StarletteIntegration(middleware_spans=middleware_spans), + FastApiIntegration(middleware_spans=middleware_spans), + ], + traces_sample_rate=1.0, + _experiments={"trace_lifecycle": "stream"}, + ) + + async def _read_form(request): + await request.form() + + items = capture_items("span") + + client = TestClient(_post_body_fastapi_app(_read_form)) + response = client.post( + "/body", + data={"name": "erica"}, + files={"avatar": ("photo.jpg", b"fake-bytes", "image/jpeg")}, + ) + assert response.status_code == 200 + + sentry_sdk.flush() + + segments = [item.payload for item in items if item.payload.get("is_segment")] + assert len(segments) == 1 + attr = segments[0]["attributes"]["http.request.body.data"] + + assert isinstance(attr, str) + parsed = json.loads(attr) + assert parsed["name"] == "erica" + assert parsed["avatar"]["metadata"]["rem"] == [["!raw", "x"]] + assert "fake-bytes" not in attr + + @pytest.mark.parametrize("span_streaming", [True, False]) @pytest.mark.asyncio async def test_original_request_not_scrubbed(