diff --git a/google/cloud/aiplatform/telemetry.py b/google/cloud/aiplatform/telemetry.py index d3ada51ed8..552c2d61fb 100644 --- a/google/cloud/aiplatform/telemetry.py +++ b/google/cloud/aiplatform/telemetry.py @@ -16,14 +16,63 @@ # import contextlib -from google.cloud.aiplatform import base +import logging +import os +import traceback +import types +from typing import Collection + +from wrapt import wrap_function_wrapper + +from opentelemetry import context as context_api +from opentelemetry.trace import get_tracer, SpanKind +from opentelemetry.trace.status import Status, StatusCode +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY, unwrap +from opentelemetry.semconv_ai import ( + SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY, + SpanAttributes, + LLMRequestTypeValues, +) + +from google.cloud.aiplatform import base, version _LOGGER = base.Logger(__name__) +_instruments = ("google-cloud-aiplatform >= 1.38.1",) _tool_names_to_append = [] +class Config: + exception_logger = None + + +def dont_throw(func): + """ + A decorator that wraps the passed in function and logs exceptions instead of throwing them. + + @param func: The function to wrap + @return: The wrapper function + """ + # Obtain a logger specific to the function's module + logger = logging.getLogger(func.__module__) + + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger.debug( + "Failed to trace in %s, error: %s", + func.__name__, + traceback.format_exc(), + ) + if Config.exception_logger: + Config.exception_logger(e) + + return wrapper + + @contextlib.contextmanager def tool_context_manager(tool_name: str) -> None: """Context manager for appending tool name to client instantiations. @@ -65,3 +114,345 @@ def _pop_tool_name(tool_name: str) -> None: ) return _tool_names_to_append.pop() + + +WRAPPED_METHODS = [ + { + "package": "vertexai.generative_models", + "object": "GenerativeModel", + "method": "generate_content", + "span_name": "vertexai.generate_content", + "is_async": False, + }, + { + "package": "vertexai.generative_models", + "object": "GenerativeModel", + "method": "generate_content_async", + "span_name": "vertexai.generate_content_async", + "is_async": True, + }, + { + "package": "vertexai.preview.generative_models", + "object": "GenerativeModel", + "method": "generate_content", + "span_name": "vertexai.generate_content", + "is_async": False, + }, + { + "package": "vertexai.preview.generative_models", + "object": "GenerativeModel", + "method": "generate_content_async", + "span_name": "vertexai.generate_content_async", + "is_async": True, + }, + { + "package": "vertexai.language_models", + "object": "TextGenerationModel", + "method": "predict", + "span_name": "vertexai.predict", + "is_async": False, + }, + { + "package": "vertexai.language_models", + "object": "TextGenerationModel", + "method": "predict_async", + "span_name": "vertexai.predict_async", + "is_async": True, + }, + { + "package": "vertexai.language_models", + "object": "TextGenerationModel", + "method": "predict_streaming", + "span_name": "vertexai.predict_streaming", + "is_async": False, + }, + { + "package": "vertexai.language_models", + "object": "TextGenerationModel", + "method": "predict_streaming_async", + "span_name": "vertexai.predict_streaming_async", + "is_async": True, + }, + { + "package": "vertexai.language_models", + "object": "ChatSession", + "method": "send_message", + "span_name": "vertexai.send_message", + "is_async": False, + }, + { + "package": "vertexai.language_models", + "object": "ChatSession", + "method": "send_message_streaming", + "span_name": "vertexai.send_message_streaming", + "is_async": False, + }, +] + + +def should_send_prompts(): + return ( + os.getenv("TRACELOOP_TRACE_CONTENT") or "true" + ).lower() == "true" or context_api.get_value("override_enable_content_tracing") + + +def is_streaming_response(response): + return isinstance(response, types.GeneratorType) + + +def is_async_streaming_response(response): + return isinstance(response, types.AsyncGeneratorType) + + +def _set_span_attribute(span, name, value): + if value is not None: + if value != "": + span.set_attribute(name, value) + return + + +def _set_input_attributes(span, args, kwargs, llm_model): + if should_send_prompts() and args is not None and len(args) > 0: + prompt = "" + for arg in args: + if isinstance(arg, str): + prompt = f"{prompt}{arg}\n" + elif isinstance(arg, list): + for subarg in arg: + prompt = f"{prompt}{subarg}\n" + + _set_span_attribute( + span, + f"{SpanAttributes.LLM_PROMPTS}.0.user", + prompt, + ) + + _set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, llm_model) + _set_span_attribute( + span, f"{SpanAttributes.LLM_PROMPTS}.0.user", kwargs.get("prompt") + ) + _set_span_attribute( + span, SpanAttributes.LLM_REQUEST_TEMPERATURE, kwargs.get("temperature") + ) + _set_span_attribute( + span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, kwargs.get("max_output_tokens") + ) + _set_span_attribute(span, SpanAttributes.LLM_REQUEST_TOP_P, kwargs.get("top_p")) + _set_span_attribute(span, SpanAttributes.LLM_TOP_K, kwargs.get("top_k")) + _set_span_attribute( + span, SpanAttributes.LLM_PRESENCE_PENALTY, kwargs.get("presence_penalty") + ) + _set_span_attribute( + span, SpanAttributes.LLM_FREQUENCY_PENALTY, kwargs.get("frequency_penalty") + ) + + return + + +@dont_throw +def _set_response_attributes(span, llm_model, generation_text, token_usage): + _set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, llm_model) + + if token_usage: + _set_span_attribute( + span, + SpanAttributes.LLM_USAGE_TOTAL_TOKENS, + token_usage.total_token_count, + ) + _set_span_attribute( + span, + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, + token_usage.candidates_token_count, + ) + _set_span_attribute( + span, + SpanAttributes.LLM_USAGE_PROMPT_TOKENS, + token_usage.prompt_token_count, + ) + + _set_span_attribute(span, f"{SpanAttributes.LLM_COMPLETIONS}.0.role", "assistant") + _set_span_attribute( + span, + f"{SpanAttributes.LLM_COMPLETIONS}.0.content", + generation_text, + ) + + +def _build_from_streaming_response(span, response, llm_model): + complete_response = "" + token_usage = None + for item in response: + item_to_yield = item + complete_response += str(item.text) + if item.usage_metadata: + token_usage = item.usage_metadata + + yield item_to_yield + + _set_response_attributes(span, llm_model, complete_response, token_usage) + + span.set_status(Status(StatusCode.OK)) + span.end() + + +async def _abuild_from_streaming_response(span, response, llm_model): + complete_response = "" + token_usage = None + async for item in response: + item_to_yield = item + complete_response += str(item.text) + if item.usage_metadata: + token_usage = item.usage_metadata + + yield item_to_yield + + _set_response_attributes(span, llm_model, complete_response, token_usage) + + span.set_status(Status(StatusCode.OK)) + span.end() + + +@dont_throw +def _handle_request(span, args, kwargs, llm_model): + if span.is_recording(): + _set_input_attributes(span, args, kwargs, llm_model) + + +@dont_throw +def _handle_response(span, response, llm_model): + if span.is_recording(): + _set_response_attributes( + span, llm_model, response.candidates[0].text, response.usage_metadata + ) + + span.set_status(Status(StatusCode.OK)) + + +def _with_tracer_wrapper(func): + """Helper for providing tracer for wrapper functions.""" + + def _with_tracer(tracer, to_wrap): + def wrapper(wrapped, instance, args, kwargs): + return func(tracer, to_wrap, wrapped, instance, args, kwargs) + + return wrapper + + return _with_tracer + + +@_with_tracer_wrapper +async def _awrap(tracer, to_wrap, wrapped, instance, args, kwargs): + """Instruments and calls every function defined in TO_WRAP.""" + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY) or context_api.get_value( + SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY + ): + return await wrapped(*args, **kwargs) + + llm_model = "unknown" + if hasattr(instance, "_model_id"): + llm_model = instance._model_id + if hasattr(instance, "_model_name"): + llm_model = instance._model_name.replace("publishers/google/models/", "") + + name = to_wrap.get("span_name") + span = tracer.start_span( + name, + kind=SpanKind.CLIENT, + attributes={ + SpanAttributes.LLM_SYSTEM: "VertexAI", + SpanAttributes.LLM_REQUEST_TYPE: LLMRequestTypeValues.COMPLETION.value, + }, + ) + + _handle_request(span, args, kwargs, llm_model) + + response = await wrapped(*args, **kwargs) + + if response: + if is_streaming_response(response): + return _build_from_streaming_response(span, response, llm_model) + elif is_async_streaming_response(response): + return _abuild_from_streaming_response(span, response, llm_model) + else: + _handle_response(span, response, llm_model) + + span.end() + return response + + +@_with_tracer_wrapper +def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs): + """Instruments and calls every function defined in TO_WRAP.""" + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY) or context_api.get_value( + SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY + ): + return wrapped(*args, **kwargs) + + llm_model = "unknown" + if hasattr(instance, "_model_id"): + llm_model = instance._model_id + if hasattr(instance, "_model_name"): + llm_model = instance._model_name.replace("publishers/google/models/", "") + + name = to_wrap.get("span_name") + span = tracer.start_span( + name, + kind=SpanKind.CLIENT, + attributes={ + SpanAttributes.LLM_SYSTEM: "VertexAI", + SpanAttributes.LLM_REQUEST_TYPE: LLMRequestTypeValues.COMPLETION.value, + }, + ) + + _handle_request(span, args, kwargs, llm_model) + + response = wrapped(*args, **kwargs) + + if response: + if is_streaming_response(response): + return _build_from_streaming_response(span, response, llm_model) + elif is_async_streaming_response(response): + return _abuild_from_streaming_response(span, response, llm_model) + else: + _handle_response(span, response, llm_model) + + span.end() + return response + + +class VertexAIInstrumentor(BaseInstrumentor): + """An instrumentor for VertextAI's client library.""" + + def __init__(self, exception_logger=None): + super().__init__() + Config.exception_logger = exception_logger + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs): + tracer_provider = kwargs.get("tracer_provider") + tracer = get_tracer(__name__, version.__version__, tracer_provider) + for wrapped_method in WRAPPED_METHODS: + wrap_package = wrapped_method.get("package") + wrap_object = wrapped_method.get("object") + wrap_method = wrapped_method.get("method") + + wrap_function_wrapper( + wrap_package, + f"{wrap_object}.{wrap_method}", + ( + _awrap(tracer, wrapped_method) + if wrapped_method.get("is_async") + else _wrap(tracer, wrapped_method) + ), + ) + + def _uninstrument(self, **kwargs): + for wrapped_method in WRAPPED_METHODS: + wrap_package = wrapped_method.get("package") + wrap_object = wrapped_method.get("object") + unwrap( + f"{wrap_package}.{wrap_object}", + wrapped_method.get("method", ""), + ) diff --git a/setup.py b/setup.py index 21db80de42..af2a24a3f0 100644 --- a/setup.py +++ b/setup.py @@ -252,6 +252,11 @@ "google-cloud-bigquery >= 1.15.0, < 4.0.0dev, !=3.20.0", "google-cloud-resource-manager >= 1.3.3, < 3.0.0dev", "shapely < 3.0.0dev", + "opentelemetry-sdk < 2", + "opentelemetry-api >= 1.28.0, < 2", + "opentelemetry-instrumentation", + "opentelemetry-semantic-conventions", + "opentelemetry-semantic-conventions-ai", ) + genai_requires, extras_require={