lyg
2 天以前 22f370322412074174cde20ecfd14ec03657ab63
knowledgebase/db/doc_db_helper.py
@@ -6,6 +6,7 @@
# @description: 文档数据库助手,mysql数据库
import json
from threading import RLock
from knowledgebase.db.doc_db_models import init_doc_db, TDoc, TEntity, TParagraph, TParagraphLink, TParagraphRefLink, \
    TParagraphEntityLink
@@ -17,9 +18,13 @@
    """
    文档数据库助手
    """
    lock = RLock()
    def __init__(self):
        self.session = init_doc_db()
        self.session = None
    def set_project_path(self, project_path):
        self.session = init_doc_db(project_path)
    def add_doc(self, doc_info: DocInfo) -> int:
        """
@@ -32,6 +37,7 @@
        )
        self.session.add(_doc)
        self.session.commit()
        doc_info.id = _doc.id
        return _doc.id
    def add_paragraph(self, doc_id: int, parent_id: int, paragraph_info: ParagraphInfo) -> TParagraph:
@@ -62,6 +68,7 @@
        if paragraph_info.children:
            for child in paragraph_info.children:
                self.add_paragraph(doc_id, _paragraph.id, child)
        paragraph_info.id = _paragraph.id
        return _paragraph
    def add_paragraph_link(self, paragraph_link):
@@ -73,6 +80,18 @@
        self.session.commit()
        return paragraph_link.id
    def add_paragraph_ref_link(self, paren_id: int, child_id: int) -> int:
        """
        添加段落引用关系
        :param paren_id: 引用段落
        :param child_id: 被引用段落
        :return:
        """
        link = TParagraphRefLink(parent_id=paren_id, child_id=child_id, is_del=0)
        self.session.add(link)
        self.session.commit()
        return link.id
    def add_paragraph_entity_link(self, paragraph_entity_link):
        """
        添加段落实体关系
@@ -83,6 +102,7 @@
        return paragraph_entity_link.id
    def get_entity(self, entity):
        with self.lock:
        ret = self.session.query(TEntity).where(
            TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first()
        if ret:
@@ -97,20 +117,14 @@
        self.session.commit()
        return entity.id
    def add_paragraph_ref_link(self, paragraph_ref_link):
        """
        添加段落引用关系
        :param paragraph_ref_link: 段落引用关系
        """
        self.session.add(paragraph_ref_link)
        self.session.commit()
        return paragraph_ref_link
    def get_all_entities(self) -> list[TEntity]:
        with self.lock:
        return self.session.query(TEntity).all()
    def get_docs(self) -> list[TDoc]:
        with self.lock:
        return self.session.query(TDoc).all()
    def get_texts_with_entities(self, entity_names: list[str]):
        """
@@ -118,13 +132,29 @@
        :param entity_names: list[str] - 实体词
        :return: list[str] - 文本列表
        """
        with self.lock:
        if not entity_names:
            return ""
        _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all()
        _entitie_ids = [entity.id for entity in _entities]
        links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
        _paragraphs = [link.paragraph for link in links]
        return [self.get_paragraph_full_text(p) for p in _paragraphs]
            links = self.session.query(TParagraphEntityLink).where(
                TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
            _paragraphs:[TParagraph] = [link.paragraph for link in links]
            ref_paragraphs = []
            for p in _paragraphs:
                ref_paragraphs.extend([x.child for x in p.ref_links])
            _paragraphs.extend(ref_paragraphs)
            id_map = {}
            result = []
            for p in _paragraphs:
                if p.id in id_map:
                    continue
                else:
                    id_map[p.id] = p
                    result.append(p)
            return [p.text for p in result]
    def get_text_with_entities(self, entity_names: list[str]) -> str:
        """
        根据实体词获取文本内容
@@ -143,9 +173,15 @@
        return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
    def get_entities_by_doc_type(self, doc_type):
        with self.lock:
        _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all()
        return _entities
    def get_entities_by_type(self, ty: str)->list[TEntity]:
        with self.lock:
            _entities = self.session.query(TEntity).where(TEntity.type == ty).all()
        return _entities
    def commit(self):
        self.session.commit()