# -*- coding: utf-8 -*-
|
#
|
# @author: lyg
|
# @date: 2025-5-12
|
# @version: 1
|
# @description: 文档数据库助手,mysql数据库
|
|
import json
|
from threading import RLock
|
|
from knowledgebase.db.doc_db_models import init_doc_db, TDoc, TEntity, TParagraph, TParagraphLink, TParagraphRefLink, \
|
TParagraphEntityLink
|
|
from knowledgebase.doc.models import ParagraphInfo, DocInfo
|
|
|
class DocDbHelper:
|
"""
|
文档数据库助手
|
"""
|
lock = RLock()
|
|
def __init__(self):
|
self.session = None
|
|
def set_project_path(self, project_path):
|
self.session = init_doc_db(project_path)
|
|
def add_doc(self, doc_info: DocInfo) -> int:
|
"""
|
添加文档
|
"""
|
_doc = TDoc(
|
file=doc_info.file,
|
file_name=doc_info.file_name,
|
is_del=0,
|
)
|
self.session.add(_doc)
|
self.session.commit()
|
doc_info.id = _doc.id
|
return _doc.id
|
|
def add_paragraph(self, doc_id: int, parent_id: int, paragraph_info: ParagraphInfo) -> TParagraph:
|
"""
|
添加段落
|
:param doc_id: 文档id
|
:param parent_id: 父段落id
|
:param paragraph_info: 段落信息
|
"""
|
_paragraph = TParagraph(
|
doc_id=doc_id,
|
text=paragraph_info.text,
|
title_level=paragraph_info.title_level,
|
title_num=paragraph_info.title_num,
|
num=paragraph_info.num,
|
num_level=paragraph_info.num_level,
|
parent_id=parent_id,
|
is_del=0,
|
)
|
self.session.add(_paragraph)
|
self.session.commit()
|
if parent_id is not None:
|
paragraph_link = TParagraphLink(parent_id=parent_id, child_id=_paragraph.id)
|
self.add_paragraph_link(paragraph_link)
|
if paragraph_info.entities:
|
for entity in paragraph_info.entities:
|
self.add_paragraph_entity_link(TParagraphEntityLink(paragraph_id=_paragraph.id, entity_id=entity.id))
|
if paragraph_info.children:
|
for child in paragraph_info.children:
|
self.add_paragraph(doc_id, _paragraph.id, child)
|
paragraph_info.id = _paragraph.id
|
return _paragraph
|
|
def add_paragraph_link(self, paragraph_link):
|
"""
|
添加段落关系
|
:param paragraph_link: 段落关系
|
"""
|
self.session.add(paragraph_link)
|
self.session.commit()
|
return paragraph_link.id
|
|
def add_paragraph_ref_link(self, paren_id: int, child_id: int) -> int:
|
"""
|
添加段落引用关系
|
:param paren_id: 引用段落
|
:param child_id: 被引用段落
|
:return:
|
"""
|
link = TParagraphRefLink(parent_id=paren_id, child_id=child_id, is_del=0)
|
self.session.add(link)
|
self.session.commit()
|
return link.id
|
|
def add_paragraph_entity_link(self, paragraph_entity_link):
|
"""
|
添加段落实体关系
|
:param paragraph_entity_link: 段落实体关系
|
"""
|
self.session.add(paragraph_entity_link)
|
self.session.commit()
|
return paragraph_entity_link.id
|
|
def get_entity(self, entity):
|
with self.lock:
|
ret = self.session.query(TEntity).where(
|
TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first()
|
if ret:
|
return ret
|
|
def add_entity(self, entity):
|
"""
|
添加实体
|
:param entity: 实体
|
"""
|
self.session.add(entity)
|
self.session.commit()
|
return entity.id
|
|
def get_all_entities(self) -> list[TEntity]:
|
with self.lock:
|
return self.session.query(TEntity).all()
|
|
def get_docs(self) -> list[TDoc]:
|
with self.lock:
|
return self.session.query(TDoc).all()
|
|
|
def get_texts_with_entities(self, entity_names: list[str]):
|
"""
|
根据实体词获取文本内容列表
|
:param entity_names: list[str] - 实体词
|
:return: list[str] - 文本列表
|
"""
|
with self.lock:
|
if not entity_names:
|
return ""
|
_entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all()
|
_entitie_ids = [entity.id for entity in _entities]
|
links = self.session.query(TParagraphEntityLink).where(
|
TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
|
_paragraphs:[TParagraph] = [link.paragraph for link in links]
|
ref_paragraphs = []
|
for p in _paragraphs:
|
ref_paragraphs.extend([x.child for x in p.ref_links])
|
_paragraphs.extend(ref_paragraphs)
|
id_map = {}
|
result = []
|
for p in _paragraphs:
|
if p.id in id_map:
|
continue
|
else:
|
id_map[p.id] = p
|
result.append(p)
|
return [p.text for p in result]
|
|
|
def get_text_with_entities(self, entity_names: list[str]) -> str:
|
"""
|
根据实体词获取文本内容
|
:param entity_names: list[str] - 实体词
|
:return: str - 文本
|
"""
|
texts = self.get_texts_with_entities(entity_names)
|
return '\n'.join(texts)
|
|
def get_entities_by_names(self, names: list[str]):
|
_entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all()
|
return _entities
|
|
def get_paragraph_full_text(self, p: TParagraph):
|
result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text
|
return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
|
|
def get_entities_by_doc_type(self, doc_type):
|
with self.lock:
|
_entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all()
|
return _entities
|
|
def get_entities_by_type(self, ty: str)->list[TEntity]:
|
with self.lock:
|
_entities = self.session.query(TEntity).where(TEntity.type == ty).all()
|
return _entities
|
|
def commit(self):
|
self.session.commit()
|
|
|
doc_dbh = DocDbHelper()
|