Source code for notdiamond.toolkit.openai

"""
Tools for working directly with OpenAI's various models.
"""
import logging
from typing import List, Union

from notdiamond import NotDiamond
from notdiamond.llms.providers import NDLLMProviders
from notdiamond.settings import NOTDIAMOND_API_KEY, OPENAI_API_KEY

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)

_ND_PARAMS = [
    "llm_configs",
    "default",
    "max_model_depth",
    "latency_tracking",
    "hash_content",
    "tradeoff",
    "preference_id",
    "tools",
    "callbacks",
    "nd_api_url",
    "nd_api_key",
    "user_agent",
]
_SHARED_PARAMS = ["timeout", "max_retries"]


[docs] class _OpenAIBase: """ Base class which wraps both an openai client and Not Diamond retry / fallback logic. """ def __init__(self, oai_client_cls, *args, **kwargs): nd_kwargs = { k: v for k, v in kwargs.items() if k in _ND_PARAMS + _SHARED_PARAMS } # TODO [a9] remove llm_configs as valid constructor arg for ND client self._nd_client = NotDiamond( api_key=nd_kwargs.get("nd_api_key", NOTDIAMOND_API_KEY), llm_configs=["openai/gpt-3.5-turbo"], *args, **nd_kwargs, ) # Create a OpenAI client with a dummy model - will ignore this during routing oai_kwargs = {k: v for k, v in kwargs.items() if k not in _ND_PARAMS} self._oai_client = oai_client_cls( *args, api_key=OPENAI_API_KEY, **oai_kwargs ) def __getattr__(self, name): return getattr(self._oai_client, name) def __call__(self, *args, **kwargs): return self._oai_client(*args, **kwargs) def __dir__(self): return dir(self._oai_client) @property def chat(self): class ChatCompletions: def __init__(self, parent): self.parent = parent @property def completions(self): return self def create(self, *args, **kwargs): return self.parent.create(*args, **kwargs) return ChatCompletions(self) def _create_prep(self, model: Union[str, List], **kwargs): model = kwargs.get("model", model) if model is None: LOGGER.info( "No LLM configs provided. Not Diamond will route to all OpenAI models." ) llm_configs = [ str(p) for p in NDLLMProviders if p.provider == "openai" ] elif isinstance(model, str): llm_configs = model.split(",") elif isinstance(model, list): llm_configs = self._nd_client._parse_llm_configs_data(model) if "messages" not in kwargs: raise ValueError("'messages' argument is required") return llm_configs
[docs] class OpenAI(_OpenAIBase): """ Encapsulating class for an openai.OpenAI client. This supports the same methods as the openai package, while also supporting routed prompts with calls to `completion`. """ def __init__(self, *args, **kwargs): from openai import OpenAI as OpenAIClient super().__init__(OpenAIClient, *args, **kwargs)
[docs] def create(self, *args, model: Union[str, List] = None, **kwargs): """ Perform chat completion using OpenAI's API, after routing the prompt to a specific LLM via Not Diamond. """ llm_configs = self._create_prep(model, **kwargs) session_id, best_llm = self._nd_client.model_select( *args, model=llm_configs, **kwargs ) response = self._oai_client.chat.completions.create( *args, model=str(best_llm.model), **kwargs ) LOGGER.info(f"Routed prompt to {best_llm} for session ID {session_id}") return response
[docs] class AsyncOpenAI(_OpenAIBase): """ Encapsulating class for an openai.OpenAI client. This supports the same methods as the openai package, while also supporting routed prompts with calls to `completion`. """ def __init__(self, *args, **kwargs): from openai import AsyncOpenAI as OpenAIClient super().__init__(OpenAIClient, *args, **kwargs)
[docs] async def create(self, *args, model: Union[str, List] = None, **kwargs): """ Perform async chat completion using OpenAI's API, after routing the prompt to a specific LLM via Not Diamond. """ llm_configs = self._create_prep(model, **kwargs) session_id, best_llm = await self._nd_client.amodel_select( *args, model=llm_configs, **kwargs ) response = await self._oai_client.chat.completions.create( *args, model=str(best_llm.model), **kwargs ) LOGGER.debug( f"Routed prompt to {best_llm} for session ID {session_id}" ) return response