Skip to content

Commit

Permalink
solve some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SonglinLyu committed Jan 20, 2025
1 parent 28f8605 commit 9a27b1d
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 49 deletions.
10 changes: 10 additions & 0 deletions dbgpt/rag/transformer/agentic_intent_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Agentic ntentTranslator class."""
import logging

from dbgpt.rag.transformer.base import TranslatorBase

logger = logging.getLogger(__name__)


class AgenticIntentTranslator(TranslatorBase):
"""Agentic ntentTranslator class."""
10 changes: 0 additions & 10 deletions dbgpt/rag/transformer/awel_intent_interpreter.py

This file was deleted.

10 changes: 10 additions & 0 deletions dbgpt/rag/transformer/awel_intent_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""AwelIntentTranslator class."""
import logging

from dbgpt.rag.transformer.base import TranslatorBase

logger = logging.getLogger(__name__)


class AwelIntentTranslator(TranslatorBase):
"""AwelIntentTranslator class."""
10 changes: 0 additions & 10 deletions dbgpt/rag/transformer/mas_intent_interpreter.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""IntentInterpreter class."""
"""SimpleIntentTranslator class."""
import json
import logging
import re
Expand Down Expand Up @@ -42,11 +42,11 @@
logger = logging.getLogger(__name__)


class IntentInterpreter(LLMTranslator):
"""IntentInterpreter class."""
class SimpleIntentTranslator(LLMTranslator):
"""SimpleIntentTranslator class."""

def __init__(self, llm_client: LLMClient, model_name: str):
"""Initialize the IntentInterpreter."""
"""Initialize the SimpleIntentTranslator."""
super().__init__(llm_client, model_name, INTENT_INTERPRET_PT)

def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]:
Expand Down
10 changes: 3 additions & 7 deletions dbgpt/storage/knowledge_graph/community_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.rag.transformer.graph_embedder import GraphEmbedder
from dbgpt.rag.transformer.graph_extractor import GraphExtractor
from dbgpt.rag.transformer.intent_interpreter import IntentInterpreter
from dbgpt.rag.transformer.text2gql import Text2GQL
from dbgpt.rag.transformer.text_embedder import TextEmbedder
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
from dbgpt.storage.knowledge_graph.graph_retriever.graph_retriever_router import GraphRetrieverRouter
from dbgpt.storage.knowledge_graph.graph_retriever.graph_retriever import GraphRetriever
from dbgpt.storage.knowledge_graph.knowledge_graph import (
GRAPH_PARAMETERS,
BuiltinKnowledgeGraph,
Expand Down Expand Up @@ -362,12 +361,9 @@ def community_store_configure(name: str, cfg: VectorStoreConfig):
),
)

self._intent_interpreter = IntentInterpreter(self._llm_client, self._model_name)
self._text2gql = Text2GQL(self._llm_client, self._model_name)

self._knowledge_graph_triplet_search_top_size = 5
self._knowledge_graph_document_search_top_size = 5
self._graph_retriever_router = GraphRetrieverRouter(
self._graph_retriever = GraphRetriever(
config,
self._graph_store_apdater,
)
Expand Down Expand Up @@ -548,7 +544,7 @@ async def asimilar_search_with_scores(
]
context = "\n".join(summaries) if summaries else ""

subgraph, subgraph_for_doc, text2gql_query = await self._graph_retriever_router.retrieve(text)
subgraph, (subgraph_for_doc, text2gql_query) = await self._graph_retriever.retrieve(text)

knowledge_graph_str = subgraph.format() if subgraph else ""
knowledge_graph_for_doc_str = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def __init__(
self._similarity_search_score_threshold = similarity_search_score_threshold

async def retrieve(
self, subs: Optional[Union[List[str], List[List[float]]]], triplet_graph: Optional[Graph]
self, input: Union[Graph, List[str], List[List[float]]]
) -> Tuple[Graph, None]:
"""Retrieve from document graph."""

if triplet_graph:
if isinstance(input, Graph):
# If retrieve subgraph from triplet graph successfully
# Using the vids to search chunks and doc
keywords_for_document_graph = []
for vertex in triplet_graph.vertices():
for vertex in input.vertices():
keywords_for_document_graph.append(vertex.name)
# Using the vids to search chunks and doc
# entities -> chunks -> doc
Expand All @@ -49,7 +49,7 @@ async def retrieve(
# Using subs to search chunks
# subs -> chunks -> doc
subgraph_for_doc = self._graph_store_apdater.explore_docgraph_without_entities(
subs=subs,
subs=input,
topk=self._similarity_search_topk,
score_threshold=self._similarity_search_score_threshold,
limit=self._document_topk,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
logger = logging.getLogger(__name__)


class GraphRetrieverRouter:
"""Graph Retriever Router class."""
class GraphRetriever(GraphRetrieverBase):
"""Graph Retriever class."""

def __init__(
self,
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
similarity_search_score_threshold,
)

async def retrieve(self, text: str) -> Tuple[Graph, Graph, str]:
async def retrieve(self, text: str) -> Tuple[Graph, Tuple[Graph, str]]:
"""Retrieve subgraph from triplet graph and document graph."""

subgraph = MemoryGraph()
Expand All @@ -124,6 +124,9 @@ async def retrieve(self, text: str) -> Tuple[Graph, Graph, str]:
if self._enable_text_search:
# Retrieve from knowledge graph with text.
subgraph, text2gql_query = await self._text_based_graph_retriever.retrieve(text)

# Extract keywords from original question
keywords: List[str] = await self._keyword_extractor.extract(text)

if subgraph.vertex_count == 0 and subgraph.edge_count == 0:
# if not enable text search or text search failed to retrieve subgraph
Expand All @@ -135,7 +138,7 @@ async def retrieve(self, text: str) -> Tuple[Graph, Graph, str]:
vector = await self._text_embedder.embed(text)
# Embedding the keywords
vectors = await self._text_embedder.batch_embed(
keywords, batch_size=self._triplet_embedding_batch_size
keywords, batch_size=self._embedding_batch_size
)
# Using the embeddings of keywords and question
vectors.append(vector)
Expand All @@ -145,8 +148,6 @@ async def retrieve(self, text: str) -> Tuple[Graph, Graph, str]:
f"embedding vector:\n[KEYWORDS]:{keywords}\n[QUESTION]:{text}"
)
else:
# Extract keywords from original question
keywords: List[str] = await self._keyword_extractor.extract(text)
subs = keywords
logger.info(
"Search subgraph with the following keywords:\n"
Expand All @@ -166,13 +167,10 @@ async def retrieve(self, text: str) -> Tuple[Graph, Graph, str]:
if subgraph.vertex_count == 0 and subgraph.edge_count == 0:
# If not enable triplet graph or failed to retrieve subgraph
# Using subs to retrieve from document graph
subgraph_for_doc = await self._document_graph_retriever.retrieve(subs=subs)
subgraph_for_doc = await self._document_graph_retriever.retrieve(subs)
else:
# If retrieve subgraph from triplet graph successfully
# Using entities in subgraph to search chunks and doc
subgraph_for_doc = await self._document_graph_retriever.retrieve(
subs=subs,
triplet_graph=subgraph
)
subgraph_for_doc = await self._document_graph_retriever.retrieve(triplet_graph=subgraph)

return subgraph, subgraph_for_doc, text2gql_query
return subgraph, (subgraph_for_doc, text2gql_query)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from typing import Dict, List, Union, Tuple

from dbgpt.rag.transformer.intent_interpreter import IntentInterpreter
from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator
from dbgpt.rag.transformer.text2gql import Text2GQL
from dbgpt.storage.graph_store.graph import MemoryGraph, Graph
from dbgpt.storage.knowledge_graph.graph_retriever.base import GraphRetrieverBase
Expand All @@ -20,7 +20,7 @@ def __init__(self, graph_store_apdater, triplet_topk, llm_client, model_name):

self._graph_store_apdater = graph_store_apdater
self._triplet_topk = triplet_topk
self._intent_interpreter = IntentInterpreter(llm_client, model_name)
self._intent_interpreter = SimpleIntentTranslator(llm_client, model_name)
self._text2gql = Text2GQL(llm_client, model_name)

async def retrieve(self, text: str) -> Tuple[Graph, str]:
Expand Down

0 comments on commit 9a27b1d

Please sign in to comment.