Skip to content

Commit

Permalink
Addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
srjoglekar246 committed Jan 8, 2025
1 parent 1e18fb3 commit 652333d
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 128 deletions.
2 changes: 1 addition & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
from autogen_agentchat.ui import Console
from autogen_core import AgentId, CancellationToken
from autogen_core.models import ReplayChatCompletionClient
from autogen_core.tools import FunctionTool
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from autogen_agentchat.teams._group_chat._magentic_one._magentic_one_orchestrator import MagenticOneOrchestrator
from autogen_core import AgentId, CancellationToken
from autogen_ext.models.replay import ReplayChatCompletionClient
from autogen_core.models import ReplayChatCompletionClient
from utils import FileLogHandler

logger = logging.getLogger(EVENT_LOGGER_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
"source": [
"# Models\n",
"\n",
"In many cases, agents need access to LLM model services such as OpenAI, Azure OpenAI, or local models. Since there are many different providers with different APIs, `autogen-core` implements a protocol for [model clients](../../core-user-guide/framework/model-clients.ipynb) and `autogen-ext` implements a set of model clients for popular model services. AgentChat can use these model clients to interact with model services. "
"In many cases, agents need access to LLM model services such as OpenAI, Azure OpenAI, or local models. Since there are many different providers with different APIs, `autogen-core` implements a protocol for [model clients](../../core-user-guide/framework/model-clients.ipynb) and `autogen-ext` implements a set of model clients for popular model services. AgentChat can use these model clients to interact with model services. \n",
"\n",
"**NOTE:** See {py:class}`~autogen_core.models.ChatCompletionCache` for a caching wrapper to use with the following clients."
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Default [Model Capabilities](../faqs.md#what-are-model-capabilities-and-how-do-i-specify-them) may be overridden should the need arise.\n",
"Default [Model Capabilities](../faqs.md#what-are-model-capabilities-and-how-do-i-specify-them) may be overridden should the need arise.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"### Streaming Response\n",
Expand Down Expand Up @@ -174,6 +180,58 @@
"**NB the default usage response is to return zero values**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Caching Wrapper\n",
"\n",
"`autogen_core` implements a {py:class}`~autogen_core.models.ChatCompletionCache` that can wrap any {py:class}`~autogen_core.models.ChatCompletionClient`. Using this wrapper avoids incurring token usage when querying the underlying client with the same prompt multiple times. \n",
"\n",
"{py:class}`~autogen_core.models.ChatCompletionCache` uses a {py:class}`~autogen_core.CacheStore` protocol to allow duck-typing any storage object that has a pair of `get` & `set` methods (such as `redis.Redis` or `diskcache.Cache`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Dict, Optional\n",
"\n",
"from autogen_core import CacheStore\n",
"from autogen_core.models import ChatCompletionCache\n",
"\n",
"\n",
"# Simple CacheStore implementation using in-memory dict,\n",
"# you can also use redis.Redis or diskcache.Cache\n",
"class DictStore(CacheStore):\n",
" def __init__(self) -> None:\n",
" self._store: Dict[Any, Any] = {}\n",
"\n",
" def get(self, key: Any, default: Optional[Any] = None) -> Optional[Any]:\n",
" return self._store.get(key, default)\n",
"\n",
" def set(self, key: Any, value: Any) -> None:\n",
" self._store[key] = value\n",
"\n",
"\n",
"cached_client = ChatCompletionCache(model_client, DictStore())\n",
"response = await cached_client.create(messages=messages)\n",
"\n",
"cached_response = await cached_client.create(messages=messages)\n",
"print(cached_response.cached)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inspecting `cached_client.total_usage()` (or `model_client.total_usage()`) before and after a cached response should yield idential counts.\n",
"\n",
"Note that the caching is sensitive to the exact arguments provided to `cached_client.create` or `cached_client.create_stream`, so changing `tools` or `json_output` arguments might lead to a cache miss."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
2 changes: 2 additions & 0 deletions python/packages/autogen-core/src/autogen_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._agent_runtime import AgentRuntime
from ._agent_type import AgentType
from ._base_agent import BaseAgent
from ._cache_store import CacheStore
from ._cancellation_token import CancellationToken
from ._closure_agent import ClosureAgent, ClosureContext
from ._component_config import (
Expand Down Expand Up @@ -80,6 +81,7 @@
"AgentMetadata",
"AgentRuntime",
"BaseAgent",
"CacheStore",
"CancellationToken",
"AgentInstantiationContext",
"TopicId",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional, Protocol


class AbstractStore(Protocol):
class CacheStore(Protocol):
"""
This protocol defines the basic interface for store/cache operations.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ._cache import ChatCompletionCache
from ._model_client import ChatCompletionClient, ModelCapabilities, ModelFamily, ModelInfo # type: ignore
from ._replay_chat_completion_client import ReplayChatCompletionClient
from ._types import (
AssistantMessage,
ChatCompletionTokenLogprob,
Expand All @@ -15,6 +17,7 @@

__all__ = [
"ModelCapabilities",
"ChatCompletionCache",
"ChatCompletionClient",
"SystemMessage",
"UserMessage",
Expand All @@ -29,4 +32,5 @@
"ChatCompletionTokenLogprob",
"ModelFamily",
"ModelInfo",
"ReplayChatCompletionClient",
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import CancellationToken
from autogen_core.models import (
from .._cache_store import CacheStore
from .._cancellation_token import CancellationToken
from ..tools import Tool, ToolSchema
from ._model_client import (
ChatCompletionClient,
CreateResult,
LLMMessage,
ModelCapabilities, # type: ignore
ModelInfo,
)
from ._types import (
CreateResult,
LLMMessage,
RequestUsage,
)
from autogen_core.store import AbstractStore
from autogen_core.tools import Tool, ToolSchema


class ChatCompletionCache(ChatCompletionClient):
Expand All @@ -22,13 +24,42 @@ class ChatCompletionCache(ChatCompletionClient):
Cache hits do not contribute to token usage of the original client.
"""

def __init__(self, client: ChatCompletionClient, store: AbstractStore):
def __init__(self, client: ChatCompletionClient, store: CacheStore):
"""
Initialize a new ChatCompletionCache.
First initialize (for eg) a Redis store:
```python
import redis
redis_client = redis.Redis(host="localhost", port=6379, db=0)
```
or diskcache store:
```python
from diskcache import Cache
diskcache_client = Cache("/tmp/diskcache")
```
Then initialize the ChatCompletionCache with the store:
```python
from autogen_core.models import ChatCompletionCache
from autogen_ext.models import OpenAIChatCompletionClient
# Original client
client = OpenAIChatCompletionClient(...)
# Cached version
cached_client = ChatCompletionCache(client, redis_client)
```
Args:
client (ChatCompletionClient): The original ChatCompletionClient to wrap.
store (AbstractStore): A store object that implements get and set methods.
store (CacheStore): A store object that implements get and set methods.
The user is responsible for managing the store's lifecycle & clearing it (if needed).
"""
self.client = client
Expand All @@ -40,17 +71,11 @@ def _check_cache(
tools: Sequence[Tool | ToolSchema],
json_output: Optional[bool],
extra_create_args: Mapping[str, Any],
force_cache: bool,
force_client: bool,
) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]:
"""
Helper function to check the cache for a result.
Returns a tuple of (cached_result, cache_key).
cached_result is None if the cache is empty or force_client is True.
Raises an error if there is a cache miss and force_cache is True.
"""
if force_client and force_cache:
raise ValueError("force_cache and force_client cannot both be True")

data = {
"messages": [message.model_dump() for message in messages],
Expand All @@ -61,12 +86,9 @@ def _check_cache(
serialized_data = json.dumps(data, sort_keys=True)
cache_key = hashlib.sha256(serialized_data.encode()).hexdigest()

if not force_client:
cached_result = cast(Optional[CreateResult], self.store.get(cache_key))
if cached_result is not None:
return cached_result, cache_key
elif force_cache:
raise ValueError("Encountered cache miss for force_cache request")
cached_result = cast(Optional[CreateResult], self.store.get(cache_key))
if cached_result is not None:
return cached_result, cache_key

return None, cache_key

Expand All @@ -78,23 +100,15 @@ async def create(
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
force_cache: bool = False,
force_client: bool = False,
) -> CreateResult:
"""
Cached version of ChatCompletionClient.create.
If the result of a call to create has been cached, it will be returned immediately
without invoking the underlying client.
NOTE: cancellation_token is ignored for cached results.
Additional parameters:
- force_cache: If True, the cache will be used and an error will be raised if a result is unavailable.
- force_client: If True, the cache will be bypassed and the underlying client will be called.
"""
cached_result, cache_key = self._check_cache(
messages, tools, json_output, extra_create_args, force_cache, force_client
)
cached_result, cache_key = self._check_cache(messages, tools, json_output, extra_create_args)
if cached_result:
assert isinstance(cached_result, CreateResult)
cached_result.cached = True
Expand All @@ -118,27 +132,21 @@ def create_stream(
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
force_cache: bool = False,
force_client: bool = False,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""
Cached version of ChatCompletionClient.create_stream.
If the result of a call to create_stream has been cached, it will be returned
without streaming from the underlying client.
NOTE: cancellation_token is ignored for cached results.
Additional parameters:
- force_cache: If True, the cache will be used and an error will be raised if a result is unavailable.
- force_client: If True, the cache will be bypassed and the underlying client will be called.
"""

if force_client and force_cache:
raise ValueError("force_cache and force_client cannot both be True")

async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]:
cached_result, cache_key = self._check_cache(
messages, tools, json_output, extra_create_args, force_cache, force_client
messages,
tools,
json_output,
extra_create_args,
)
if cached_result:
assert isinstance(cached_result, list)
Expand All @@ -153,6 +161,7 @@ async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]:
tools=tools,
json_output=json_output,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)

output_results: List[Union[str, CreateResult]] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union

from autogen_core import EVENT_LOGGER_NAME, CancellationToken
from autogen_core.models import (
from .. import EVENT_LOGGER_NAME
from .._cancellation_token import CancellationToken
from ..tools import Tool, ToolSchema
from ._model_client import (
ChatCompletionClient,
CreateResult,
LLMMessage,
ModelCapabilities, # type: ignore
ModelFamily,
ModelInfo,
)
from ._types import (
CreateResult,
LLMMessage,
RequestUsage,
)
from autogen_core.tools import Tool, ToolSchema

logger = logging.getLogger(EVENT_LOGGER_NAME)

Expand All @@ -40,8 +43,7 @@ class ReplayChatCompletionClient(ChatCompletionClient):
.. code-block:: python
from autogen_ext.models.replay import ReplayChatCompletionClient
from autogen_core.models import UserMessage
from autogen_core.models import ReplayChatCompletionClient, UserMessage
async def example():
Expand All @@ -60,8 +62,7 @@ async def example():
.. code-block:: python
import asyncio
from autogen_ext.models.replay import ReplayChatCompletionClient
from autogen_core.models import UserMessage
from autogen_core.models import ReplayChatCompletionClient, UserMessage
async def example():
Expand All @@ -86,8 +87,7 @@ async def example():
.. code-block:: python
import asyncio
from autogen_ext.models.replay import ReplayChatCompletionClient
from autogen_core.models import UserMessage
from autogen_core.models import ReplayChatCompletionClient, UserMessage
async def example():
Expand Down

This file was deleted.

Loading

0 comments on commit 652333d

Please sign in to comment.