Skip to content

Commit

Permalink
solve some comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
SonglinLyu committed Jan 16, 2025
1 parent 43c81ad commit e78c030
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 164 deletions.
3 changes: 2 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,15 @@ KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_RECALL_SCORE=0.7
KNOWLEDGE_GRAPH_TEXT_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0

GRAPH_COMMUNITY_SUMMARY_ENABLED=True # enable the graph community summary
TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
SIMILARITY_SEARCH_ENABLED=True # enable the similarity search for entities and chunks
TEXT2GQL_SEARCH_ENABLED=False # enable the text2gql search for entities and relations.
TEXT_SEARCH_ENABLED=False # enable the text search for entities and relations.

KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""

@abstractmethod
async def translate(self, text: str, limit: Optional[int] = None) -> Dict:
async def translate(self, text: str) -> Dict:
"""Translate results from text."""
8 changes: 4 additions & 4 deletions dbgpt/rag/transformer/intent_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import re
from typing import Dict, List, Optional
from typing import Dict, List

from dbgpt.core import BaseMessage, HumanPromptTemplate, LLMClient
from dbgpt.rag.transformer.llm_translator import LLMTranslator
Expand Down Expand Up @@ -69,10 +69,10 @@ def __init__(self, llm_client: LLMClient, model_name: str):
super().__init__(llm_client, model_name, INTENT_INTERPRET_PT)

def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]:
# interprete intent with single prompt only.
# interprete intention with single prompt only.
template = HumanPromptTemplate.from_template(self._prompt_template)

