Source code for langsmith.run_trees

"""Schemas for the LangSmith API."""

from __future__ import annotations

import json
import logging
import sys
from datetime import datetime, timezone
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
from uuid import UUID, uuid4

try:
    from pydantic.v1 import Field, root_validator  # type: ignore[import]
except ImportError:
    from pydantic import (  # type: ignore[assignment, no-redef]
        Field,
        root_validator,
    )

import threading
import urllib.parse

from langsmith import schemas as ls_schemas
from langsmith import utils
from langsmith.client import ID_TYPE, RUN_TYPE_T, Client, _dumps_json, _ensure_uuid

logger = logging.getLogger(__name__)

LANGSMITH_PREFIX = "langsmith-"
LANGSMITH_DOTTED_ORDER = sys.intern(f"{LANGSMITH_PREFIX}trace")
LANGSMITH_DOTTED_ORDER_BYTES = LANGSMITH_DOTTED_ORDER.encode("utf-8")
LANGSMITH_METADATA = sys.intern(f"{LANGSMITH_PREFIX}metadata")
LANGSMITH_TAGS = sys.intern(f"{LANGSMITH_PREFIX}tags")
LANGSMITH_PROJECT = sys.intern(f"{LANGSMITH_PREFIX}project")
_CLIENT: Optional[Client] = None
_LOCK = threading.Lock()  # Keeping around for a while for backwards compat


# Note, this is called directly by langchain. Do not remove.


