Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Groups conversations based on the user's messages. #546

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 147 additions & 80 deletions src/codegate/dashboard/post_processing.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import asyncio
import json
import re
from typing import List, Optional, Tuple, Union
from collections import defaultdict
from typing import List, Optional, Union

import structlog

from codegate.dashboard.request_models import (
AlertConversation,
ChatMessage,
Conversation,
PartialConversation,
PartialQuestionAnswer,
PartialQuestions,
QuestionAnswer,
)
from codegate.db.models import GetAlertsWithPromptAndOutputRow, GetPromptWithOutputsRow
Expand Down Expand Up @@ -74,60 +76,57 @@ async def parse_request(request_str: str) -> Optional[str]:
return None

# Only respond with the latest message
return messages[-1]
return messages


async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]:
async def parse_output(output_str: str) -> Optional[str]:
"""
Parse the output string from the pipeline and return the message and chat_id.
Parse the output string from the pipeline and return the message.
"""
try:
if output_str is None:
return None, None
return None

output = json.loads(output_str)
except Exception as e:
logger.warning(f"Error parsing output: {output_str}. {e}")
return None, None
return None

def _parse_single_output(single_output: dict) -> str:
single_chat_id = single_output.get("id")
single_output_message = ""
for choice in single_output.get("choices", []):
if not isinstance(choice, dict):
continue
content_dict = choice.get("delta", {}) or choice.get("message", {})
single_output_message += content_dict.get("content", "")
return single_output_message, single_chat_id
return single_output_message

full_output_message = ""
chat_id = None
if isinstance(output, list):
for output_chunk in output:
output_message, output_chat_id = "", None
output_message = ""
if isinstance(output_chunk, dict):
output_message, output_chat_id = _parse_single_output(output_chunk)
output_message = _parse_single_output(output_chunk)
elif isinstance(output_chunk, str):
try:
output_decoded = json.loads(output_chunk)
output_message, output_chat_id = _parse_single_output(output_decoded)
output_message = _parse_single_output(output_decoded)
except Exception:
logger.error(f"Error reading chunk: {output_chunk}")
else:
logger.warning(
f"Could not handle output: {output_chunk}", out_type=type(output_chunk)
)
chat_id = chat_id or output_chat_id
full_output_message += output_message
elif isinstance(output, dict):
full_output_message, chat_id = _parse_single_output(output)
full_output_message = _parse_single_output(output)

return full_output_message, chat_id
return full_output_message


async def _get_question_answer(
row: Union[GetPromptWithOutputsRow, GetAlertsWithPromptAndOutputRow]
) -> Tuple[Optional[QuestionAnswer], Optional[str]]:
) -> Optional[PartialQuestionAnswer]:
"""
Parse a row from the get_prompt_with_outputs query and return a PartialConversation

Expand All @@ -137,17 +136,19 @@ async def _get_question_answer(
request_task = tg.create_task(parse_request(row.request))
output_task = tg.create_task(parse_output(row.output))

request_msg_str = request_task.result()
output_msg_str, chat_id = output_task.result()
request_user_msgs = request_task.result()
output_msg_str = output_task.result()

# If we couldn't parse the request or output, return None
if not request_msg_str:
return None, None
# If we couldn't parse the request, return None
if not request_user_msgs:
return None

request_message = ChatMessage(
message=request_msg_str,
request_message = PartialQuestions(
messages=request_user_msgs,
timestamp=row.timestamp,
message_id=row.id,
provider=row.provider,
type=row.type,
)
if output_msg_str:
output_message = ChatMessage(
Expand All @@ -157,28 +158,7 @@ async def _get_question_answer(
)
else:
output_message = None
chat_id = row.id
return QuestionAnswer(question=request_message, answer=output_message), chat_id


async def parse_get_prompt_with_output(
row: GetPromptWithOutputsRow,
) -> Optional[PartialConversation]:
"""
Parse a row from the get_prompt_with_outputs query and return a PartialConversation

