# -*- coding: utf-8 -*-
|
# @file: doc_processor.py
|
# @author: lyg
|
# @date: 2025-5-13
|
# @version:
|
# @description: 处理文档,拆分文档,将拆分后的章节保存到数据库中。
|
from langchain_core.messages import HumanMessage
|
|
from knowledgebase.doc.docx_split import DocSplit
|
import asyncio
|
from knowledgebase.db.doc_db_helper import doc_dbh
|
from knowledgebase.doc.entity_helper import entity_helper
|
from knowledgebase.doc.entity_recognition import EntityRecognition
|
import os.path
|
|
from knowledgebase.doc.models import DocInfo, ParagraphInfo
|
from knowledgebase.llm import llm
|
from knowledgebase.log import Log
|
from knowledgebase import utils
|
|
|
class DocProcessor:
|
def __init__(self, docx_file: str):
|
"""
|
文档处理
|
:param docx_file: 要处理的文档
|
"""
|
Log.info(f'开始处理文档:{docx_file}')
|
self.docx_file = docx_file
|
self.doc_split = DocSplit(docx_file)
|
self.doc_type = self.get_doc_type()
|
self.entity_recognition = EntityRecognition(self.doc_type)
|
self.doc_id = 0
|
|
def get_doc_type(self):
|
Log.info(f'识别文档类型:{self.docx_file}')
|
rules = ';\n'.join([f'- {it}:{entity_helper.doc_prompt_map[it]}' for it in entity_helper.doc_prompt_map.keys()])
|
msg = HumanMessage(f'''
|
# 指令
|
请从下面的文件名中识别文档类型,如果识别失败不要输出任何字符。
|
文件名:{os.path.basename(self.docx_file)}
|
# 识别规则
|
{rules}
|
# 示例
|
遥测大纲
|
''')
|
resp = llm.invoke([msg])
|
Log.info(f'识别结果:{resp.content}')
|
return resp.content
|
|
async def gen_sect_entities(self, paragraph: ParagraphInfo):
|
# Log.info(f'生成章节实体词:{paragraph.full_text}')
|
# 获取章节实体词
|
entities = await asyncio.to_thread(lambda: self.entity_recognition.run(paragraph.full_text))
|
Log.info(f'章节实体词:{entities}')
|
if entities:
|
paragraph.entities = [next(filter(lambda x: x.name == e, entity_helper.entities), None) for e in entities]
|
paragraph.entities = [e for e in paragraph.entities if e]
|
|
def process(self):
|
self.doc_split.split()
|
# 分批并发处理,每批10个
|
batch_size = 10
|
for i in range(0, len(self.doc_split.paragraphs), batch_size):
|
batch_paragraphs = self.doc_split.paragraphs[i:i + batch_size]
|
tasks = []
|
for paragraph in batch_paragraphs:
|
tasks.append(self.gen_sect_entities(paragraph))
|
|
async def run():
|
await asyncio.gather(*tasks)
|
|
asyncio.run(run())
|
# 保存到数据库
|
self.save_to_db()
|
|
def save_to_db(self):
|
"""
|
保存段落和段落实体词关系到数据库。
|
"""
|
Log.info('保存段落和段落实体词关系到数据库...')
|
with open(self.docx_file, 'rb') as f:
|
file_bytes = f.read()
|
md5 = utils.generate_bytes_md5(file_bytes)
|
doc = DocInfo(os.path.basename(self.docx_file), md5, self.doc_type, self.doc_split.paragraph_tree)
|
self.doc_id = doc_dbh.add_doc(doc)
|
for paragraph in doc.paragraphs:
|
doc_dbh.add_paragraph(self.doc_id, None, paragraph)
|
Log.info('保存段落和段落实体词关系到数据库完成')
|
|
|
if __name__ == '__main__':
|
files = [
|
r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机1553B总线传输通信帧分配(公开).docx",
|
r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机分系统遥测源包设计报告(公开).docx",
|
r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机软件用户需求(公开).docx",
|
r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机遥测大纲(公开).docx",
|
r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机遥测信号分配表(公开).docx",
|
r"D:\workspace\PythonProjects\KnowledgeBase\doc\XA-5D无人机指令格式与编码定义(公开).docx",
|
r"D:\workspace\PythonProjects\KnowledgeBase\doc\指令格式(公开).docx"
|
]
|
for file in files:
|
doc_processor = DocProcessor(file)
|
doc_processor.process()
|
|
# doc_dbh.get_docs()
|