# -*- coding: utf-8 -*- # # @author: lyg # @date: 2025-5-12 # @version: 1 # @description: 文档数据库助手,mysql数据库 import json from knowledgebase.db.doc_db_models import init_doc_db, TDoc, TEntity, TParagraph, TParagraphLink, TParagraphRefLink, \ TParagraphEntityLink from knowledgebase.doc.models import ParagraphInfo, DocInfo class DocDbHelper: """ 文档数据库助手 """ def __init__(self): self.session = init_doc_db() def add_doc(self, doc_info: DocInfo) -> int: """ 添加文档 """ _doc = TDoc( file=doc_info.file, file_name=doc_info.file_name, is_del=0, ) self.session.add(_doc) self.session.commit() return _doc.id def add_paragraph(self, doc_id: int, parent_id: int, paragraph_info: ParagraphInfo) -> TParagraph: """ 添加段落 :param doc_id: 文档id :param parent_id: 父段落id :param paragraph_info: 段落信息 """ _paragraph = TParagraph( doc_id=doc_id, text=paragraph_info.text, title_level=paragraph_info.title_level, title_num=paragraph_info.title_num, num=paragraph_info.num, num_level=paragraph_info.num_level, parent_id=parent_id, is_del=0, ) self.session.add(_paragraph) self.session.commit() if parent_id is not None: paragraph_link = TParagraphLink(parent_id=parent_id, child_id=_paragraph.id) self.add_paragraph_link(paragraph_link) if paragraph_info.entities: for entity in paragraph_info.entities: self.add_paragraph_entity_link(TParagraphEntityLink(paragraph_id=_paragraph.id, entity_id=entity.id)) if paragraph_info.children: for child in paragraph_info.children: self.add_paragraph(doc_id, _paragraph.id, child) return _paragraph def add_paragraph_link(self, paragraph_link): """ 添加段落关系 :param paragraph_link: 段落关系 """ self.session.add(paragraph_link) self.session.commit() return paragraph_link.id def add_paragraph_entity_link(self, paragraph_entity_link): """ 添加段落实体关系 :param paragraph_entity_link: 段落实体关系 """ self.session.add(paragraph_entity_link) self.session.commit() return paragraph_entity_link.id def get_entity(self, entity): ret = self.session.query(TEntity).where( TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first() if ret: return ret def add_entity(self, entity): """ 添加实体 :param entity: 实体 """ self.session.add(entity) self.session.commit() return entity.id def add_paragraph_ref_link(self, paragraph_ref_link): """ 添加段落引用关系 :param paragraph_ref_link: 段落引用关系 """ self.session.add(paragraph_ref_link) self.session.commit() return paragraph_ref_link def get_all_entities(self) -> list[TEntity]: return self.session.query(TEntity).all() def get_docs(self) -> list[TDoc]: return self.session.query(TDoc).all() def get_texts_with_entities(self, entity_names: list[str]): """ 根据实体词获取文本内容列表 :param entity_names: list[str] - 实体词 :return: list[str] - 文本列表 """ if not entity_names: return "" _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all() _entitie_ids = [entity.id for entity in _entities] links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all() _paragraphs = [link.paragraph for link in links] return [self.get_paragraph_full_text(p) for p in _paragraphs] def get_text_with_entities(self, entity_names: list[str]) -> str: """ 根据实体词获取文本内容 :param entity_names: list[str] - 实体词 :return: str - 文本 """ texts = self.get_texts_with_entities(entity_names) return '\n'.join(texts) def get_entities_by_names(self, names: list[str]): _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all() return _entities def get_paragraph_full_text(self, p: TParagraph): result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children]) def get_entities_by_doc_type(self, doc_type): _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all() return _entities def commit(self): self.session.commit() doc_dbh = DocDbHelper()