lyg
2 天以前 22f370322412074174cde20ecfd14ec03657ab63
knowledgebase/doc/entity_recognition.py
@@ -5,20 +5,14 @@
# @version: 0.0.1
# @description: 实体抽取,将文本中的实体进行识别和提取。
from langchain_openai.chat_models import ChatOpenAI
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
import json
from knowledgebase import utils
from knowledgebase.doc.entity_helper import entity_helper
from knowledgebase.db.doc_db_helper import doc_dbh
from knowledgebase.log import Log
llm = ChatOpenAI(temperature=0,
                 model="qwen2.5-72b-instruct",
                 base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
                 api_key="sk-15ecf7e273ad4b729c7f7f42b542749e")
from knowledgebase.llm import llm
class EntityRecognition:
    """
@@ -31,7 +25,7 @@
    def __init__(self, doc_type: str):
        # 实体词列表
        entities = list(filter(lambda x: x.doc_type == doc_type, entity_helper.entities))
        entities = doc_dbh.get_entities_by_doc_type(doc_type)
        entity_list = ','.join([entity.name for entity in entities]) + "。"
        entity_rules = ";\n".join([f"- {entity.name}:{entity.prompts}" for entity in entities]) + "。"
        tpl = """
@@ -43,7 +37,7 @@
# 约束
- 输出格式为JSON格式;
- 提取的实体词必须是:""" + entity_list + """;
- 如果没有复合上述规则的实体词则不要输出任何实体词;
- 如果没有符合上述规则的实体词则不要输出任何实体词;
- 输出数据结构为字符串数组。
# 示例
```json
@@ -74,9 +68,9 @@
        保存缓存。
        """
        text = json.dumps(self.cache)
        utils.save_to_file(text, self.cache_file)
        utils.save_text_to_file(text, self.cache_file)
    def run(self, in_text: str) -> list[str]:
    async def run(self, in_text: str) -> list[str]:
        """
        运行实体识别抽取。
        :param in_text: str - 输入文本
@@ -85,7 +79,7 @@
        text_md5 = utils.generate_text_md5(in_text)
        if self.use_cache and text_md5 in self.cache:
            return self.cache[text_md5]
        result = self.chain.invoke({"text": in_text})
        result = await self.chain.ainvoke({"text": in_text})
        self.cache[text_md5] = result
        self.save_cache()
        return result