From c099e6662b8a6e320ac314d31eda9b40455e5aa7 Mon Sep 17 00:00:00 2001
From: lyg <1543117173@qq.com>
Date: 星期四, 22 五月 2025 09:27:37 +0800
Subject: [PATCH] 修改指令json生成相关提示词和代码逻辑

---
 knowledgebase/db/doc_db_helper.py |   40 +++++++++++++++++++++++++---------------
 1 files changed, 25 insertions(+), 15 deletions(-)

diff --git a/knowledgebase/db/doc_db_helper.py b/knowledgebase/db/doc_db_helper.py
index fe36fb2..6b8d6cb 100644
--- a/knowledgebase/db/doc_db_helper.py
+++ b/knowledgebase/db/doc_db_helper.py
@@ -82,6 +82,12 @@
         self.session.commit()
         return paragraph_entity_link.id
 
+    def get_entity(self, entity):
+        ret = self.session.query(TEntity).where(
+            TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first()
+        if ret:
+            return ret
+
     def add_entity(self, entity):
         """
         娣诲姞瀹炰綋
@@ -106,11 +112,11 @@
     def get_docs(self) -> list[TDoc]:
         return self.session.query(TDoc).all()
 
-    def get_text_with_entities(self, entity_names: list[str]) -> str:
+    def get_texts_with_entities(self, entity_names: list[str]):
         """
-        鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹�
+        鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹瑰垪琛�
         :param entity_names: list[str] - 瀹炰綋璇�
-        :return: str - 鏂囨湰
+        :return: list[str] - 鏂囨湰鍒楄〃
         """
         if not entity_names:
             return ""
@@ -118,26 +124,30 @@
         _entitie_ids = [entity.id for entity in _entities]
         links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
         _paragraphs = [link.paragraph for link in links]
+        return [self.get_paragraph_full_text(p) for p in _paragraphs]
+    def get_text_with_entities(self, entity_names: list[str]) -> str:
+        """
+        鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹�
+        :param entity_names: list[str] - 瀹炰綋璇�
+        :return: str - 鏂囨湰
+        """
+        texts = self.get_texts_with_entities(entity_names)
+        return '\n'.join(texts)
 
-        return '\n'.join([self.get_paragraph_full_text(p) for p in _paragraphs])
+    def get_entities_by_names(self, names: list[str]):
+        _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all()
+        return _entities
 
     def get_paragraph_full_text(self, p: TParagraph):
         result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text
         return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
+
+    def get_entities_by_doc_type(self, doc_type):
+        _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all()
+        return _entities
 
     def commit(self):
         self.session.commit()
 
 
 doc_dbh = DocDbHelper()
-
-# if __name__ == '__main__':
-#     text = doc_dbh.get_text_with_entities(['閬ユ帶鍖呮牸寮�'])
-#     print(text)
-#     doc_db = DocDbHelper()
-#     # doc_db.insert_entities()
-#     doc = doc_db.add_doc(DocInfo(file='aaa', file_name='test'))
-#     p1 = doc_db.add_paragraph(doc.id, None, ParagraphInfo(text='test1', title_level=1, num=1, num_level=1))
-#     p2 = doc_db.add_paragraph(doc.id, p1.id, ParagraphInfo(text='test2', title_level=2, num=1, num_level=2))
-#     p3 = doc_db.add_paragraph(doc.id, p2.id, ParagraphInfo(text='test3', title_level=3, num=1, num_level=3))
-#     doc_db.add_paragraph_ref_link(TParagraphRefLink(parent_id=p1.id, child_id=p3.id))

--
Gitblit v1.9.1