"""NotDiamond client class"""
import inspect
import logging
import time
import warnings
from enum import Enum
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from litellm import token_counter
# Details: https://python.langchain.com/v0.1/docs/guides/development/pydantic_compatibility/
from pydantic.v1 import BaseModel
from pydantic_partial import create_partial_model
from notdiamond import settings
from notdiamond._utils import _module_check
from notdiamond.exceptions import (
ApiError,
CreateUnavailableError,
MissingLLMConfigs,
)
from notdiamond.llms.config import LLMConfig
from notdiamond.llms.providers import is_o1_model
from notdiamond.llms.request import (
amodel_select,
create_preference_id,
model_select,
report_latency,
)
from notdiamond.metrics.metric import Metric
from notdiamond.prompts import (
_curly_escape,
inject_system_prompt,
o1_system_prompt_translate,
)
from notdiamond.types import NDApiKeyValidator
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)
class _NDClientTarget(Enum):
ROUTER = "router"
INVOKER = "invoker"
def _ndllm_factory(import_target: _NDClientTarget = None):
_invoke_error_msg_tmpl = (
"{fn_name} is not available. `notdiamond` can generate LLM responses after "
"installing additional dependencies via `pip install notdiamond[create]`."
)
_default_llm_config_invalid_warning = "The default LLMConfig set is invalid. Defaulting to {provider}/{model}"
_no_default_llm_config_warning = (
"No default LLMConfig set. Defaulting to {provider}/{model}"
)
class _NDRouterClient(BaseModel):
api_key: str
llm_configs: Optional[List[Union[LLMConfig, str]]]
default: Union[LLMConfig, int, str]
max_model_depth: Optional[int]
latency_tracking: bool
hash_content: bool
tradeoff: Optional[str]
preference_id: Optional[str]
tools: Optional[Sequence[Union[Dict[str, Any], Callable]]]
callbacks: Optional[List]
nd_api_url: Optional[str]
user_agent: Union[str, None]
max_retries: Optional[int]
timeout: Optional[Union[float, int]]
class Config:
arbitrary_types_allowed = True
def __init__(
self,
llm_configs: Optional[List[Union[LLMConfig, str]]] = None,
api_key: Optional[str] = None,
default: Union[LLMConfig, int, str] = 0,
max_model_depth: Optional[int] = None,
latency_tracking: bool = True,
hash_content: bool = False,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
callbacks: Optional[List] = None,
tools: Optional[Sequence[Union[Dict[str, Any], Callable]]] = None,
nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
user_agent: Union[str, None] = None,
max_retries: Optional[int] = 3,
timeout: Optional[Union[float, int]] = 60.0,
**kwargs,
):
if api_key is None:
api_key = settings.NOTDIAMOND_API_KEY
NDApiKeyValidator(api_key=api_key)
if user_agent is None:
user_agent = settings.DEFAULT_USER_AGENT
if llm_configs is not None:
llm_configs = self._parse_llm_configs_data(llm_configs)
if max_model_depth is None:
max_model_depth = len(llm_configs)
if max_model_depth > len(llm_configs):
LOGGER.warning(
"WARNING: max_model_depth cannot be bigger than the number of LLMs."
)
max_model_depth = len(llm_configs)
if tradeoff is not None:
if tradeoff not in ["cost", "latency"]:
raise ValueError(
"Invalid tradeoff. Accepted values: cost, latency."
)
if tradeoff is not None:
warnings.warn(
"The tradeoff constructor parameter is deprecated and will be removed in a "
"future version. Please specify the tradeoff when using model_select or invocation methods.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(
api_key=api_key,
llm_configs=llm_configs,
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
tools=tools,
callbacks=callbacks,
nd_api_url=nd_api_url,
user_agent=user_agent,
max_retries=max_retries,
timeout=timeout,
**kwargs,
)
self.user_agent = user_agent
assert (
self.api_key is not None
), "API key is not set. Please set a Not Diamond API key."
@property
def chat(self):
return self
@property
def completions(self):
return self
def create_preference_id(self, name: Optional[str] = None) -> str:
return create_preference_id(self.api_key, name, self.nd_api_url)
async def amodel_select(
self,
messages: List[Dict[str, str]],
input: Optional[Dict[str, Any]] = None,
model: Optional[List[LLMConfig]] = None,
default: Optional[Union[LLMConfig, int, str]] = None,
max_model_depth: Optional[int] = None,
latency_tracking: Optional[bool] = None,
hash_content: Optional[bool] = None,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
metric: Metric = Metric("accuracy"),
previous_session: Optional[str] = None,
timeout: Optional[Union[float, int]] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> tuple[str, Optional[LLMConfig]]:
"""
This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
and leaves the execution of the LLM call to the developer.
The function is async, so it's suitable for async codebases.
Parameters:
messages (List[Dict[str, str]]): List of messages, OpenAI style.
input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
the values for those variables. Defaults to None, assuming no
variables.
model (Optional[List[LLMConfig]]): List of models to choose from.
default (Optional[Union[LLMConfig, int, str]]): Default LLM.
max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
of routing you're willing to go.
latency_tracking (Optional[bool]): Latency tracking flag.
hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
for the router to determine the best LLM for a given query.
preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
Defaults to None.
metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
Defaults to Metric("accuracy").
previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
max_retries (int): The number of retries to attempt before giving up.
nd_api_url (Optional[str]): The URL of the NotDiamond API. Defaults to settings.NOTDIAMOND_API_URL.
**kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
Returns:
tuple[str, Optional[LLMConfig]]: returns the session_id and the chosen LLM
"""
if input is None:
input = {}
if model is not None:
llm_configs = self._parse_llm_configs_data(model)
self.llm_configs = llm_configs
self.validate_params(
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
)
best_llm, session_id = await amodel_select(
messages=messages,
llm_configs=self.llm_configs,
metric=metric,
notdiamond_api_key=self.api_key,
max_model_depth=self.max_model_depth,
hash_content=self.hash_content,
tradeoff=self.tradeoff,
preference_id=self.preference_id,
tools=self.tools,
previous_session=previous_session,
timeout=timeout or self.timeout,
max_retries=max_retries or self.max_retries,
nd_api_url=self.nd_api_url,
_user_agent=self.user_agent,
)
if not best_llm:
LOGGER.warning(
f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
)
best_llm = self.default_llm
self.call_callbacks("on_model_select", best_llm, best_llm.model)
return session_id, best_llm
def model_select(
self,
messages: List[Dict[str, str]],
input: Optional[Dict[str, Any]] = None,
model: Optional[List[LLMConfig]] = None,
default: Optional[Union[LLMConfig, int, str]] = None,
max_model_depth: Optional[int] = None,
latency_tracking: Optional[bool] = None,
hash_content: Optional[bool] = None,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
metric: Metric = Metric("accuracy"),
previous_session: Optional[str] = None,
timeout: Optional[Union[float, int]] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> tuple[str, Optional[LLMConfig]]:
"""
This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
and leaves the execution of the LLM call to the developer.
Parameters:
messages (List[Dict[str, str]]): List of messages OpenAI style.
input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
the values for those variables. Defaults to None, assuming no
variables.
model (Optional[List[LLMConfig]]): List of models to choose from.
default (Optional[Union[LLMConfig, int, str]]): Default LLM.
max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
of routing you're willing to go.
latency_tracking (Optional[bool]): Latency tracking flag.
hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
for the router to determine the best LLM for a given query.
preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
Defaults to None.
metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
Defaults to Metric("accuracy").
previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
max_retries (int): The number of retries to attempt before giving up.
**kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
Returns:
tuple[str, Optional[LLMConfig]]: returns the session_id and the chosen LLM
"""
if input is None:
input = {}
if model is not None:
llm_configs = self._parse_llm_configs_data(model)
self.llm_configs = llm_configs
self.validate_params(
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
)
best_llm, session_id = model_select(
messages=messages,
llm_configs=self.llm_configs,
metric=metric,
notdiamond_api_key=self.api_key,
max_model_depth=self.max_model_depth,
hash_content=self.hash_content,
tradeoff=self.tradeoff,
preference_id=self.preference_id,
tools=self.tools,
previous_session=previous_session,
timeout=timeout or self.timeout,
max_retries=max_retries or self.max_retries,
nd_api_url=self.nd_api_url,
_user_agent=self.user_agent,
)
if not best_llm:
LOGGER.warning(
f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
)
best_llm = self.default_llm
self.call_callbacks("on_model_select", best_llm, best_llm.model)
return session_id, best_llm
@staticmethod
def _parse_llm_configs_data(
llm_configs: list,
) -> List[LLMConfig]:
providers = []
for llm_config in llm_configs:
if isinstance(llm_config, LLMConfig):
providers.append(llm_config)
continue
parsed_provider = LLMConfig.from_string(llm_config)
providers.append(parsed_provider)
return providers
def validate_params(
self,
default: Optional[Union[LLMConfig, int, str]] = None,
max_model_depth: Optional[int] = None,
latency_tracking: Optional[bool] = None,
hash_content: Optional[bool] = None,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
):
self.default = default
if max_model_depth is not None:
self.max_model_depth = max_model_depth
if self.llm_configs is None or len(self.llm_configs) == 0:
raise MissingLLMConfigs(
"No LLM config speficied. Specify at least one."
)
if self.max_model_depth is None:
self.max_model_depth = len(self.llm_configs)
if self.max_model_depth == 0:
raise ValueError("max_model_depth has to be bigger than 0.")
if self.max_model_depth > len(self.llm_configs):
LOGGER.warning(
"WARNING: max_model_depth cannot be bigger than the number of LLMs."
)
self.max_model_depth = len(self.llm_configs)
if tradeoff is not None:
if tradeoff not in ["cost", "latency"]:
raise ValueError(
"Invalid tradeoff. Accepted values: cost, latency."
)
self.tradeoff = tradeoff
if preference_id is not None:
self.preference_id = preference_id
if latency_tracking is not None:
self.latency_tracking = latency_tracking
if hash_content is not None:
self.hash_content = hash_content
def bind_tools(
self, tools: Sequence[Union[Dict[str, Any], Callable]]
) -> "NotDiamond":
"""
Bind tools to the LLM object. The tools will be passed to the LLM object when invoking it.
Results in the tools being available in the LLM object.
You can access the tool_calls in the result via `result.tool_calls`.
"""
for provider in self.llm_configs:
if provider.model not in settings.PROVIDERS[
provider.provider
].get("support_tools", []):
raise ApiError(
f"{provider.provider}/{provider.model} does not support function calling."
)
self.tools = tools
return self
def call_callbacks(self, function_name: str, *args, **kwargs) -> None:
"""
Call all callbacks with a specific function name.
"""
if self.callbacks is None:
return
for callback in self.callbacks:
if hasattr(callback, function_name):
getattr(callback, function_name)(*args, **kwargs)
def create(*args, **kwargs):
format_str = f"`{inspect.stack()[0].function}`"
raise CreateUnavailableError(
_invoke_error_msg_tmpl.format(fn_name=format_str)
)
async def acreate(*args, **kwargs):
format_str = f"`{inspect.stack()[0].function}`"
raise CreateUnavailableError(
_invoke_error_msg_tmpl.format(fn_name=format_str)
)
def invoke(*args, **kwargs):
format_str = f"`{inspect.stack()[0].function}`"
raise CreateUnavailableError(
_invoke_error_msg_tmpl.format(fn_name=format_str)
)
async def ainvoke(*args, **kwargs):
format_str = f"`{inspect.stack()[0].function}`"
raise CreateUnavailableError(
_invoke_error_msg_tmpl.format(fn_name=format_str)
)
def stream(*args, **kwargs):
raise CreateUnavailableError(
_invoke_error_msg_tmpl.format(
fn_name=inspect.stack()[0].function
)
)
async def astream(*args, **kwargs):
raise CreateUnavailableError(
_invoke_error_msg_tmpl.format(
fn_name=inspect.stack()[0].function
)
)
@property
def default_llm(self) -> LLMConfig:
"""
Return the default LLM that's set on the NotDiamond client class.
"""
if isinstance(self.default, int):
if self.default < len(self.llm_configs):
return self.llm_configs[self.default]
if isinstance(self.default, str):
try:
default = LLMConfig.from_string(self.default)
if default in self.llm_configs:
return default
except Exception as e:
LOGGER.debug(f"Error setting default llm: {e}")
if isinstance(self.default, LLMConfig):
return self.default
default = self.llm_configs[0]
if self.default is None:
LOGGER.info(
_no_default_llm_config_warning.format(
provider=default.provider, model=default.model
)
)
else:
LOGGER.info(
_default_llm_config_invalid_warning.format(
provider=default.provider, model=default.model
)
)
return default
# Do not import from langchain_core directly, as it is now an optional SDK dependency
try:
LLM = _module_check("langchain_core.language_models.llms", "LLM")
BaseMessageChunk = _module_check(
"langchain_core.messages", "BaseMessageChunk"
)
JsonOutputParser = _module_check(
"langchain_core.output_parsers", "JsonOutputParser"
)
ChatPromptTemplate = _module_check(
"langchain_core.prompts", "ChatPromptTemplate"
)
except (ModuleNotFoundError, ImportError) as ierr:
msg = _invoke_error_msg_tmpl.format(fn_name="NotDiamond creation")
if import_target == _NDClientTarget.INVOKER:
msg += " Create was requested, however - raising..."
raise ImportError(msg) from ierr
else:
LOGGER.debug(msg)
return _NDRouterClient
class _NDInvokerClient(_NDRouterClient, LLM):
"""
Implementation of NotDiamond class, the main class responsible for creating and invoking LLM prompts.
The class inherits from Langchain's LLM class. Starting reference is from here:
https://python.langchain.com/docs/modules/model_io/llms/custom_llm
It's mandatory to have an API key set. If the api_key is not explicitly specified,
it will check for NOTDIAMOND_API_KEY in the .env file.
Raises:
MissingLLMProviders: you must specify at least one LLM provider for the router to work
ApiError: error raised when the NotDiamond API call fails.
Ensure to set a default LLM provider to not break the code.
"""
api_key: str
llm_configs: Optional[List[Union[LLMConfig, str]]]
default: Union[LLMConfig, int, str]
max_model_depth: Optional[int]
latency_tracking: bool
hash_content: bool
tradeoff: Optional[str]
preference_id: Optional[str]
tools: Optional[Sequence[Union[Dict[str, Any], Callable]]]
callbacks: Optional[List]
nd_api_url: Optional[str]
user_agent: Union[str, None]
def __init__(
self,
llm_configs: Optional[List[Union[LLMConfig, str]]] = None,
api_key: Optional[str] = None,
default: Union[LLMConfig, int, str] = 0,
max_model_depth: Optional[int] = None,
latency_tracking: bool = True,
hash_content: bool = False,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
tools: Optional[Sequence[Union[Dict[str, Any], Callable]]] = None,
callbacks: Optional[List] = None,
nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
user_agent: Union[str, None] = None,
timeout: Optional[Union[float, int]] = 60.0,
max_retries: Optional[int] = 3,
**kwargs,
) -> None:
super().__init__(
api_key=api_key,
llm_configs=llm_configs,
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
tools=tools,
callbacks=callbacks,
nd_api_url=nd_api_url,
user_agent=user_agent,
timeout=timeout,
max_retries=max_retries,
**kwargs,
)
if user_agent is None:
user_agent = settings.DEFAULT_USER_AGENT
if tradeoff is not None:
warnings.warn(
"The tradeoff constructor parameter is deprecated and will be removed in a "
"future version. Please specify the tradeoff when using model_select or invocation methods.",
DeprecationWarning,
stacklevel=2,
)
self.user_agent = user_agent
assert (
self.api_key is not None
), "API key is not set. Please set a Not Diamond API key."
def __repr__(self) -> str:
class_name = self.__class__.__name__
address = hex(id(self)) # Gets the memory address of the object
return f"<{class_name} object at {address}>"
@property
def _llm_type(self) -> str:
return "NotDiamond LLM"
@staticmethod
def _inject_model_instruction(messages, parser):
format_instructions = parser.get_format_instructions()
format_instructions = format_instructions.replace(
"{", "{{"
).replace("}", "}}")
messages[0]["content"] = (
format_instructions + "\n" + messages[0]["content"]
)
return messages
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[Any] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return "This function is deprecated for the latest LangChain version, use invoke instead"
def create(
self,
messages: List[Dict[str, str]],
model: Optional[List[LLMConfig]] = None,
default: Optional[Union[LLMConfig, int, str]] = None,
max_model_depth: Optional[int] = None,
latency_tracking: Optional[bool] = None,
hash_content: Optional[bool] = None,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
metric: Metric = Metric("accuracy"),
previous_session: Optional[str] = None,
response_model: Optional[Type[BaseModel]] = None,
timeout: Optional[Union[float, int]] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> tuple[str, str, LLMConfig]:
"""
Function call to invoke the LLM, with the same interface
as the OpenAI Python library.
Parameters:
messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
the messages OpenAI style.
model (Optional[List[LLMConfig]]): List of models to choose from.
default (Optional[Union[LLMConfig, int, str]]): Default LLM.
max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
of routing you're willing to go.
latency_tracking (Optional[bool]): Latency tracking flag.
hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
for the router to determine the best LLM for a given query.
preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
Defaults to None.
metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
Defaults to Metric("accuracy").
previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
response into the given model. In which case result will a
dict.
timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
max_retries (int): The number of retries to attempt before giving up.
nd_api_url (Optional[str]): The URL of the NotDiamond API. Defaults to settings.NOTDIAMOND_API_URL.
**kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
Raises:
ApiError: when the NotDiamond API fails
Returns:
tuple[Union[AIMessage, BaseModel], str, LLMConfig]:
result: response type defined by Langchain, contains the response from the LLM.
or object of the response_model
str: session_id returned by the NotDiamond API
LLMConfig: the best LLM selected by the router
"""
return self.invoke(
messages=messages,
model=model,
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
metric=metric,
previous_session=previous_session,
response_model=response_model,
timeout=timeout,
max_retries=max_retries,
**kwargs,
)
async def acreate(
self,
messages: List[Dict[str, str]],
model: Optional[List[LLMConfig]] = None,
default: Optional[Union[LLMConfig, int, str]] = None,
max_model_depth: Optional[int] = None,
latency_tracking: Optional[bool] = None,
hash_content: Optional[bool] = None,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
metric: Metric = Metric("accuracy"),
previous_session: Optional[str] = None,
response_model: Optional[Type[BaseModel]] = None,
timeout: Optional[Union[float, int]] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> tuple[str, str, LLMConfig]:
"""
Async function call to invoke the LLM, with the same interface
as the OpenAI Python library.
Parameters:
messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
the messages OpenAI style.
model (Optional[List[LLMConfig]]): List of models to choose from.
default (Optional[Union[LLMConfig, int, str]]): Default LLM.
max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
of routing you're willing to go.
latency_tracking (Optional[bool]): Latency tracking flag.
hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
for the router to determine the best LLM for a given query.
preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
Defaults to None.
metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
Defaults to Metric("accuracy").
previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
response into the given model. In which case result will a
dict.
timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
max_retries (int): The number of retries to attempt before giving up.
**kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
Raises:
ApiError: when the NotDiamond API fails
Returns:
tuple[Union[AIMessage, BaseModel], str, LLMConfig]:
result: response type defined by Langchain, contains the response from the LLM.
or object of the response_model
str: session_id returned by the NotDiamond API
LLMConfig: the best LLM selected by the router
"""
result = await self.ainvoke(
messages=messages,
model=model,
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
metric=metric,
previous_session=previous_session,
response_model=response_model,
timeout=timeout,
max_retries=max_retries,
**kwargs,
)
return result
def invoke(
self,
messages: List[Dict[str, str]],
model: Optional[List[LLMConfig]] = None,
default: Optional[Union[LLMConfig, int, str]] = None,
max_model_depth: Optional[int] = None,
latency_tracking: Optional[bool] = None,
hash_content: Optional[bool] = None,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
metric: Metric = Metric("accuracy"),
previous_session: Optional[str] = None,
response_model: Optional[Type[BaseModel]] = None,
timeout: Optional[Union[float, int]] = None,
max_retries: Optional[int] = None,
input: Optional[Dict[str, Any]] = None,
**kwargs,
) -> tuple[str, str, LLMConfig]:
"""
Function to invoke the LLM. Behind the scenes what happens:
1. API call to NotDiamond backend to get the most suitable LLM for the given prompt
2. Invoke the returned LLM client side
3. Return the response
Parameters:
prompt_template (Optional(Union[ NDPromptTemplate, NDChatPromptTemplate, str, ])):
the prompt template defined by the user. It also supports Langchain prompt template types.
messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
the messages OpenAI style.
model (Optional[List[LLMConfig]]): List of models to choose from.
default (Optional[Union[LLMConfig, int, str]]): Default LLM.
max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
of routing you're willing to go.
latency_tracking (Optional[bool]): Latency tracking flag.
hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
for the router to determine the best LLM for a given query.
preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
Defaults to None.
metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
Defaults to Metric("accuracy").
previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
response into the given model. In which case result will a
dict.
timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
max_retries (int): The number of retries to attempt before giving up.
input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
the values for those variables. Defaults to None, assuming no
variables.
**kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
Raises:
ApiError: when the NotDiamond API fails
Returns:
tuple[Union[AIMessage, BaseModel], str, LLMConfig]:
result: response type defined by Langchain, contains the response from the LLM.
or object of the response_model
str: session_id returned by the NotDiamond API
LLMConfig: the best LLM selected by the router
"""
if model is not None:
llm_configs = self._parse_llm_configs_data(model)
self.llm_configs = llm_configs
self.validate_params(
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
)
# If response_model is present, we will parse the response into the given model
# doing this here so that if validation errors occur, we can raise them before making the API call
response_model_parser = None
if response_model is not None:
self.verify_against_response_model()
response_model_parser = JsonOutputParser(
pydantic_object=response_model
)
if input is None:
input = {}
best_llm, session_id = model_select(
messages=messages,
llm_configs=self.llm_configs,
metric=metric,
notdiamond_api_key=self.api_key,
max_model_depth=self.max_model_depth,
hash_content=self.hash_content,
tradeoff=self.tradeoff,
preference_id=self.preference_id,
tools=self.tools,
previous_session=previous_session,
timeout=timeout or self.timeout,
max_retries=max_retries or self.max_retries,
nd_api_url=self.nd_api_url,
)
is_default = False
if not best_llm:
LOGGER.warning(
f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
)
best_llm = self.default_llm
is_default = True
if best_llm.system_prompt is not None:
messages = inject_system_prompt(
messages, best_llm.system_prompt
)
messages = o1_system_prompt_translate(messages, best_llm)
self.call_callbacks("on_model_select", best_llm, best_llm.model)
llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
if self.tools and not is_o1_model(best_llm):
llm = llm.bind_tools(self.tools)
if response_model is not None:
messages = _NDInvokerClient._inject_model_instruction(
messages, response_model_parser
)
chain_messages = [
(msg["role"], _curly_escape(msg["content"]))
for msg in messages
]
prompt_template = ChatPromptTemplate.from_messages(chain_messages)
chain = prompt_template | llm
accepted_errors = _get_accepted_invoke_errors(best_llm.provider)
try:
if self.latency_tracking:
result = self._invoke_with_latency_tracking(
session_id=session_id,
chain=chain,
llm_config=best_llm,
is_default=is_default,
input=input,
**kwargs,
)
else:
result = chain.invoke(input, **kwargs)
except accepted_errors as e:
if best_llm.provider == "google":
LOGGER.warning(
f"Submitted chat messages are violating Google requirements with error {e}. "
"If you see this message, `notdiamond` has returned a Google model as the best option, "
"but the LLM call will fail. If possible, `notdiamond` will fall back to a non-Google model."
)
non_google_llm = next(
(
llm_config
for llm_config in self.llm_configs
if llm_config.provider != "google"
),
None,
)
if non_google_llm is not None:
best_llm = non_google_llm
llm = self._llm_from_config(
best_llm, callbacks=self.callbacks
)
if response_model is not None:
messages = (
_NDInvokerClient._inject_model_instruction(
messages, response_model_parser
)
)
chain_messages = [
(msg["role"], _curly_escape(msg["content"]))
for msg in messages
]
prompt_template = ChatPromptTemplate.from_messages(
chain_messages
)
chain = prompt_template | llm
if self.latency_tracking:
result = self._invoke_with_latency_tracking(
session_id=session_id,
chain=chain,
llm_config=best_llm,
is_default=is_default,
input=input,
**kwargs,
)
else:
result = chain.invoke(input, **kwargs)
else:
raise e
else:
raise e
if response_model is not None:
parsed_dict = response_model_parser.parse(result.content)
result = response_model.parse_obj(parsed_dict)
return result, session_id, best_llm
async def ainvoke(
self,
messages: List[Dict[str, str]],
model: Optional[List[LLMConfig]] = None,
default: Optional[Union[LLMConfig, int, str]] = None,
max_model_depth: Optional[int] = None,
latency_tracking: Optional[bool] = None,
hash_content: Optional[bool] = None,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
metric: Metric = Metric("accuracy"),
previous_session: Optional[str] = None,
response_model: Optional[Type[BaseModel]] = None,
timeout: Optional[Union[float, int]] = None,
max_retries: Optional[int] = None,
input: Optional[Dict[str, Any]] = None,
**kwargs,
) -> tuple[str, str, LLMConfig]:
"""
Function to invoke the LLM. Behind the scenes what happens:
1. API call to NotDiamond backend to get the most suitable LLM for the given prompt
2. Invoke the returned LLM client side
3. Return the response
Parameters:
messages (List[Dict[str, str]]): List of messages, OpenAI style
model (Optional[List[LLMConfig]]): List of models to choose from.
default (Optional[Union[LLMConfig, int, str]]): Default LLM.
max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
of routing you're willing to go.
latency_tracking (Optional[bool]): Latency tracking flag.
hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
for the router to determine the best LLM for a given query.
preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
Defaults to None.
metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
Defaults to Metric("accuracy").
previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
response into the given model. In which case result will a dict.
timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
max_retries (int): The number of retries to attempt before giving up.
input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
the values for those variables. Defaults to None, assuming no
variables.
**kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
Raises:
ApiError: when the NotDiamond API fails
Returns:
tuple[Union[AIMessage, BaseModel], str, LLMConfig]:
result: response type defined by Langchain, contains the response from the LLM.
or object of the response_model
str: session_id returned by the NotDiamond API
LLMConfig: the best LLM selected by the router
"""
if model is not None:
llm_configs = self._parse_llm_configs_data(model)
self.llm_configs = llm_configs
self.validate_params(
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
)
response_model_parser = None
if response_model is not None:
self.verify_against_response_model()
response_model_parser = JsonOutputParser(
pydantic_object=response_model
)
if input is None:
input = {}
best_llm, session_id = await amodel_select(
messages=messages,
llm_configs=self.llm_configs,
metric=metric,
notdiamond_api_key=self.api_key,
max_model_depth=self.max_model_depth,
hash_content=self.hash_content,
tradeoff=self.tradeoff,
preference_id=self.preference_id,
tools=self.tools,
previous_session=previous_session,
timeout=timeout or self.timeout,
max_retries=max_retries or self.max_retries,
nd_api_url=self.nd_api_url,
)
is_default = False
if not best_llm:
LOGGER.warning(
f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
)
best_llm = self.default_llm
is_default = True
if best_llm.system_prompt is not None:
messages = inject_system_prompt(
messages, best_llm.system_prompt
)
messages = o1_system_prompt_translate(messages, best_llm)
self.call_callbacks("on_model_select", best_llm, best_llm.model)
llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
if self.tools and not is_o1_model(best_llm):
llm = llm.bind_tools(self.tools)
if response_model is not None:
messages = _NDInvokerClient._inject_model_instruction(
messages, response_model_parser
)
chain_messages = [
(msg["role"], _curly_escape(msg["content"]))
for msg in messages
]
prompt_template = ChatPromptTemplate.from_messages(chain_messages)
chain = prompt_template | llm
accepted_errors = _get_accepted_invoke_errors(best_llm.provider)
try:
if self.latency_tracking:
result = await self._async_invoke_with_latency_tracking(
session_id=session_id,
chain=chain,
llm_config=best_llm,
is_default=is_default,
input=input,
**kwargs,
)
else:
result = await chain.ainvoke(input, **kwargs)
except accepted_errors as e:
if best_llm.provider == "google":
LOGGER.warning(
f"Submitted chat messages are violating Google requirements with error {e}. "
"If you see this message, `notdiamond` has returned a Google model as the best option, "
"but the LLM call will fail. If possible, `notdiamond` will fall back to a non-Google model."
)
non_google_llm = next(
(
llm_config
for llm_config in self.llm_configs
if llm_config.provider != "google"
),
None,
)
if non_google_llm is not None:
best_llm = non_google_llm
llm = self._llm_from_config(
best_llm, callbacks=self.callbacks
)
if response_model is not None:
messages = (
_NDInvokerClient._inject_model_instruction(
messages, response_model_parser
)
)
chain_messages = [
(msg["role"], _curly_escape(msg["content"]))
for msg in messages
]
prompt_template = ChatPromptTemplate.from_messages(
chain_messages
)
chain = prompt_template | llm
if self.latency_tracking:
result = (
await self._async_invoke_with_latency_tracking(
session_id=session_id,
chain=chain,
llm_config=best_llm,
is_default=is_default,
input=input,
**kwargs,
)
)
else:
result = await chain.ainvoke(input, **kwargs)
else:
raise e
else:
raise e
if response_model is not None:
parsed_dict = response_model_parser.parse(result.content)
result = response_model.parse_obj(parsed_dict)
return result, session_id, best_llm
def stream(
self,
messages: List[Dict[str, str]],
model: Optional[List[LLMConfig]] = None,
default: Optional[Union[LLMConfig, int, str]] = None,
max_model_depth: Optional[int] = None,
latency_tracking: Optional[bool] = None,
hash_content: Optional[bool] = None,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
metric: Metric = Metric("accuracy"),
previous_session: Optional[str] = None,
response_model: Optional[Type[BaseModel]] = None,
timeout: Optional[Union[float, int]] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> Iterator[Union[BaseMessageChunk, BaseModel]]:
"""
This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
and calls the LLM client side to stream the response.
Parameters:
messages (Optional[List[Dict[str, str]], optional): List of messages, OpenAI style
model (Optional[List[LLMConfig]]): List of models to choose from.
default (Optional[Union[LLMConfig, int, str]]): Default LLM.
max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
of routing you're willing to go.
latency_tracking (Optional[bool]): Latency tracking flag.
hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
for the router to determine the best LLM for a given query.
preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
Defaults to None.
metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
Defaults to Metric("accuracy").
previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
response into the given model. In which case result will a
dict.
timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
max_retries (int): The number of retries to attempt before giving up.
**kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
Raises:
ApiError: when the NotDiamond API fails
Yields:
Iterator[Union[BaseMessageChunk, BaseModel]]: returns the response in chunks.
If response_model is present, it will return the partial model object
"""
if model is not None:
llm_configs = self._parse_llm_configs_data(model)
self.llm_configs = llm_configs
self.validate_params(
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
)
response_model_parser = None
if response_model is not None:
self.verify_against_response_model()
response_model_parser = JsonOutputParser(
pydantic_object=response_model
)
best_llm, session_id = model_select(
messages=messages,
llm_configs=self.llm_configs,
metric=metric,
notdiamond_api_key=self.api_key,
max_model_depth=self.max_model_depth,
hash_content=self.hash_content,
tradeoff=self.tradeoff,
preference_id=self.preference_id,
tools=self.tools,
previous_session=previous_session,
timeout=timeout or self.timeout,
max_retries=max_retries or self.max_retries,
nd_api_url=self.nd_api_url,
)
if not best_llm:
LOGGER.warning(
f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
)
best_llm = self.default_llm
if best_llm.system_prompt is not None:
messages = inject_system_prompt(
messages, best_llm.system_prompt
)
if response_model is not None:
messages = _NDInvokerClient._inject_model_instruction(
messages, response_model_parser
)
self.call_callbacks("on_model_select", best_llm, best_llm.model)
llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
if self.tools:
llm = llm.bind_tools(self.tools)
if response_model is not None:
chain = llm | response_model_parser
else:
chain = llm
for chunk in chain.stream(messages, **kwargs):
if response_model is None:
yield chunk
else:
partial_model = create_partial_model(response_model)
yield partial_model(**chunk)
async def astream(
self,
messages: List[Dict[str, str]],
model: Optional[List[LLMConfig]] = None,
default: Optional[Union[LLMConfig, int, str]] = None,
max_model_depth: Optional[int] = None,
latency_tracking: Optional[bool] = None,
hash_content: Optional[bool] = None,
tradeoff: Optional[str] = None,
preference_id: Optional[str] = None,
metric: Metric = Metric("accuracy"),
previous_session: Optional[str] = None,
response_model: Optional[Type[BaseModel]] = None,
timeout: Optional[Union[float, int]] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> AsyncIterator[Union[BaseMessageChunk, BaseModel]]:
"""
This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
and calls the LLM client side to stream the response. The function is async, so it's suitable for async codebases.
Parameters:
messages (Optional[List[Dict[str, str]], optional): List of messages, OpenAI style
model (Optional[List[LLMConfig]]): List of models to choose from.
default (Optional[Union[LLMConfig, int, str]]): Default LLM.
max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
of routing you're willing to go.
latency_tracking (Optional[bool]): Latency tracking flag.
hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
for the router to determine the best LLM for a given query.
preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
Defaults to None.
metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
Defaults to Metric("accuracy").
previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
response into the given model. In which case result will a dict.
timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
max_retries (int): The number of retries to attempt before giving up.
**kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
Raises:
ApiError: when the NotDiamond API fails
Yields:
AsyncIterator[Union[BaseMessageChunk, BaseModel]]: returns the response in chunks.
If response_model is present, it will return the partial model object
"""
if model is not None:
llm_configs = self._parse_llm_configs_data(model)
self.llm_configs = llm_configs
self.validate_params(
default=default,
max_model_depth=max_model_depth,
latency_tracking=latency_tracking,
hash_content=hash_content,
tradeoff=tradeoff,
preference_id=preference_id,
)
response_model_parser = None
if response_model is not None:
self.verify_against_response_model()
response_model_parser = JsonOutputParser(
pydantic_object=response_model
)
best_llm, session_id = await amodel_select(
messages=messages,
llm_configs=self.llm_configs,
metric=metric,
notdiamond_api_key=self.api_key,
max_model_depth=self.max_model_depth,
hash_content=self.hash_content,
tradeoff=self.tradeoff,
preference_id=self.preference_id,
tools=self.tools,
previous_session=previous_session,
timeout=timeout or self.timeout,
max_retries=max_retries or self.max_retries,
nd_api_url=self.nd_api_url,
)
if not best_llm:
LOGGER.warning(
f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
)
best_llm = self.default_llm
if best_llm.system_prompt is not None:
messages = inject_system_prompt(
messages, best_llm.system_prompt
)
if response_model is not None:
messages = _NDInvokerClient._inject_model_instruction(
messages, response_model_parser
)
self.call_callbacks("on_model_select", best_llm, best_llm.model)
llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
if self.tools:
llm = llm.bind_tools(self.tools)
if response_model is not None:
chain = llm | response_model_parser
else:
chain = llm
async for chunk in chain.astream(messages, **kwargs):
if response_model is None:
yield chunk
else:
partial_model = create_partial_model(response_model)
yield partial_model(**chunk)
async def _async_invoke_with_latency_tracking(
self,
session_id: str,
chain: Any,
llm_config: LLMConfig,
input: Optional[Dict[str, Any]] = {},
is_default: bool = True,
**kwargs,
):
if session_id in ("NO-SESSION-ID", "") and not is_default:
error_message = (
"ND session_id is not valid for latency tracking."
+ "Please check the API response."
)
self.call_callbacks("on_api_error", error_message)
raise ApiError(error_message)
start_time = time.time()
result = await chain.ainvoke(input, **kwargs)
end_time = time.time()
tokens_completed = token_counter(
model=llm_config.model,
messages=[{"role": "assistant", "content": result.content}],
)
tokens_per_second = tokens_completed / (end_time - start_time)
report_latency(
session_id=session_id,
llm_config=llm_config,
tokens_per_second=tokens_per_second,
notdiamond_api_key=self.api_key,
nd_api_url=self.nd_api_url,
_user_agent=self.user_agent,
)
self.call_callbacks(
"on_latency_tracking",
session_id,
llm_config,
tokens_per_second,
)
return result
def _invoke_with_latency_tracking(
self,
session_id: str,
chain: Any,
llm_config: LLMConfig,
input: Optional[Dict[str, Any]] = {},
is_default: bool = True,
**kwargs,
):
LOGGER.debug(f"Latency tracking enabled, session_id={session_id}")
if session_id in ("NO-SESSION-ID", "") and not is_default:
error_message = (
"ND session_id is not valid for latency tracking."
+ "Please check the API response."
)
self.call_callbacks("on_api_error", error_message)
raise ApiError(error_message)
start_time = time.time()
result = chain.invoke(input, **kwargs)
end_time = time.time()
tokens_completed = token_counter(
model=llm_config.model,
messages=[{"role": "assistant", "content": result.content}],
)
tokens_per_second = tokens_completed / (end_time - start_time)
report_latency(
session_id=session_id,
llm_config=llm_config,
tokens_per_second=tokens_per_second,
notdiamond_api_key=self.api_key,
nd_api_url=self.nd_api_url,
_user_agent=self.user_agent,
)
self.call_callbacks(
"on_latency_tracking",
session_id,
llm_config,
tokens_per_second,
)
return result
@staticmethod
def _llm_from_config(
provider: LLMConfig,
callbacks: Optional[List] = None,
) -> Any:
default_kwargs = {"max_retries": 5, "timeout": 120}
passed_kwargs = {**default_kwargs, **provider.kwargs}
if provider.provider == "openai":
ChatOpenAI = _module_check(
"langchain_openai.chat_models",
"ChatOpenAI",
provider.provider,
)
if is_o1_model(provider):
passed_kwargs["temperature"] = 1.0
return ChatOpenAI(
openai_api_key=provider.api_key,
model_name=provider.model,
callbacks=callbacks,
**passed_kwargs,
)
if provider.provider == "anthropic":
ChatAnthropic = _module_check(
"langchain_anthropic", "ChatAnthropic", provider.provider
)
return ChatAnthropic(
anthropic_api_key=provider.api_key,
model=provider.model,
callbacks=callbacks,
**passed_kwargs,
)
if provider.provider == "google":
ChatGoogleGenerativeAI = _module_check(
"langchain_google_genai",
"ChatGoogleGenerativeAI",
provider.provider,
)
return ChatGoogleGenerativeAI(
google_api_key=provider.api_key,
model=provider.model,
convert_system_message_to_human=True,
callbacks=callbacks,
**passed_kwargs,
)
if provider.provider == "cohere":
ChatCohere = _module_check(
"langchain_cohere.chat_models",
"ChatCohere",
provider.provider,
)
return ChatCohere(
cohere_api_key=provider.api_key,
model=provider.model,
callbacks=callbacks,
**passed_kwargs,
)
if provider.provider == "mistral":
ChatMistralAI = _module_check(
"langchain_mistralai.chat_models",
"ChatMistralAI",
provider.provider,
)
return ChatMistralAI(
mistral_api_key=provider.api_key,
model=provider.model,
callbacks=callbacks,
**passed_kwargs,
)
if provider.provider == "togetherai":
provider_settings = settings.PROVIDERS.get(
provider.provider, None
)
model_prefixes = provider_settings.get("model_prefix", None)
model_prefix = model_prefixes.get(provider.model, None)
del passed_kwargs["max_retries"]
del passed_kwargs["timeout"]
if model_prefix is not None:
model = f"{model_prefix}/{provider.model}"
Together = _module_check(
"langchain_together", "Together", provider.provider
)
return Together(
together_api_key=provider.api_key,
model=model,
callbacks=callbacks,
**passed_kwargs,
)
if provider.provider == "perplexity":
del passed_kwargs["max_retries"]
passed_kwargs["request_timeout"] = passed_kwargs["timeout"]
del passed_kwargs["timeout"]
ChatPerplexity = _module_check(
"langchain_community.chat_models",
"ChatPerplexity",
provider.provider,
)
return ChatPerplexity(
pplx_api_key=provider.api_key,
model=provider.model,
callbacks=callbacks,
**passed_kwargs,
)
if provider.provider == "replicate":
provider_settings = settings.PROVIDERS.get(
provider.provider, None
)
model_prefixes = provider_settings.get("model_prefix", None)
model_prefix = model_prefixes.get(provider.model, None)
passed_kwargs["request_timeout"] = passed_kwargs["timeout"]
del passed_kwargs["timeout"]
if model_prefix is not None:
model = f"replicate/{model_prefix}/{provider.model}"
ChatLiteLLM = _module_check(
"langchain_community.chat_models",
"ChatLiteLLM",
provider.provider,
)
return ChatLiteLLM(
model=model,
callbacks=callbacks,
replicate_api_key=provider.api_key,
**passed_kwargs,
)
raise ValueError(f"Unsupported provider: {provider.provider}")
def verify_against_response_model(self) -> bool:
"""
Verify that the LLMs support response modeling.
"""
for provider in self.llm_configs:
if provider.model not in settings.PROVIDERS[
provider.provider
].get("support_response_model", []):
raise ApiError(
f"{provider.provider}/{provider.model} does not support response modeling."
)
return True
if import_target is _NDClientTarget.ROUTER:
return _NDRouterClient
return _NDInvokerClient
_NDClient = _ndllm_factory()
[docs]
class NotDiamond(_NDClient):
api_key: str
"""
API key required for making calls to NotDiamond.
You can get an API key via our dashboard: https://app.notdiamond.ai
If an API key is not set, it will check for NOTDIAMOND_API_KEY in .env file.
"""
llm_configs: Optional[List[Union[LLMConfig, str]]]
"""The list of LLMs that are available to route between."""
default: Union[LLMConfig, int, str]
"""
Set a default LLM, so in case anything goes wrong in the flow,
as for example NotDiamond API call fails, your code won't break and you have
a fallback model. There are various ways to configure a default model:
- Integer, specifying the index of the default provider from the llm_configs list
- String, similar how you can specify llm_configs, of structure 'provider_name/model_name'
- LLMConfig, just directly specify the object of the provider
By default, we will set your first LLM in the list as the default.
"""
max_model_depth: Optional[int]
"""
If your top recommended model is down, specify up to which depth of routing you're willing to go.
If max_model_depth is not set, it defaults to the length of the llm_configs list.
If max_model_depth is set to 0, the init will fail.
If the value is larger than the llm_configs list length, we reset the value to len(llm_configs).
"""
latency_tracking: bool
"""
Tracking and sending latency of LLM call to NotDiamond server as feedback, so we can improve our router.
By default this is turned on, set it to False to turn off.
"""
hash_content: bool
"""
Hashing the content before being sent to the NotDiamond API.
By default this is False.
"""
tradeoff: Optional[str]
"""
[DEPRECATED] The tradeoff constructor parameter is deprecated and will be removed in a future version.
Please specify the tradeoff when using model_select or invocation methods.
Define tradeoff between "cost" and "latency" for the router to determine the best LLM for a given query.
If None is specified, then the router will not consider either cost or latency.
The supported values: "cost", "latency"
Defaults to None.
"""
preference_id: Optional[str]
"""The ID of the router preference that was configured via the Dashboard. Defaults to None."""
tools: Optional[Sequence[Union[Dict[str, Any], Callable]]]
"""Bind tools to the LLM object. The tools will be passed to the LLM object when invoking it."""
nd_api_url: Optional[str]
"""The URL of the NotDiamond API. Defaults to settings.NOTDIAMOND_API_URL."""
user_agent: Union[str, None]
max_retries: int
"""The maximum number of retries to make when calling the Not Diamond API."""
timeout: float
"""The timeout for the Not Diamond API call."""
[docs]
class Config:
arbitrary_types_allowed = True
def __init__(
self,
nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
user_agent: Union[str, None] = settings.DEFAULT_USER_AGENT,
*args,
**kwargs,
):
super().__init__(
nd_api_url=nd_api_url, user_agent=user_agent, *args, **kwargs
)
self.nd_api_url = nd_api_url
if kwargs.get("tradeoff") is not None:
warnings.warn(
"The tradeoff constructor parameter is deprecated and will be removed in a "
"future version. Please specify the tradeoff when using model_select or invocation methods.",
DeprecationWarning,
stacklevel=2,
)
def _get_accepted_invoke_errors(provider: str) -> Tuple:
if provider == "google":
ChatGoogleGenerativeAIError = _module_check(
"langchain_google_genai.chat_models",
"ChatGoogleGenerativeAIError",
provider,
)
accepted_errors = (ChatGoogleGenerativeAIError, ValueError)
else:
accepted_errors = (ValueError,)
return accepted_errors