The row contains the raw request and output strings from the pipeline.
"""
question_answer, chat_id = await _get_question_answer(row)
if not question_answer or not chat_id:
return None
return PartialConversation(
question_answer=question_answer,
provider=row.provider,
type=row.type,
chat_id=chat_id,
request_timestamp=row.timestamp,
)
return PartialQuestionAnswer(partial_questions=request_message, answer=output_message)


def parse_question_answer(input_text: str) -> str:
Expand All @@ -195,50 +175,135 @@ def parse_question_answer(input_text: str) -> str:
return input_text


def _group_partial_messages(pq_list: List[PartialQuestions]) -> List[List[PartialQuestions]]:
"""
A PartialQuestion is an object that contains several user messages provided from a
chat conversation. Example:
- PartialQuestion(messages=["Hello"], timestamp=2022-01-01T00:00:00Z)
- PartialQuestion(messages=["Hello", "How are you?"], timestamp=2022-01-01T00:00:01Z)
In the above example both PartialQuestions are part of the same conversation and should be
matched together.
Group PartialQuestions objects such that:
- If one PartialQuestion (pq) is a subset of another pq's messages, group them together.
- If multiple subsets exist for the same superset, choose only the one
closest in timestamp to the superset.
- Leave any unpaired pq by itself.
- Finally, sort the resulting groups by the earliest timestamp in each group.
"""
# 1) Sort by length of messages descending (largest/most-complete first),
# then by timestamp ascending for stable processing.
pq_list_sorted = sorted(pq_list, key=lambda x: (-len(x.messages), x.timestamp))

used = set()
groups = []

# 2) Iterate in order of "largest messages first"
for sup in pq_list_sorted:
if sup.message_id in used:
continue # Already grouped

# Find all potential subsets of 'sup' that are not yet used
# (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
possible_subsets = []
for sub in pq_list_sorted:
if sub.message_id == sup.message_id:
continue
if sub.message_id in used:
continue
if (
set(sub.messages).issubset(set(sup.messages))
and sub.provider == sup.provider
and set(sub.messages) != set(sup.messages)
):
possible_subsets.append(sub)

# 3) If there are no subsets, this sup stands alone
if not possible_subsets:
groups.append([sup])
used.add(sup.message_id)
else:
# 4) Group subsets by messages to discard duplicates e.g.: 2 subsets with single 'hello'
subs_group_by_messages = defaultdict(list)
for q in possible_subsets:
subs_group_by_messages[tuple(q.messages)].append(q)

new_group = [sup]
used.add(sup.message_id)
for subs_same_message in subs_group_by_messages.values():
# If more than one pick the one subset closest in time to sup
closest_subset = min(
subs_same_message, key=lambda s: abs(s.timestamp - sup.timestamp)
)
new_group.append(closest_subset)
used.add(closest_subset.message_id)
groups.append(new_group)

# 5) Sort the groups by the earliest timestamp within each group
groups.sort(key=lambda g: min(pq.timestamp for pq in g))
return groups


def _get_question_answer_from_partial(
partial_question_answer: PartialQuestionAnswer,
) -> QuestionAnswer:
"""
Get a QuestionAnswer object from a PartialQuestionAnswer object.
"""
# Get the last user message as the question
question = ChatMessage(
message=partial_question_answer.partial_questions.messages[-1],
timestamp=partial_question_answer.partial_questions.timestamp,
message_id=partial_question_answer.partial_questions.message_id,
)

return QuestionAnswer(question=question, answer=partial_question_answer.answer)


async def match_conversations(
partial_conversations: List[Optional[PartialConversation]],
partial_question_answers: List[Optional[PartialQuestionAnswer]],
) -> List[Conversation]:
"""
Match partial conversations to form a complete conversation.
"""
convers = {}
for partial_conversation in partial_conversations:
if not partial_conversation:
continue

# Group by chat_id
if partial_conversation.chat_id not in convers:
convers[partial_conversation.chat_id] = []
convers[partial_conversation.chat_id].append(partial_conversation)
valid_partial_qas = [
partial_qas for partial_qas in partial_question_answers if partial_qas is not None
]
grouped_partial_questions = _group_partial_messages(
[partial_qs_a.partial_questions for partial_qs_a in valid_partial_qas]
)

# Sort by timestamp
sorted_convers = {
chat_id: sorted(conversations, key=lambda x: x.request_timestamp)
for chat_id, conversations in convers.items()
}
# Create the conversation objects
conversations = []
for chat_id, sorted_convers in sorted_convers.items():
for group in grouped_partial_questions:
questions_answers = []
first_partial_conversation = None
for partial_conversation in sorted_convers:
first_partial_qa = None
for partial_question in sorted(group, key=lambda x: x.timestamp):
# Partial questions don't contain the answer, so we need to find the corresponding
selected_partial_qa = None
for partial_qa in valid_partial_qas:
if partial_question.message_id == partial_qa.partial_questions.message_id:
selected_partial_qa = partial_qa
break

# check if we have an answer, otherwise do not add it
if partial_conversation.question_answer.answer is not None:
first_partial_conversation = partial_conversation
partial_conversation.question_answer.question.message = parse_question_answer(
partial_conversation.question_answer.question.message
if selected_partial_qa.answer is not None:
# if we don't have a first question, set it
first_partial_qa = first_partial_qa or selected_partial_qa
question_answer = _get_question_answer_from_partial(selected_partial_qa)
question_answer.question.message = parse_question_answer(
question_answer.question.message
)
questions_answers.append(partial_conversation.question_answer)
questions_answers.append(question_answer)

# only add conversation if we have some answers
if len(questions_answers) > 0 and first_partial_conversation is not None:
if len(questions_answers) > 0 and first_partial_qa is not None:
conversations.append(
Conversation(
question_answers=questions_answers,
provider=first_partial_conversation.provider,
type=first_partial_conversation.type,
chat_id=chat_id,
conversation_timestamp=sorted_convers[0].request_timestamp,
provider=first_partial_qa.partial_questions.provider,
type=first_partial_qa.partial_questions.type,
chat_id=first_partial_qa.partial_questions.message_id,
conversation_timestamp=first_partial_qa.partial_questions.timestamp,
)
)

Expand All @@ -254,10 +319,10 @@ async def parse_messages_in_conversations(

# Parse the prompts and outputs in parallel
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(parse_get_prompt_with_output(row)) for row in prompts_outputs]
partial_conversations = [task.result() for task in tasks]
tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs]
partial_question_answers = [task.result() for task in tasks]

conversations = await match_conversations(partial_conversations)
conversations = await match_conversations(partial_question_answers)
return conversations


Expand All @@ -269,15 +334,17 @@ async def parse_row_alert_conversation(

The row contains the raw request and output strings from the pipeline.
"""
question_answer, chat_id = await _get_question_answer(row)
if not question_answer or not chat_id:
partial_qa = await _get_question_answer(row)
if not partial_qa:
return None

question_answer = _get_question_answer_from_partial(partial_qa)

conversation = Conversation(
question_answers=[question_answer],
provider=row.provider,
type=row.type,
chat_id=chat_id or "chat-id-not-found",
chat_id=row.id,
conversation_timestamp=row.timestamp,
)
code_snippet = json.loads(row.code_snippet) if row.code_snippet else None
Expand Down
19 changes: 14 additions & 5 deletions src/codegate/dashboard/request_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,25 @@ class QuestionAnswer(BaseModel):
answer: Optional[ChatMessage]


class PartialConversation(BaseModel):
class PartialQuestions(BaseModel):
"""
Represents a partial conversation obtained from a DB row.
Represents all user messages obtained from a DB row.
"""

question_answer: QuestionAnswer
messages: List[str]
timestamp: datetime.datetime
message_id: str
provider: Optional[str]
type: str
chat_id: str
request_timestamp: datetime.datetime


class PartialQuestionAnswer(BaseModel):
"""
Represents a partial conversation.
"""

partial_questions: PartialQuestions
answer: Optional[ChatMessage]


class Conversation(BaseModel):
Expand Down
Loading
Loading