# -*- coding: utf-8 -*- # # @author: lyg # @date: 2025-5-12 # @version: 1 # @description: 文档数据库助手,mysql数据库 import json from threading import RLock from knowledgebase.db.doc_db_models import init_doc_db, TDoc, TEntity, TParagraph, TParagraphLink, TParagraphRefLink, \ TParagraphEntityLink from knowledgebase.doc.models import ParagraphInfo, DocInfo class DocDbHelper: """ 文档数据库助手 """ lock = RLock() def __init__(self): self.session = None def set_project_path(self, project_path): self.session = init_doc_db(project_path) def add_doc(self, doc_info: DocInfo) -> int: """ 添加文档 """ _doc = TDoc( file=doc_info.file, file_name=doc_info.file_name, is_del=0, ) self.session.add(_doc) self.session.commit() doc_info.id = _doc.id return _doc.id def add_paragraph(self, doc_id: int, parent_id: int, paragraph_info: ParagraphInfo) -> TParagraph: """ 添加段落 :param doc_id: 文档id :param parent_id: 父段落id :param paragraph_info: 段落信息 """ _paragraph = TParagraph( doc_id=doc_id, text=paragraph_info.text, title_level=paragraph_info.title_level, title_num=paragraph_info.title_num, num=paragraph_info.num, num_level=paragraph_info.num_level, parent_id=parent_id, is_del=0, ) self.session.add(_paragraph) self.session.commit() if parent_id is not None: paragraph_link = TParagraphLink(parent_id=parent_id, child_id=_paragraph.id) self.add_paragraph_link(paragraph_link) if paragraph_info.entities: for entity in paragraph_info.entities: self.add_paragraph_entity_link(TParagraphEntityLink(paragraph_id=_paragraph.id, entity_id=entity.id)) if paragraph_info.children: for child in paragraph_info.children: self.add_paragraph(doc_id, _paragraph.id, child) paragraph_info.id = _paragraph.id return _paragraph def add_paragraph_link(self, paragraph_link): """ 添加段落关系 :param paragraph_link: 段落关系 """ self.session.add(paragraph_link) self.session.commit() return paragraph_link.id def add_paragraph_ref_link(self, paren_id: int, child_id: int) -> int: """ 添加段落引用关系 :param paren_id: 引用段落 :param child_id: 被引用段落 :return: """ link = TParagraphRefLink(parent_id=paren_id, child_id=child_id, is_del=0) self.session.add(link) self.session.commit() return link.id def add_paragraph_entity_link(self, paragraph_entity_link): """ 添加段落实体关系 :param paragraph_entity_link: 段落实体关系 """ self.session.add(paragraph_entity_link) self.session.commit() return paragraph_entity_link.id def get_entity(self, entity): with self.lock: ret = self.session.query(TEntity).where( TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first() if ret: return ret def add_entity(self, entity): """ 添加实体 :param entity: 实体 """ self.session.add(entity) self.session.commit() return entity.id def get_all_entities(self) -> list[TEntity]: with self.lock: return self.session.query(TEntity).all() def get_docs(self) -> list[TDoc]: with self.lock: return self.session.query(TDoc).all() def get_texts_with_entities(self, entity_names: list[str]): """ 根据实体词获取文本内容列表 :param entity_names: list[str] - 实体词 :return: list[str] - 文本列表 """ with self.lock: if not entity_names: return "" _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all() _entitie_ids = [entity.id for entity in _entities] links = self.session.query(TParagraphEntityLink).where( TParagraphEntityLink.entity_id.in_(_entitie_ids)).all() _paragraphs:[TParagraph] = [link.paragraph for link in links] ref_paragraphs = [] for p in _paragraphs: ref_paragraphs.extend([x.child for x in p.ref_links]) _paragraphs.extend(ref_paragraphs) id_map = {} result = [] for p in _paragraphs: if p.id in id_map: continue else: id_map[p.id] = p result.append(p) return [p.text for p in result] def get_text_with_entities(self, entity_names: list[str]) -> str: """ 根据实体词获取文本内容 :param entity_names: list[str] - 实体词 :return: str - 文本 """ texts = self.get_texts_with_entities(entity_names) return '\n'.join(texts) def get_entities_by_names(self, names: list[str]): _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all() return _entities def get_paragraph_full_text(self, p: TParagraph): result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children]) def get_entities_by_doc_type(self, doc_type): with self.lock: _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all() return _entities def get_entities_by_type(self, ty: str)->list[TEntity]: with self.lock: _entities = self.session.query(TEntity).where(TEntity.type == ty).all() return _entities def commit(self): self.session.commit() doc_dbh = DocDbHelper()