lyg
2 天以前 22f370322412074174cde20ecfd14ec03657ab63
knowledgebase/db/doc_db_helper.py
@@ -6,6 +6,7 @@
# @description: 文档数据库助手,mysql数据库
import json
from threading import RLock
from knowledgebase.db.doc_db_models import init_doc_db, TDoc, TEntity, TParagraph, TParagraphLink, TParagraphRefLink, \
    TParagraphEntityLink
@@ -17,9 +18,13 @@
    """
    文档数据库助手
    """
    lock = RLock()
    def __init__(self):
        self.session = init_doc_db()
        self.session = None
    def set_project_path(self, project_path):
        self.session = init_doc_db(project_path)
    def add_doc(self, doc_info: DocInfo) -> int:
        """
@@ -32,6 +37,7 @@
        )
        self.session.add(_doc)
        self.session.commit()
        doc_info.id = _doc.id
        return _doc.id
    def add_paragraph(self, doc_id: int, parent_id: int, paragraph_info: ParagraphInfo) -> TParagraph:
@@ -62,6 +68,7 @@
        if paragraph_info.children:
            for child in paragraph_info.children:
                self.add_paragraph(doc_id, _paragraph.id, child)
        paragraph_info.id = _paragraph.id
        return _paragraph
    def add_paragraph_link(self, paragraph_link):
@@ -73,6 +80,18 @@
        self.session.commit()
        return paragraph_link.id
    def add_paragraph_ref_link(self, paren_id: int, child_id: int) -> int:
        """
        添加段落引用关系
        :param paren_id: 引用段落
        :param child_id: 被引用段落
        :return:
        """
        link = TParagraphRefLink(parent_id=paren_id, child_id=child_id, is_del=0)
        self.session.add(link)
        self.session.commit()
        return link.id
    def add_paragraph_entity_link(self, paragraph_entity_link):
        """
        添加段落实体关系
@@ -81,6 +100,13 @@
        self.session.add(paragraph_entity_link)
        self.session.commit()
        return paragraph_entity_link.id
    def get_entity(self, entity):
        with self.lock:
            ret = self.session.query(TEntity).where(
                TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first()
            if ret:
                return ret
    def add_entity(self, entity):
        """
@@ -91,20 +117,43 @@
        self.session.commit()
        return entity.id
    def add_paragraph_ref_link(self, paragraph_ref_link):
        """
        添加段落引用关系
        :param paragraph_ref_link: 段落引用关系
        """
        self.session.add(paragraph_ref_link)
        self.session.commit()
        return paragraph_ref_link
    def get_all_entities(self) -> list[TEntity]:
        return self.session.query(TEntity).all()
        with self.lock:
            return self.session.query(TEntity).all()
    def get_docs(self) -> list[TDoc]:
        return self.session.query(TDoc).all()
        with self.lock:
            return self.session.query(TDoc).all()
    def get_texts_with_entities(self, entity_names: list[str]):
        """
        根据实体词获取文本内容列表
        :param entity_names: list[str] - 实体词
        :return: list[str] - 文本列表
        """
        with self.lock:
            if not entity_names:
                return ""
            _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all()
            _entitie_ids = [entity.id for entity in _entities]
            links = self.session.query(TParagraphEntityLink).where(
                TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
            _paragraphs:[TParagraph] = [link.paragraph for link in links]
            ref_paragraphs = []
            for p in _paragraphs:
                ref_paragraphs.extend([x.child for x in p.ref_links])
            _paragraphs.extend(ref_paragraphs)
            id_map = {}
            result = []
            for p in _paragraphs:
                if p.id in id_map:
                    continue
                else:
                    id_map[p.id] = p
                    result.append(p)
            return [p.text for p in result]
    def get_text_with_entities(self, entity_names: list[str]) -> str:
        """
@@ -112,22 +161,29 @@
        :param entity_names: list[str] - 实体词
        :return: str - 文本
        """
        if not entity_names:
            return ""
        texts = self.get_texts_with_entities(entity_names)
        return '\n'.join(texts)
        return '\n'.join([entity.name for entity in self.get_all_entities() if entity.name in entity_names])
    def get_entities_by_names(self, names: list[str]):
        _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all()
        return _entities
    def get_paragraph_full_text(self, p: TParagraph):
        result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text
        return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
    def get_entities_by_doc_type(self, doc_type):
        with self.lock:
            _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all()
        return _entities
    def get_entities_by_type(self, ty: str)->list[TEntity]:
        with self.lock:
            _entities = self.session.query(TEntity).where(TEntity.type == ty).all()
        return _entities
    def commit(self):
        self.session.commit()
doc_dbh = DocDbHelper()
# if __name__ == '__main__':
#     doc_db = DocDbHelper()
#     # doc_db.insert_entities()
#     doc = doc_db.add_doc(DocInfo(file='aaa', file_name='test'))
#     p1 = doc_db.add_paragraph(doc.id, None, ParagraphInfo(text='test1', title_level=1, num=1, num_level=1))
#     p2 = doc_db.add_paragraph(doc.id, p1.id, ParagraphInfo(text='test2', title_level=2, num=1, num_level=2))
#     p3 = doc_db.add_paragraph(doc.id, p2.id, ParagraphInfo(text='test3', title_level=3, num=1, num_level=3))
#     doc_db.add_paragraph_ref_link(TParagraphRefLink(parent_id=p1.id, child_id=p3.id))