lyg
2025-05-14 b75a49c22e7d2b9aa8d3dc4975df8801c52b4d5b
修改文档拆分和实体词提取逻辑,增加实体词文本抽取
4个文件已修改
39 ■■■■ 已修改文件
knowledgebase/db/doc_db_helper.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/doc/doc_processor.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/doc/entity_helper.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/doc/entity_recognition.py 24 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/db/doc_db_helper.py
@@ -114,8 +114,16 @@
        """
        if not entity_names:
            return ""
        _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all()
        _entitie_ids = [entity.id for entity in _entities]
        links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
        _paragraphs = [link.paragraph for link in links]
        return '\n'.join([entity.name for entity in self.get_all_entities() if entity.name in entity_names])
        return '\n'.join([self.get_paragraph_full_text(p) for p in _paragraphs])
    def get_paragraph_full_text(self, p: TParagraph):
        result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text
        return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
    def commit(self):
        self.session.commit()
@@ -124,6 +132,8 @@
doc_dbh = DocDbHelper()
# if __name__ == '__main__':
#     text = doc_dbh.get_text_with_entities(['遥控包格式'])
#     print(text)
#     doc_db = DocDbHelper()
#     # doc_db.insert_entities()
#     doc = doc_db.add_doc(DocInfo(file='aaa', file_name='test'))
knowledgebase/doc/doc_processor.py
@@ -96,7 +96,7 @@
        r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机软件用户需求(公开).docx",
        r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机遥测大纲(公开).docx",
        r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机遥测信号分配表(公开).docx",
        r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机指令格式与编码定义(公开).docx",
        # r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机指令格式与编码定义(公开).docx",
        r"D:\workspace\PythonProjects\KnowledgeBase\doc\指令格式(公开).docx"
    ]
    for file in files:
knowledgebase/doc/entity_helper.py
@@ -38,6 +38,7 @@
                        _entity = TEntity(name=entity, type=ty, doc_type=doc_ty,
                                          prompts=obj2[doc_ty]['entities'][entity])
                        doc_dbh.add_entity(_entity)
                        self.entities.append(_entity)
                        Log.info(f"新增Entity:{entity},id:{_entity.id}")
knowledgebase/doc/entity_recognition.py
@@ -12,6 +12,7 @@
from knowledgebase import utils
from knowledgebase.doc.entity_helper import entity_helper
from knowledgebase.log import Log
llm = ChatOpenAI(temperature=0,
                 model="qwen2.5-72b-instruct",
@@ -25,30 +26,35 @@
    使用langchain构建实体抽取流程。
    """
    use_cache = False
    cache_file = "entity_recognition.cache"
    def __init__(self, doc_type: str):
        # 实体词列表
        entities = filter(lambda x: x.doc_type == doc_type, entity_helper.entities)
        entity_list = ';\n'.join([f'- {entity.name}:{entity.prompts}' for entity in entities]) + "。"
        msg = HumanMessagePromptTemplate.from_template(template="""
        entities = list(filter(lambda x: x.doc_type == doc_type, entity_helper.entities))
        entity_list = ','.join([entity.name for entity in entities]) + "。"
        entity_rules = ";\n".join([f"- {entity.name}:{entity.prompts}" for entity in entities]) + "。"
        tpl = """
# 指令
请从给定的文本中提取实体词列表,实体词列表定义如下:
## 实体词列表及识别规则
请根据实体词判断规则从给定的文本中判断是否有下列实体词相关内容,如果有则输出相关的实体词,没有则不输出,实体词列表定义如下:
""" + entity_list + """
## 实体词判断规则:
""" + entity_rules + """
# 约束
- 输出格式为JSON格式;
- 提取的实体词必须是上面列举的实体词;
- 提取的实体词必须是:""" + entity_list + """;
- 如果没有复合上述规则的实体词则不要输出任何实体词;
- 输出数据结构为字符串数组。
# 示例
```json
["遥控帧格式","遥控包格式"]
[\"""" + entities[0].name + """\"]
```
# 文本如下:
{text}
"""
                                                       )
        Log.info(tpl)
        msg = HumanMessagePromptTemplate.from_template(template=tpl)
        prompt = ChatPromptTemplate.from_messages([msg])
        parser = JsonOutputParser(pydantic_object=list[str])
        self.chain = prompt | llm | parser
@@ -77,7 +83,7 @@
        """
        # 缓存命中
        text_md5 = utils.generate_text_md5(in_text)
        if text_md5 in self.cache:
        if self.use_cache and text_md5 in self.cache:
            return self.cache[text_md5]
        result = self.chain.invoke({"text": in_text})
        self.cache[text_md5] = result