lyg
2025-05-14 b75a49c22e7d2b9aa8d3dc4975df8801c52b4d5b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# -*- coding: utf-8 -*-
# @author: lyg
# @date: 2025-5-12
# @version: 1
# @description: 文档数据库助手,mysql数据库
 
import json
 
from knowledgebase.db.doc_db_models import init_doc_db, TDoc, TEntity, TParagraph, TParagraphLink, TParagraphRefLink, \
    TParagraphEntityLink
 
from knowledgebase.doc.models import ParagraphInfo, DocInfo
 
 
class DocDbHelper:
    """
    文档数据库助手
    """
 
    def __init__(self):
        self.session = init_doc_db()
 
    def add_doc(self, doc_info: DocInfo) -> int:
        """
        添加文档
        """
        _doc = TDoc(
            file=doc_info.file,
            file_name=doc_info.file_name,
            is_del=0,
        )
        self.session.add(_doc)
        self.session.commit()
        return _doc.id
 
    def add_paragraph(self, doc_id: int, parent_id: int, paragraph_info: ParagraphInfo) -> TParagraph:
        """
        添加段落
        :param doc_id: 文档id
        :param parent_id: 父段落id
        :param paragraph_info: 段落信息
        """
        _paragraph = TParagraph(
            doc_id=doc_id,
            text=paragraph_info.text,
            title_level=paragraph_info.title_level,
            title_num=paragraph_info.title_num,
            num=paragraph_info.num,
            num_level=paragraph_info.num_level,
            parent_id=parent_id,
            is_del=0,
        )
        self.session.add(_paragraph)
        self.session.commit()
        if parent_id is not None:
            paragraph_link = TParagraphLink(parent_id=parent_id, child_id=_paragraph.id)
            self.add_paragraph_link(paragraph_link)
        if paragraph_info.entities:
            for entity in paragraph_info.entities:
                self.add_paragraph_entity_link(TParagraphEntityLink(paragraph_id=_paragraph.id, entity_id=entity.id))
        if paragraph_info.children:
            for child in paragraph_info.children:
                self.add_paragraph(doc_id, _paragraph.id, child)
        return _paragraph
 
    def add_paragraph_link(self, paragraph_link):
        """
        添加段落关系
        :param paragraph_link: 段落关系
        """
        self.session.add(paragraph_link)
        self.session.commit()
        return paragraph_link.id
 
    def add_paragraph_entity_link(self, paragraph_entity_link):
        """
        添加段落实体关系
        :param paragraph_entity_link: 段落实体关系
        """
        self.session.add(paragraph_entity_link)
        self.session.commit()
        return paragraph_entity_link.id
 
    def add_entity(self, entity):
        """
        添加实体
        :param entity: 实体
        """
        self.session.add(entity)
        self.session.commit()
        return entity.id
 
    def add_paragraph_ref_link(self, paragraph_ref_link):
        """
        添加段落引用关系
        :param paragraph_ref_link: 段落引用关系
        """
        self.session.add(paragraph_ref_link)
        self.session.commit()
        return paragraph_ref_link
 
    def get_all_entities(self) -> list[TEntity]:
        return self.session.query(TEntity).all()
 
    def get_docs(self) -> list[TDoc]:
        return self.session.query(TDoc).all()
 
    def get_text_with_entities(self, entity_names: list[str]) -> str:
        """
        根据实体词获取文本内容
        :param entity_names: list[str] - 实体词
        :return: str - 文本
        """
        if not entity_names:
            return ""
        _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all()
        _entitie_ids = [entity.id for entity in _entities]
        links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
        _paragraphs = [link.paragraph for link in links]
 
        return '\n'.join([self.get_paragraph_full_text(p) for p in _paragraphs])
 
    def get_paragraph_full_text(self, p: TParagraph):
        result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text
        return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
 
    def commit(self):
        self.session.commit()
 
 
doc_dbh = DocDbHelper()
 
# if __name__ == '__main__':
#     text = doc_dbh.get_text_with_entities(['遥控包格式'])
#     print(text)
#     doc_db = DocDbHelper()
#     # doc_db.insert_entities()
#     doc = doc_db.add_doc(DocInfo(file='aaa', file_name='test'))
#     p1 = doc_db.add_paragraph(doc.id, None, ParagraphInfo(text='test1', title_level=1, num=1, num_level=1))
#     p2 = doc_db.add_paragraph(doc.id, p1.id, ParagraphInfo(text='test2', title_level=2, num=1, num_level=2))
#     p3 = doc_db.add_paragraph(doc.id, p2.id, ParagraphInfo(text='test3', title_level=3, num=1, num_level=3))
#     doc_db.add_paragraph_ref_link(TParagraphRefLink(parent_id=p1.id, child_id=p3.id))