messages = (
messages: List[BaseMessage] = (
template.format_messages(text=text, history=history)
if history is not None
else template.format_messages(text=text)
Expand All @@ -86,7 +86,7 @@ def truncate(self):
def drop(self):
"""Do nothing by default."""

def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict:
def _parse_response(self, text: str) -> Dict:
"""Parse llm response."""
"""
The returned diction should contain the following content.
Expand Down
6 changes: 4 additions & 2 deletions dbgpt/rag/transformer/llm_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ async def _extract(
self, text: str, history: str = None, limit: Optional[int] = None
) -> List:
"""Inner extract by LLM."""
# limit check
if limit and limit < 1:
ValueError("optional argument limit >= 1")

template = HumanPromptTemplate.from_template(self._prompt_template)

messages = (
Expand All @@ -80,8 +84,6 @@ async def _extract(
logger.error(f"request llm failed ({code}) {reason}")
return []

if limit and limit < 1:
ValueError("optional argument limit >= 1")
return self._parse_response(response.text, limit)

def truncate(self):
Expand Down
16 changes: 6 additions & 10 deletions dbgpt/rag/transformer/llm_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from typing import Dict, List

from dbgpt.core import BaseMessage, LLMClient, ModelMessage, ModelRequest
from dbgpt.rag.transformer.base import TranslatorBase
Expand All @@ -19,14 +19,12 @@ def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str)
self._model_name = model_name
self._prompt_template = prompt_template

async def translate(self, text: str, limit: Optional[int] = None) -> Dict:
async def translate(self, text: str) -> Dict:
"""Translate by LLM."""
messages = self._format_messages(text)
return await self._translate(messages, limit)
return await self._translate(messages)

async def _translate(
self, messages: List[BaseMessage], limit: Optional[int] = None
) -> Dict:
async def _translate(self, messages: List[BaseMessage]) -> Dict:
"""Inner translate by LLM."""
# use default model if needed
if not self._model_name:
Expand All @@ -46,9 +44,7 @@ async def _translate(
logger.error(f"request llm failed ({code}) {reason}")
return {}

if limit and limit < 1:
ValueError("optional argument limit >= 1")
return self._parse_response(response.text, limit)
return self._parse_response(response.text)

def truncate(self):
"""Do nothing by default."""
Expand All @@ -61,5 +57,5 @@ def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]:
"""Parse llm response."""

@abstractmethod
def _parse_response(self, text: str, limit: Optional[int] = None) -> Dict:
def _parse_response(self, text: str) -> Dict:
"""Parse llm response."""
119 changes: 0 additions & 119 deletions dbgpt/rag/transformer/text2cypher.py

This file was deleted.

111 changes: 109 additions & 2 deletions dbgpt/rag/transformer/text2gql.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,117 @@
"""Text2GQL class."""
import json
import logging
import re
from typing import Dict, List

from dbgpt.rag.transformer.base import TranslatorBase
from dbgpt.core import BaseMessage, HumanPromptTemplate, LLMClient
from dbgpt.rag.transformer.llm_translator import LLMTranslator

TEXT_TO_GQL_PT = (
"A question written in graph query language style is provided below. "
"The category of this question, "
"entities and relations that might be used in the cypher query are also provided. "
"Given the question, translate the question into a cypher query that "
"can be executed on the given knowledge graph. "
"Make sure the syntax of the translated cypher query is correct.\n"
"To help query generation, the schema of the knowledge graph is:\n"
"{schema}\n"
"---------------------\n"
"Example:\n"
"Question: Query the entity named TuGraph then return the entity.\n"
"Category: Single Entity Search\n"
'entities: ["TuGraph"]\n'
"relations: []\n"
'Query:\nMatch (n) WHERE n.id="TuGraph" RETURN n\n'
"Question: Query all one hop paths between the entity named Alex "
"and the entity named TuGraph, then return them.\n"
"Category: One Hop Entity Search\n"
'entities: ["Alex", "TuGraph"]\n'
"relations: []\n"
'Query:\nMATCH p=(n)-[r]-(m) WHERE n.id="Alex" AND m.id="TuGraph" RETURN p \n'
"Question: Query all one hop paths that has a entity named TuGraph "
"and a relation named commit, then return them.\n"
"Category: One Hop Relation Search\n"
'entities: ["TuGraph"]\n'
'relations: ["commit"]\n'
'Query:\nMATCH p=(n)-[r]-(m) WHERE n.id="TuGraph" AND r.id="commit" RETURN p \n'
"Question: Query all entities that have a two hop path between them "
"and the entity named Bob, "
"both entities should have a work for relation with the middle entity.\n"
"Category: Two Hop Entity Search\n"
'entities: ["Bob"]\n'
'relations: ["work for"]\n'
'Query:\nMATCH p=(n)-[r1]-(m)-[r2]-(l) WHERE n.id="Bob" '
'AND r1.id="work for" AND r2.id="work for" RETURN p \n'
"Question: Introduce TuGraph and DBGPT seperately.\n"
"Category: Freestyle Question\n"
'entities: ["TuGraph", "DBGPT"]\n'
"relations: []\n"
"Query:\nMATCH p=(n)-[r:relation*2]-(m) "
'WHERE n.id IN ["TuGraph", "DB-GPT"] RETURN p\n'
"---------------------\n"
"Question: {question}\n"
"Category: {category}\n"
"entities: {entities}\n"
"relations: {relations}\n"
"Query:\n"
)

logger = logging.getLogger(__name__)


class Text2GQL(TranslatorBase):
class Text2GQL(LLMTranslator):
"""Text2GQL class."""

def __init__(self, llm_client: LLMClient, model_name: str):
"""Initialize the Text2GQL."""
super().__init__(llm_client, model_name, TEXT_TO_GQL_PT)

def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]:
# translate intention to gql with single prompt only.
intention = json.loads(text)
question = intention["rewritten_question"]
category = intention["category"]
entities = intention["entities"]
relations = intention["relations"]
schema = intention["schema"]

template = HumanPromptTemplate.from_template(self._prompt_template)

messages = (
template.format_messages(
schema=schema,
question=question,
category=category,
entities=entities,
relations=relations,
history=history,
)
if history is not None
else template.format_messages(
schema=schema,
question=question,
category=category,
entities=entities,
relations=relations,
)
)

return messages

def _parse_response(self, text: str) -> Dict:
"""Parse llm response."""
translation = {}
query = ""

code_block_pattern = re.compile(r"```cypher(.*?)```", re.S)

result = re.findall(code_block_pattern, text)
if result:
query = result[0]
else:
query = text

translation["query"] = query.strip()

return translation
6 changes: 3 additions & 3 deletions dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class GraphStoreConfig(BaseModel):
default=False,
description="Enable similarity search or not.",
)
enable_text2gql_search: bool = Field(
enable_text_search: bool = Field(
default=False,
description="Enable text2gql search or not.",
description="Enable text search or not.",
)


Expand All @@ -46,7 +46,7 @@ def __init__(self, config: GraphStoreConfig):
self._conn = None
self.enable_summary = config.enable_summary
self.enable_similarity_search = config.enable_similarity_search
self.enable_text2gql_search = config.enable_text2gql_search
self.enable_text_search = config.enable_text_search

@abstractmethod
def get_config(self) -> GraphStoreConfig:
Expand Down
Loading

0 comments on commit e78c030

Please sign in to comment.