# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from __future__ import annotations
from contextlib import contextmanager
from contextvars import Token
from typing import Optional, Dict, Sequence, cast, Callable, Iterator, TYPE_CHECKING

from opentelemetry import context as otel_context_module, trace
from opentelemetry.trace import (
    Span,
    SpanKind as OpenTelemetrySpanKind,
    Link as OpenTelemetryLink,
    StatusCode,
)
from opentelemetry.trace.propagation import get_current_span as get_current_span_otel
from opentelemetry.propagate import extract, inject

try:
    from opentelemetry.context import _SUPPRESS_HTTP_INSTRUMENTATION_KEY  # type: ignore[attr-defined]
except ImportError:
    _SUPPRESS_HTTP_INSTRUMENTATION_KEY = "suppress_http_instrumentation"

from .._version import VERSION
from ._models import (
    Attributes,
    SpanKind as _SpanKind,
)

if TYPE_CHECKING:
    from azure.core.tracing import Link, SpanKind


_DEFAULT_SCHEMA_URL = "https://opentelemetry.io/schemas/1.23.1"
_DEFAULT_MODULE_NAME = "azure-core"

_KIND_MAPPINGS = {
    _SpanKind.CLIENT: OpenTelemetrySpanKind.CLIENT,
    _SpanKind.CONSUMER: OpenTelemetrySpanKind.CONSUMER,
    _SpanKind.PRODUCER: OpenTelemetrySpanKind.PRODUCER,
    _SpanKind.SERVER: OpenTelemetrySpanKind.SERVER,
    _SpanKind.INTERNAL: OpenTelemetrySpanKind.INTERNAL,
    _SpanKind.UNSPECIFIED: OpenTelemetrySpanKind.INTERNAL,
}


