lyg
2 天以前 22f370322412074174cde20ecfd14ec03657ab63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# -*- coding: utf-8 -*-
# @author: lyg
# @date: 2025-5-12
# @version: 1
# @description: 文档数据库助手,mysql数据库
 
import json
from threading import RLock
 
from knowledgebase.db.doc_db_models import init_doc_db, TDoc, TEntity, TParagraph, TParagraphLink, TParagraphRefLink, \
    TParagraphEntityLink
 
from knowledgebase.doc.models import ParagraphInfo, DocInfo
 
 
class DocDbHelper:
    """
    文档数据库助手
    """
    lock = RLock()
 
    def __init__(self):
        self.session = None
 
    def set_project_path(self, project_path):
        self.session = init_doc_db(project_path)
 
    def add_doc(self, doc_info: DocInfo) -> int:
        """
        添加文档
        """
        _doc = TDoc(
            file=doc_info.file,
            file_name=doc_info.file_name,
            is_del=0,
        )
        self.session.add(_doc)
        self.session.commit()
        doc_info.id = _doc.id
        return _doc.id
 
    def add_paragraph(self, doc_id: int, parent_id: int, paragraph_info: ParagraphInfo) -> TParagraph:
        """
        添加段落
        :param doc_id: 文档id
        :param parent_id: 父段落id
        :param paragraph_info: 段落信息
        """
        _paragraph = TParagraph(
            doc_id=doc_id,
            text=paragraph_info.text,
            title_level=paragraph_info.title_level,
            title_num=paragraph_info.title_num,
            num=paragraph_info.num,
            num_level=paragraph_info.num_level,
            parent_id=parent_id,
            is_del=0,
        )
        self.session.add(_paragraph)
        self.session.commit()
        if parent_id is not None:
            paragraph_link = TParagraphLink(parent_id=parent_id, child_id=_paragraph.id)
            self.add_paragraph_link(paragraph_link)
        if paragraph_info.entities:
            for entity in paragraph_info.entities:
                self.add_paragraph_entity_link(TParagraphEntityLink(paragraph_id=_paragraph.id, entity_id=entity.id))
        if paragraph_info.children:
            for child in paragraph_info.children:
                self.add_paragraph(doc_id, _paragraph.id, child)
        paragraph_info.id = _paragraph.id
        return _paragraph
 
    def add_paragraph_link(self, paragraph_link):
        """
        添加段落关系
        :param paragraph_link: 段落关系
        """
        self.session.add(paragraph_link)
        self.session.commit()
        return paragraph_link.id
 
    def add_paragraph_ref_link(self, paren_id: int, child_id: int) -> int:
        """
        添加段落引用关系
        :param paren_id: 引用段落
        :param child_id: 被引用段落
        :return:
        """
        link = TParagraphRefLink(parent_id=paren_id, child_id=child_id, is_del=0)
        self.session.add(link)
        self.session.commit()
        return link.id
 
    def add_paragraph_entity_link(self, paragraph_entity_link):
        """
        添加段落实体关系
        :param paragraph_entity_link: 段落实体关系
        """
        self.session.add(paragraph_entity_link)
        self.session.commit()
        return paragraph_entity_link.id
 
    def get_entity(self, entity):
        with self.lock:
            ret = self.session.query(TEntity).where(
                TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first()
            if ret:
                return ret
 
    def add_entity(self, entity):
        """
        添加实体
        :param entity: 实体
        """
        self.session.add(entity)
        self.session.commit()
        return entity.id
 
    def get_all_entities(self) -> list[TEntity]:
        with self.lock:
            return self.session.query(TEntity).all()
 
    def get_docs(self) -> list[TDoc]:
        with self.lock:
            return self.session.query(TDoc).all()
 
 
    def get_texts_with_entities(self, entity_names: list[str]):
        """
        根据实体词获取文本内容列表
        :param entity_names: list[str] - 实体词
        :return: list[str] - 文本列表
        """
        with self.lock:
            if not entity_names:
                return ""
            _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all()
            _entitie_ids = [entity.id for entity in _entities]
            links = self.session.query(TParagraphEntityLink).where(
                TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
            _paragraphs:[TParagraph] = [link.paragraph for link in links]
            ref_paragraphs = []
            for p in _paragraphs:
                ref_paragraphs.extend([x.child for x in p.ref_links])
            _paragraphs.extend(ref_paragraphs)
            id_map = {}
            result = []
            for p in _paragraphs:
                if p.id in id_map:
                    continue
                else:
                    id_map[p.id] = p
                    result.append(p)
            return [p.text for p in result]
 
 
    def get_text_with_entities(self, entity_names: list[str]) -> str:
        """
        根据实体词获取文本内容
        :param entity_names: list[str] - 实体词
        :return: str - 文本
        """
        texts = self.get_texts_with_entities(entity_names)
        return '\n'.join(texts)
 
    def get_entities_by_names(self, names: list[str]):
        _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all()
        return _entities
 
    def get_paragraph_full_text(self, p: TParagraph):
        result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text
        return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
 
    def get_entities_by_doc_type(self, doc_type):
        with self.lock:
            _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all()
        return _entities
 
    def get_entities_by_type(self, ty: str)->list[TEntity]:
        with self.lock:
            _entities = self.session.query(TEntity).where(TEntity.type == ty).all()
        return _entities
 
    def commit(self):
        self.session.commit()
 
 
doc_dbh = DocDbHelper()