From 22f370322412074174cde20ecfd14ec03657ab63 Mon Sep 17 00:00:00 2001
From: lyg <1543117173@qq.com>
Date: 星期一, 07 七月 2025 16:20:25 +0800
Subject: [PATCH] 生成数据库

---
 knowledgebase/db/doc_db_helper.py |   84 ++++++++++++++++++++++++++++++------------
 1 files changed, 60 insertions(+), 24 deletions(-)

diff --git a/knowledgebase/db/doc_db_helper.py b/knowledgebase/db/doc_db_helper.py
index 6b8d6cb..a0b82b1 100644
--- a/knowledgebase/db/doc_db_helper.py
+++ b/knowledgebase/db/doc_db_helper.py
@@ -6,6 +6,7 @@
 # @description: 鏂囨。鏁版嵁搴撳姪鎵嬶紝mysql鏁版嵁搴�
 
 import json
+from threading import RLock
 
 from knowledgebase.db.doc_db_models import init_doc_db, TDoc, TEntity, TParagraph, TParagraphLink, TParagraphRefLink, \
     TParagraphEntityLink
@@ -17,9 +18,13 @@
     """
     鏂囨。鏁版嵁搴撳姪鎵�
     """
+    lock = RLock()
 
     def __init__(self):
-        self.session = init_doc_db()
+        self.session = None
+
+    def set_project_path(self, project_path):
+        self.session = init_doc_db(project_path)
 
     def add_doc(self, doc_info: DocInfo) -> int:
         """
@@ -32,6 +37,7 @@
         )
         self.session.add(_doc)
         self.session.commit()
+        doc_info.id = _doc.id
         return _doc.id
 
     def add_paragraph(self, doc_id: int, parent_id: int, paragraph_info: ParagraphInfo) -> TParagraph:
@@ -62,6 +68,7 @@
         if paragraph_info.children:
             for child in paragraph_info.children:
                 self.add_paragraph(doc_id, _paragraph.id, child)
+        paragraph_info.id = _paragraph.id
         return _paragraph
 
     def add_paragraph_link(self, paragraph_link):
@@ -73,6 +80,18 @@
         self.session.commit()
         return paragraph_link.id
 
+    def add_paragraph_ref_link(self, paren_id: int, child_id: int) -> int:
+        """
+        娣诲姞娈佃惤寮曠敤鍏崇郴
+        :param paren_id: 寮曠敤娈佃惤
+        :param child_id: 琚紩鐢ㄦ钀�
+        :return:
+        """
+        link = TParagraphRefLink(parent_id=paren_id, child_id=child_id, is_del=0)
+        self.session.add(link)
+        self.session.commit()
+        return link.id
+
     def add_paragraph_entity_link(self, paragraph_entity_link):
         """
         娣诲姞娈佃惤瀹炰綋鍏崇郴
@@ -83,10 +102,11 @@
         return paragraph_entity_link.id
 
     def get_entity(self, entity):
-        ret = self.session.query(TEntity).where(
-            TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first()
-        if ret:
-            return ret
+        with self.lock:
+            ret = self.session.query(TEntity).where(
+                TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first()
+            if ret:
+                return ret
 
     def add_entity(self, entity):
         """
@@ -97,20 +117,14 @@
         self.session.commit()
         return entity.id
 
-    def add_paragraph_ref_link(self, paragraph_ref_link):
-        """
-        娣诲姞娈佃惤寮曠敤鍏崇郴
-        :param paragraph_ref_link: 娈佃惤寮曠敤鍏崇郴
-        """
-        self.session.add(paragraph_ref_link)
-        self.session.commit()
-        return paragraph_ref_link
-
     def get_all_entities(self) -> list[TEntity]:
-        return self.session.query(TEntity).all()
+        with self.lock:
+            return self.session.query(TEntity).all()
 
     def get_docs(self) -> list[TDoc]:
-        return self.session.query(TDoc).all()
+        with self.lock:
+            return self.session.query(TDoc).all()
+
 
     def get_texts_with_entities(self, entity_names: list[str]):
         """
@@ -118,13 +132,29 @@
         :param entity_names: list[str] - 瀹炰綋璇�
         :return: list[str] - 鏂囨湰鍒楄〃
         """
-        if not entity_names:
-            return ""
-        _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all()
-        _entitie_ids = [entity.id for entity in _entities]
-        links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
-        _paragraphs = [link.paragraph for link in links]
-        return [self.get_paragraph_full_text(p) for p in _paragraphs]
+        with self.lock:
+            if not entity_names:
+                return ""
+            _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all()
+            _entitie_ids = [entity.id for entity in _entities]
+            links = self.session.query(TParagraphEntityLink).where(
+                TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
+            _paragraphs:[TParagraph] = [link.paragraph for link in links]
+            ref_paragraphs = []
+            for p in _paragraphs:
+                ref_paragraphs.extend([x.child for x in p.ref_links])
+            _paragraphs.extend(ref_paragraphs)
+            id_map = {}
+            result = []
+            for p in _paragraphs:
+                if p.id in id_map:
+                    continue
+                else:
+                    id_map[p.id] = p
+                    result.append(p)
+            return [p.text for p in result]
+
+
     def get_text_with_entities(self, entity_names: list[str]) -> str:
         """
         鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹�
@@ -143,7 +173,13 @@
         return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
 
     def get_entities_by_doc_type(self, doc_type):
-        _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all()
+        with self.lock:
+            _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all()
+        return _entities
+
+    def get_entities_by_type(self, ty: str)->list[TEntity]:
+        with self.lock:
+            _entities = self.session.query(TEntity).where(TEntity.type == ty).all()
         return _entities
 
     def commit(self):

--
Gitblit v1.9.1