[docs]def get_cached_client(**init_kwargs: Any) -> Client: global _CLIENT if _CLIENT is None: if _CLIENT is None: _CLIENT = Client(**init_kwargs) return _CLIENT
[docs]class RunTree(ls_schemas.RunBase): """Run Schema with back-references for posting runs.""" name: str id: UUID = Field(default_factory=uuid4) run_type: str = Field(default="chain") start_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) parent_run: Optional[RunTree] = Field(default=None, exclude=True) child_runs: List[RunTree] = Field( default_factory=list, exclude={"__all__": {"parent_run_id"}}, ) session_name: str = Field( default_factory=lambda: utils.get_tracer_project() or "default", alias="project_name", ) session_id: Optional[UUID] = Field(default=None, alias="project_id") extra: Dict = Field(default_factory=dict) tags: Optional[List[str]] = Field(default_factory=list) events: List[Dict] = Field(default_factory=list) """List of events associated with the run, like start and end events.""" ls_client: Optional[Any] = Field(default=None, exclude=True) dotted_order: str = Field( default="", description="The order of the run in the tree." ) trace_id: UUID = Field(default="", description="The trace id of the run.") # type: ignore class Config: """Pydantic model configuration.""" arbitrary_types_allowed = True allow_population_by_field_name = True extra = "ignore" @root_validator(pre=True) def infer_defaults(cls, values: dict) -> dict: """Assign name to the run.""" if values.get("name") is None and values.get("serialized") is not None: if "name" in values["serialized"]: values["name"] = values["serialized"]["name"] elif "id" in values["serialized"]: values["name"] = values["serialized"]["id"][-1] if values.get("name") is None: values["name"] = "Unnamed" if "client" in values: # Handle user-constructed clients values["ls_client"] = values.pop("client") elif "_client" in values: values["ls_client"] = values.pop("_client") if not values.get("ls_client"): values["ls_client"] = None if values.get("parent_run") is not None: values["parent_run_id"] = values["parent_run"].id if "id" not in values: values["id"] = uuid4() if "trace_id" not in values: if "parent_run" in values: values["trace_id"] = values["parent_run"].trace_id else: values["trace_id"] = values["id"] cast(dict, values.setdefault("extra", {})) if values.get("events") is None: values["events"] = [] if values.get("tags") is None: values["tags"] = [] if values.get("outputs") is None: values["outputs"] = {} if values.get("attachments") is None: values["attachments"] = {} return values @root_validator(pre=False) def ensure_dotted_order(cls, values: dict) -> dict: """Ensure the dotted order of the run.""" current_dotted_order = values.get("dotted_order") if current_dotted_order and current_dotted_order.strip(): return values current_dotted_order = _create_current_dotted_order( values["start_time"], values["id"] ) if values["parent_run"]: values["dotted_order"] = ( values["parent_run"].dotted_order + "." + current_dotted_order ) else: values["dotted_order"] = current_dotted_order return values @property def client(self) -> Client: """Return the client.""" # Lazily load the client # If you never use this for API calls, it will never be loaded if self.ls_client is None: self.ls_client = get_cached_client() return self.ls_client @property def _client(self) -> Optional[Client]: # For backwards compat return self.ls_client def __setattr__(self, name, value): """Set the _client specially.""" # For backwards compat if name == "_client": self.ls_client = value else: return super().__setattr__(name, value)
[docs] def add_tags(self, tags: Union[Sequence[str], str]) -> None: """Add tags to the run.""" if isinstance(tags, str): tags = [tags] if self.tags is None: self.tags = [] self.tags.extend(tags)
[docs] def add_metadata(self, metadata: Dict[str, Any]) -> None: """Add metadata to the run.""" if self.extra is None: self.extra = {} metadata_: dict = cast(dict, self.extra).setdefault("metadata", {}) metadata_.update(metadata)
[docs] def add_outputs(self, outputs: Dict[str, Any]) -> None: """Upsert the given outputs into the run. Args: outputs (Dict[str, Any]): A dictionary containing the outputs to be added. Returns: None """ if self.outputs is None: self.outputs = {} self.outputs.update(outputs)
[docs] def add_event( self, events: Union[ ls_schemas.RunEvent, Sequence[ls_schemas.RunEvent], Sequence[dict], dict, str, ], ) -> None: """Add an event to the list of events. Args: events (Union[ls_schemas.RunEvent, Sequence[ls_schemas.RunEvent], Sequence[dict], dict, str]): The event(s) to be added. It can be a single event, a sequence of events, a sequence of dictionaries, a dictionary, or a string. Returns: None """ if self.events is None: self.events = [] if isinstance(events, dict): self.events.append(events) # type: ignore[arg-type] elif isinstance(events, str): self.events.append( { "name": "event", "time": datetime.now(timezone.utc).isoformat(), "message": events, } ) else: self.events.extend(events) # type: ignore[arg-type]
[docs] def end( self, *, outputs: Optional[Dict] = None, error: Optional[str] = None, end_time: Optional[datetime] = None, events: Optional[Sequence[ls_schemas.RunEvent]] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: """Set the end time of the run and all child runs.""" self.end_time = end_time or datetime.now(timezone.utc) if outputs is not None: if not self.outputs: self.outputs = outputs else: self.outputs.update(outputs) if error is not None: self.error = error if events is not None: self.add_event(events) if metadata is not None: self.add_metadata(metadata)
[docs] def create_child( self, name: str, run_type: RUN_TYPE_T = "chain", *, run_id: Optional[ID_TYPE] = None, serialized: Optional[Dict] = None, inputs: Optional[Dict] = None, outputs: Optional[Dict] = None, error: Optional[str] = None, reference_example_id: Optional[UUID] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, tags: Optional[List[str]] = None, extra: Optional[Dict] = None, attachments: Optional[ls_schemas.Attachments] = None, ) -> RunTree: """Add a child run to the run tree.""" serialized_ = serialized or {"name": name} run = RunTree( name=name, id=_ensure_uuid(run_id), serialized=serialized_, inputs=inputs or {}, outputs=outputs or {}, error=error, run_type=run_type, reference_example_id=reference_example_id, start_time=start_time or datetime.now(timezone.utc), end_time=end_time, extra=extra or {}, parent_run=self, project_name=self.session_name, ls_client=self.ls_client, tags=tags, attachments=attachments or {}, ) self.child_runs.append(run) return run
def _get_dicts_safe(self): # Things like generators cannot be copied self_dict = self.dict( exclude={"child_runs", "inputs", "outputs"}, exclude_none=True ) if self.inputs is not None: # shallow copy. deep copying will occur in the client self_dict["inputs"] = self.inputs.copy() if self.outputs is not None: # shallow copy; deep copying will occur in the client self_dict["outputs"] = self.outputs.copy() return self_dict
[docs] def post(self, exclude_child_runs: bool = True) -> None: """Post the run tree to the API asynchronously.""" kwargs = self._get_dicts_safe() self.client.create_run(**kwargs) if attachments := kwargs.get("attachments"): keys = [str(name) for name in attachments] self.events.append( { "name": "uploaded_attachment", "time": datetime.now(timezone.utc).isoformat(), "message": set(keys), } ) if not exclude_child_runs: for child_run in self.child_runs: child_run.post(exclude_child_runs=False)
[docs] def patch(self) -> None: """Patch the run tree to the API in a background thread.""" if not self.end_time: self.end() attachments = self.attachments try: # Avoid loading the same attachment twice if attachments: uploaded = next( ( ev for ev in self.events if ev.get("name") == "uploaded_attachment" ), None, ) if uploaded: attachments = { a: v for a, v in attachments.items() if a not in uploaded["message"] } except Exception as e: logger.warning(f"Error filtering attachments to upload: {e}") self.client.update_run( name=self.name, run_id=self.id, outputs=self.outputs.copy() if self.outputs else None, error=self.error, parent_run_id=self.parent_run_id, reference_example_id=self.reference_example_id, end_time=self.end_time, dotted_order=self.dotted_order, trace_id=self.trace_id, events=self.events, tags=self.tags, extra=self.extra, attachments=attachments, )
[docs] def wait(self) -> None: """Wait for all _futures to complete.""" pass
[docs] def get_url(self) -> str: """Return the URL of the run.""" return self.client.get_run_url(run=self)
[docs] @classmethod def from_dotted_order( cls, dotted_order: str, **kwargs: Any, ) -> RunTree: """Create a new 'child' span from the provided dotted order. Returns: RunTree: The new span. """ headers = { LANGSMITH_DOTTED_ORDER: dotted_order, } return cast(RunTree, cls.from_headers(headers, **kwargs)) # type: ignore[arg-type]
[docs] @classmethod def from_runnable_config( cls, config: Optional[dict], **kwargs: Any, ) -> Optional[RunTree]: """Create a new 'child' span from the provided runnable config. Requires langchain to be installed. Returns: Optional[RunTree]: The new span or None if no parent span information is found. """ try: from langchain_core.callbacks.manager import ( AsyncCallbackManager, CallbackManager, ) from langchain_core.runnables import RunnableConfig, ensure_config from langchain_core.tracers.langchain import LangChainTracer except ImportError as e: raise ImportError( "RunTree.from_runnable_config requires langchain-core to be installed. " "You can install it with `pip install langchain-core`." ) from e if config is None: config_ = ensure_config( cast(RunnableConfig, config) if isinstance(config, dict) else None ) else: config_ = cast(RunnableConfig, config) if ( (cb := config_.get("callbacks")) and isinstance(cb, (CallbackManager, AsyncCallbackManager)) and cb.parent_run_id and ( tracer := next( (t for t in cb.handlers if isinstance(t, LangChainTracer)), None, ) ) ): if (run := tracer.run_map.get(str(cb.parent_run_id))) and run.dotted_order: dotted_order = run.dotted_order kwargs["run_type"] = run.run_type kwargs["inputs"] = run.inputs kwargs["outputs"] = run.outputs kwargs["start_time"] = run.start_time kwargs["end_time"] = run.end_time kwargs["tags"] = sorted(set(run.tags or [] + kwargs.get("tags", []))) kwargs["name"] = run.name extra_ = kwargs.setdefault("extra", {}) metadata_ = extra_.setdefault("metadata", {}) metadata_.update(run.metadata) elif hasattr(tracer, "order_map") and cb.parent_run_id in tracer.order_map: dotted_order = tracer.order_map[cb.parent_run_id][1] else: return None kwargs["client"] = tracer.client kwargs["project_name"] = tracer.project_name return RunTree.from_dotted_order(dotted_order, **kwargs) return None
[docs] @classmethod def from_headers( cls, headers: Mapping[Union[str, bytes], Union[str, bytes]], **kwargs: Any ) -> Optional[RunTree]: """Create a new 'parent' span from the provided headers. Extracts parent span information from the headers and creates a new span. Metadata and tags are extracted from the baggage header. The dotted order and trace id are extracted from the trace header. Returns: Optional[RunTree]: The new span or None if no parent span information is found. """ init_args = kwargs.copy() langsmith_trace = cast(Optional[str], headers.get(LANGSMITH_DOTTED_ORDER)) if not langsmith_trace: langsmith_trace_bytes = cast( Optional[bytes], headers.get(LANGSMITH_DOTTED_ORDER_BYTES) ) if not langsmith_trace_bytes: return # type: ignore[return-value] langsmith_trace = langsmith_trace_bytes.decode("utf-8") parent_dotted_order = langsmith_trace.strip() parsed_dotted_order = _parse_dotted_order(parent_dotted_order) trace_id = parsed_dotted_order[0][1] init_args["trace_id"] = trace_id init_args["id"] = parsed_dotted_order[-1][1] init_args["dotted_order"] = parent_dotted_order if len(parsed_dotted_order) >= 2: # Has a parent init_args["parent_run_id"] = parsed_dotted_order[-2][1] # All placeholders. We assume the source process # handles the life-cycle of the run. init_args["start_time"] = init_args.get("start_time") or datetime.now( timezone.utc ) init_args["run_type"] = init_args.get("run_type") or "chain" init_args["name"] = init_args.get("name") or "parent" baggage = _Baggage.from_headers(headers) if baggage.metadata or baggage.tags: init_args["extra"] = init_args.setdefault("extra", {}) init_args["extra"]["metadata"] = init_args["extra"].setdefault( "metadata", {} ) metadata = {**baggage.metadata, **init_args["extra"]["metadata"]} init_args["extra"]["metadata"] = metadata tags = sorted(set(baggage.tags + init_args.get("tags", []))) init_args["tags"] = tags if baggage.project_name: init_args["project_name"] = baggage.project_name return RunTree(**init_args)
[docs] def to_headers(self) -> Dict[str, str]: """Return the RunTree as a dictionary of headers.""" headers = {} if self.trace_id: headers[f"{LANGSMITH_DOTTED_ORDER}"] = self.dotted_order baggage = _Baggage( metadata=self.extra.get("metadata", {}), tags=self.tags, project_name=self.session_name, ) headers["baggage"] = baggage.to_header() return headers
def __repr__(self): """Return a string representation of the RunTree object.""" return ( f"RunTree(id={self.id}, name='{self.name}', " f"run_type='{self.run_type}', dotted_order='{self.dotted_order}')" )
class _Baggage: """Baggage header information.""" def __init__( self, metadata: Optional[Dict[str, str]] = None, tags: Optional[List[str]] = None, project_name: Optional[str] = None, ): """Initialize the Baggage object.""" self.metadata = metadata or {} self.tags = tags or [] self.project_name = project_name @classmethod def from_header(cls, header_value: Optional[str]) -> _Baggage: """Create a Baggage object from the given header value.""" if not header_value: return cls() metadata = {} tags = [] project_name = None try: for item in header_value.split(","): key, value = item.split("=", 1) if key == LANGSMITH_METADATA: metadata = json.loads(urllib.parse.unquote(value)) elif key == LANGSMITH_TAGS: tags = urllib.parse.unquote(value).split(",") elif key == LANGSMITH_PROJECT: project_name = urllib.parse.unquote(value) except Exception as e: logger.warning(f"Error parsing baggage header: {e}") return cls(metadata=metadata, tags=tags, project_name=project_name) @classmethod def from_headers(cls, headers: Mapping[Union[str, bytes], Any]) -> _Baggage: if "baggage" in headers: return cls.from_header(headers["baggage"]) elif b"baggage" in headers: return cls.from_header(cast(bytes, headers[b"baggage"]).decode("utf-8")) else: return cls.from_header(None) def to_header(self) -> str: """Return the Baggage object as a header value.""" items = [] if self.metadata: serialized_metadata = _dumps_json(self.metadata) items.append( f"{LANGSMITH_PREFIX}metadata={urllib.parse.quote(serialized_metadata)}" ) if self.tags: serialized_tags = ",".join(self.tags) items.append( f"{LANGSMITH_PREFIX}tags={urllib.parse.quote(serialized_tags)}" ) if self.project_name: items.append( f"{LANGSMITH_PREFIX}project={urllib.parse.quote(self.project_name)}" ) return ",".join(items) def _parse_dotted_order(dotted_order: str) -> List[Tuple[datetime, UUID]]: """Parse the dotted order string.""" parts = dotted_order.split(".") return [ (datetime.strptime(part[:-36], "%Y%m%dT%H%M%S%fZ"), UUID(part[-36:])) for part in parts ] def _create_current_dotted_order( start_time: Optional[datetime], run_id: Optional[UUID] ) -> str: """Create the current dotted order.""" st = start_time or datetime.now(timezone.utc) id_ = run_id or uuid4() return st.strftime("%Y%m%dT%H%M%S%fZ") + str(id_) __all__ = ["RunTree", "RunTree"]