import os
from typing import Dict, List, Union
from notdiamond.toolkit._retry import (
AsyncRetryWrapper,
ClientType,
ModelType,
OpenAIMessagesType,
RetryManager,
RetryWrapper,
)
[docs]
def init(
client: Union[ClientType, List[ClientType]],
models: ModelType,
max_retries: Union[int, Dict[str, int]] = 1,
timeout: Union[float, Dict[str, float]] = 60.0,
model_messages: Dict[str, OpenAIMessagesType] = None,
api_key: Union[str, None] = None,
async_mode: bool = False,
backoff: Union[float, Dict[str, float]] = 2.0,
) -> RetryManager:
"""
Entrypoint for fallback and retry features without changing existing code.
Add this to existing codebase without other modifications to enable the following capabilities:
- Fallback to a different model if a model invocation fails.
- If configured, fallback to a different *provider* if a model invocation fails
(eg. azure/gpt-4o fails -> invoke openai/gpt-4o)
- Load-balance between models and providers, if specified.
- Pass timeout and retry configurations to each invoke, optionally configured per model.
- Pass model-specific messages on each retry (prepended to the provided `messages` parameter)
Parameters:
client (Union[ClientType, List[ClientType]]): Clients to apply retry/fallback logic to.
models (Union[Dict[str, float], List[str]]):
Models to use of the format <provider>/<model>.
Supports two formats:
- List of models, eg. ["openai/gpt-4o", "azure/gpt-4o"]. Models will be prioritized as listed.
- Dict of models to weights for load balancing, eg. {"openai/gpt-4o": 0.9, "azure/gpt-4o": 0.1}.
If a model invocation fails, the next model is selected by sampling using the *remaining* weights.
max_retries (Union[int, Dict[str, int]]):
Maximum number of retries. Can be configured globally or per model.
timeout (Union[float, Dict[str, float]]):
Timeout in seconds per model. Can be configured globally or per model.
model_messages (Dict[str, OpenAIMessagesType]):
Model-specific messages to prepend to `messages` on each invocation, formatted OpenAI-style. Can be
configured using any role which is valid as an initial message (eg. "system" or "user", but not "assistant").
api_key (Optional[str]):
Not Diamond API key for authentication. Unused for now - will offer logging and metrics in the future.
async_mode (bool):
Whether to manage clients as async.
backoff (Union[float, Dict[str, float]]):
Backoff factor for exponential backoff per each retry. Can be configured globally or per model.
Returns:
RetryManager: Manager object that handles retries and fallbacks. Not required for usage.
Model Fallback Prioritization
-----------------------------
- If models is a list, the fallback model is selected in order after removing the failed model.
eg. If "openai/gpt-4o" fails for the list:
- ["openai/gpt-4o", "azure/gpt-4o"], "azure/gpt-4o" will be tried next
- ["openai/gpt-4o-mini", "openai/gpt-4o", "azure/gpt-4o"], "openai/gpt-4o-mini" will be tried next.
- If models is a dict, the next model is selected by sampling using the *remaining* weights.
eg. If "openai/gpt-4o" fails for the dict:
- {"openai/gpt-4o": 0.9, "azure/gpt-4o": 0.1}, "azure/gpt-4o" will be invoked 100% of the time
- {"openai/gpt-4o": 0.5, "azure/gpt-4o": 0.25, "openai/gpt-4o-mini": 0.25}, then "azure/gpt-4o" and
"openai/gpt-4o-mini" can be invoked with 50% probability each.
Usage
-----
Please refer to tests/test_init.py for more examples on how to use notdiamond.init.
.. code-block:: python
# ...existing workflow code, including client initialization...
openai_client = OpenAI(...)
azure_client = AzureOpenAI(...)
# Add `notdiamond.init` to the workflow.
notdiamond.init(
[openai_client, azure_client],
models={"openai/gpt-4o": 0.9, "azure/gpt-4o": 0.1},
max_retries={"openai/gpt-4o": 3, "azure/gpt-4o": 1},
timeout={"openai/gpt-4o": 10.0, "azure/gpt-4o": 5.0},
model_messages={
"openai/gpt-4o": [{"role": "user", "content": "Here is a prompt for OpenAI."}],
"azure/gpt-4o": [{"role": "user", "content": "Here is a prompt for Azure."}],
},
api_key="sk-...",
backoff=2.0,
)
# ...continue existing workflow code...
response = openai_client.chat.completions.create(
model="notdiamond",
messages=[{"role": "user", "content": "Hello!"}]
)
"""
api_key = api_key or os.getenv("NOTDIAMOND_API_KEY")
if async_mode:
wrapper_cls = AsyncRetryWrapper
else:
wrapper_cls = RetryWrapper
for model in models:
if len(model.split("/")) != 2:
raise ValueError(
f"Model {model} must be in the format <provider>/<model>."
)
if not isinstance(client, List):
client_wrappers = [
wrapper_cls(
client=client,
models=models,
max_retries=max_retries,
timeout=timeout,
model_messages=model_messages,
api_key=api_key,
backoff=backoff,
)
]
else:
client_wrappers = [
wrapper_cls(
client=cc,
models=models,
max_retries=max_retries,
timeout=timeout,
model_messages=model_messages,
api_key=api_key,
backoff=backoff,
)
for cc in client
]
retry_manager = RetryManager(models, client_wrappers)
return retry_manager