Skip to content

Commit

Permalink
quality: Avoid followed entries with assistant
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Mar 4, 2024
1 parent 47fde82 commit 19f34f0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
17 changes: 11 additions & 6 deletions helpers/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,18 @@ async def handle_play(
See: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts
"""
if store:
call.messages.append(
MessageModel(
content=text,
persona=MessagePersonaEnum.ASSISTANT,
style=style,
if (
call.messages[-1].persona == MessagePersonaEnum.ASSISTANT
): # If the last message was from the assistant, append to it
call.messages[-1].content += f" {text}"
else: # Otherwise, create a new message
call.messages.append(
MessageModel(
content=text,
persona=MessagePersonaEnum.ASSISTANT,
style=style,
)
)
)

# Split text in chunks of max 400 characters, separated by sentence
chunks = []
Expand Down
15 changes: 10 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,15 @@ async def execute_llm_chat(
4. `CallModel`, the updated model
"""
_logger.debug("Running LLM chat")
content_full = ""
should_user_answer = True

async def _buffer_user_callback(
async def _tools_callback(text: str, style: MessageStyleEnum) -> None:
nonlocal content_full
content_full += f" {text}"
await user_callback(text, style)

async def _content_callback(
buffer: str, style: MessageStyleEnum
) -> MessageStyleEnum:
# Remove tool calls from buffer content and detect style
Expand Down Expand Up @@ -850,7 +856,7 @@ async def _tools_cancellation_callback() -> None:
post_call_next=post_call_next,
post_call_synthesis=post_call_synthesis,
search=search,
user_callback=user_callback,
user_callback=_tools_callback,
)

tools = []
Expand All @@ -862,7 +868,6 @@ async def _tools_cancellation_callback() -> None:

# Execute LLM inference
content_buffer_pointer = 0
content_full = ""
tool_calls_buffer: dict[int, MessageToolModel] = {}
try:
async for delta in completion_stream(
Expand All @@ -885,7 +890,7 @@ async def _tools_cancellation_callback() -> None:
content_full[content_buffer_pointer:], False
):
content_buffer_pointer += len(sentence)
plugins.style = await _buffer_user_callback(sentence, plugins.style)
plugins.style = await _content_callback(sentence, plugins.style)
except ReadError:
_logger.warn("Network error", exc_info=True)
return True, True, should_user_answer, call
Expand All @@ -895,7 +900,7 @@ async def _tools_cancellation_callback() -> None:

# Flush the remaining buffer
if content_buffer_pointer < len(content_full):
plugins.style = await _buffer_user_callback(
plugins.style = await _content_callback(
content_full[content_buffer_pointer:], plugins.style
)

Expand Down

0 comments on commit 19f34f0

Please sign in to comment.