From 19f34f02b0c40e8715ad10cf977fb2a765cfa94a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Mon, 4 Mar 2024 17:55:09 +0100 Subject: [PATCH] quality: Avoid followed entries with assistant --- helpers/call.py | 17 +++++++++++------ main.py | 15 ++++++++++----- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/helpers/call.py b/helpers/call.py index d3d42865..010477f9 100644 --- a/helpers/call.py +++ b/helpers/call.py @@ -140,13 +140,18 @@ async def handle_play( See: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts """ if store: - call.messages.append( - MessageModel( - content=text, - persona=MessagePersonaEnum.ASSISTANT, - style=style, + if ( + call.messages[-1].persona == MessagePersonaEnum.ASSISTANT + ): # If the last message was from the assistant, append to it + call.messages[-1].content += f" {text}" + else: # Otherwise, create a new message + call.messages.append( + MessageModel( + content=text, + persona=MessagePersonaEnum.ASSISTANT, + style=style, + ) ) - ) # Split text in chunks of max 400 characters, separated by sentence chunks = [] diff --git a/main.py b/main.py index 7a5fb553..cb21b8a0 100644 --- a/main.py +++ b/main.py @@ -792,9 +792,15 @@ async def execute_llm_chat( 4. `CallModel`, the updated model """ _logger.debug("Running LLM chat") + content_full = "" should_user_answer = True - async def _buffer_user_callback( + async def _tools_callback(text: str, style: MessageStyleEnum) -> None: + nonlocal content_full + content_full += f" {text}" + await user_callback(text, style) + + async def _content_callback( buffer: str, style: MessageStyleEnum ) -> MessageStyleEnum: # Remove tool calls from buffer content and detect style @@ -850,7 +856,7 @@ async def _tools_cancellation_callback() -> None: post_call_next=post_call_next, post_call_synthesis=post_call_synthesis, search=search, - user_callback=user_callback, + user_callback=_tools_callback, ) tools = [] @@ -862,7 +868,6 @@ async def _tools_cancellation_callback() -> None: # Execute LLM inference content_buffer_pointer = 0 - content_full = "" tool_calls_buffer: dict[int, MessageToolModel] = {} try: async for delta in completion_stream( @@ -885,7 +890,7 @@ async def _tools_cancellation_callback() -> None: content_full[content_buffer_pointer:], False ): content_buffer_pointer += len(sentence) - plugins.style = await _buffer_user_callback(sentence, plugins.style) + plugins.style = await _content_callback(sentence, plugins.style) except ReadError: _logger.warn("Network error", exc_info=True) return True, True, should_user_answer, call @@ -895,7 +900,7 @@ async def _tools_cancellation_callback() -> None: # Flush the remaining buffer if content_buffer_pointer < len(content_full): - plugins.style = await _buffer_user_callback( + plugins.style = await _content_callback( content_full[content_buffer_pointer:], plugins.style )