diff --git a/gs_quant/api/gs/risk.py b/gs_quant/api/gs/risk.py index 044d0cc9..808170cf 100644 --- a/gs_quant/api/gs/risk.py +++ b/gs_quant/api/gs/risk.py @@ -26,9 +26,10 @@ from typing import Iterable, Optional, Union import msgpack -from opentracing import Span +from opentracing import Span, Format from gs_quant.api.risk import RiskApi +from gs_quant.context_base import nullcontext from gs_quant.risk import RiskRequest from gs_quant.target.risk import OptimizationRequest from gs_quant.tracing import Tracer @@ -254,13 +255,14 @@ async def handle_websocket(): risk_session = cls.get_session() api_version = GsRiskApi.PRICING_API_VERSION or risk_session.api_version ws_url = f'/{api_version}/risk/calculate/results/subscribe' - async with risk_session._connect_websocket(ws_url, include_version=False) as ws: - if span: - Tracer.get_instance().scope_manager.activate(span, finish_on_close=False) - with Tracer(f'wss:/{ws_url}') as scope: + trace = Tracer(f'wss:/{ws_url}') if span else nullcontext + with trace as scope: + tracing_headers = {} + if scope and scope.span: + Tracer.inject(Format.HTTP_HEADERS, tracing_headers) + async with risk_session._connect_websocket(ws_url, tracing_headers, include_version=False) as ws: + if scope and scope.span: scope.span.set_tag('wss.host', ws.request_headers.get('host')) - error = await handle_websocket() - else: error = await handle_websocket() attempts = max_attempts @@ -278,6 +280,9 @@ async def handle_websocket(): if error != '': _logger.error(f'Fatal error with websocket: {error}') + if span: + span.set_tag('error', True) + span.log_kv({'event': 'error', 'message': error}) cls.shutdown_queue_listener(results) return error