class OpenTelemetryTracer:
    """A tracer that uses OpenTelemetry to trace operations.

    :keyword library_name: The name of the library to use in the tracer.
    :paramtype library_name: str
    :keyword library_version: The version of the library to use in the tracer.
    :paramtype library_version: str
    :keyword schema_url: Specifies the Schema URL of the emitted spans. Defaults to
        "https://opentelemetry.io/schemas/1.23.1".
    :paramtype schema_url: str
    :keyword attributes: Attributes to add to the emitted spans.
    :paramtype attributes: Mapping[str, AttributeValue]
    """

    def __init__(
        self,
        *,
        library_name: Optional[str] = None,
        library_version: Optional[str] = None,
        schema_url: Optional[str] = None,
        attributes: Optional[Attributes] = None,
    ) -> None:
        self._tracer = trace.get_tracer(
            instrumenting_module_name=library_name or _DEFAULT_MODULE_NAME,
            instrumenting_library_version=library_version or VERSION,
            schema_url=schema_url or _DEFAULT_SCHEMA_URL,
            attributes=attributes,
        )

    def start_span(
        self,
        name: str,
        *,
        kind: SpanKind = _SpanKind.INTERNAL,
        attributes: Optional[Attributes] = None,
        links: Optional[Sequence[Link]] = None,
    ) -> Span:
        """Starts a span without setting it as the current span in the context.

        :param name: The name of the span
        :type name: str
        :keyword kind: The kind of the span. INTERNAL by default.
        :paramtype kind: ~azure.core.tracing.SpanKind
        :keyword attributes: Attributes to add to the span.
        :paramtype attributes: Mapping[str, AttributeValue]
        :keyword links: Links to add to the span.
        :paramtype links: list[~azure.core.tracing.Link]
        :return: The span that was started
        :rtype: ~opentelemetry.trace.Span
        """
        otel_kind = _KIND_MAPPINGS.get(kind, OpenTelemetrySpanKind.INTERNAL)
        otel_links = self._parse_links(links)

        otel_span = self._tracer.start_span(
            name,
            kind=otel_kind,
            attributes=attributes,
            links=otel_links,
            record_exception=False,
        )

        return otel_span

    @contextmanager
    def start_as_current_span(
        self,
        name: str,
        *,
        kind: SpanKind = _SpanKind.INTERNAL,
        attributes: Optional[Attributes] = None,
        links: Optional[Sequence[Link]] = None,
        end_on_exit: bool = True,
    ) -> Iterator[Span]:
        """Context manager that starts a span and sets it as the current span in the context.

        .. code:: python

            with tracer.start_as_current_span("span_name") as span:
                # Do something with the span
                span.set_attribute("key", "value")

        :param name: The name of the span
        :type name: str
        :keyword kind: The kind of the span. INTERNAL by default.
        :paramtype kind: ~azure.core.tracing.SpanKind
        :keyword attributes: Attributes to add to the span.
        :paramtype attributes: Optional[Attributes]
        :keyword links: Links to add to the span.
        :paramtype links: Optional[Sequence[Link]]
        :keyword end_on_exit: Whether to end the span when exiting the context manager. Defaults to True.
        :paramtype end_on_exit: bool
        :return: The span that was started
        :rtype: Iterator[~opentelemetry.trace.Span]
        """
        span = self.start_span(name, kind=kind, attributes=attributes, links=links)
        with trace.use_span(  # pylint: disable=not-context-manager
            span, record_exception=False, end_on_exit=end_on_exit
        ) as span:
            yield span

    @classmethod
    @contextmanager
    def use_span(cls, span: Span, *, end_on_exit: bool = True) -> Iterator[Span]:
        """Context manager that takes a non-active span and activates it in the current context.

        :param span: The span to set as the current span
        :type span: ~opentelemetry.trace.Span
        :keyword end_on_exit: Whether to end the span when exiting the context manager. Defaults to True.
        :paramtype end_on_exit: bool
        :return: The span that was activated.
        :rtype: Iterator[~opentelemetry.trace.Span]
        """
        with trace.use_span(  # pylint: disable=not-context-manager
            span, record_exception=False, end_on_exit=end_on_exit
        ) as active_span:
            yield active_span

    @staticmethod
    def set_span_error_status(span: Span, description: Optional[str] = None) -> None:
        """Set the status of a span to ERROR with the provided description, if any.

        :param span: The span to set the ERROR status on.
        :type span: ~opentelemetry.trace.Span
        :param description: An optional description of the error.
        :type description: str
        """
        span.set_status(StatusCode.ERROR, description=description)

    def _parse_links(self, links: Optional[Sequence[Link]]) -> Optional[Sequence[OpenTelemetryLink]]:
        if not links:
            return None

        try:
            otel_links = []
            for link in links:
                ctx = extract(link.headers)
                span_ctx = get_current_span_otel(ctx).get_span_context()
                otel_links.append(OpenTelemetryLink(span_ctx, link.attributes))
            return otel_links
        except AttributeError:
            # We will just send the links as is if it's not ~azure.core.tracing.Link without
            # any validation assuming the user knows what they are doing.
            return cast(Sequence[OpenTelemetryLink], links)

    @classmethod
    def get_current_span(cls) -> Span:
        """Returns the current span in the context.

        :return: The current span
        :rtype: ~opentelemetry.trace.Span
        """
        return get_current_span_otel()

    @classmethod
    def with_current_context(cls, func: Callable) -> Callable:
        """Passes the current spans to the new context the function will be run in.

        :param func: The function that will be run in the new context
        :type func: callable
        :return: The wrapped function
        :rtype: callable
        """
        current_context = otel_context_module.get_current()

        def call_with_current_context(*args, **kwargs):
            token = None
            try:
                token = otel_context_module.attach(current_context)
                return func(*args, **kwargs)
            finally:
                if token is not None:
                    otel_context_module.detach(token)

        return call_with_current_context

    @classmethod
    def get_trace_context(cls) -> Dict[str, str]:
        """Returns the Trace Context header values associated with the current span.

        These are generally the W3C Trace Context headers (i.e. "traceparent" and "tracestate").

        :return: A key value pair dictionary
        :rtype: dict[str, str]
        """
        trace_context: Dict[str, str] = {}
        inject(trace_context)
        return trace_context

    @classmethod
    def _suppress_auto_http_instrumentation(cls) -> Token:
        """Enabled automatic HTTP instrumentation suppression.

        Since azure-core already instruments HTTP calls, we need to suppress any automatic HTTP
        instrumentation provided by other libraries to prevent duplicate spans. This has no effect if no
        automatic HTTP instrumentation libraries are being used.

        :return: A token that can be used to detach the suppression key from the context
        :rtype: ~contextvars.Token
        """
        return otel_context_module.attach(otel_context_module.set_value(_SUPPRESS_HTTP_INSTRUMENTATION_KEY, True))

    @classmethod
    def _detach_from_context(cls, token: Token) -> None:
        """Detach a token from the context.

        :param token: The token to detach
        :type token: ~contextvars.Token
        """
        otel_context_module.detach(token)
