From e60d75228fb161e464ca59fa2526bf0765f4d902 Mon Sep 17 00:00:00 2001
From: lyg <1543117173@qq.com>
Date: 星期四, 22 五月 2025 12:35:55 +0800
Subject: [PATCH] 修改指令json生成,加入fastapi

---
 knowledgebase/db/doc_db_helper.py |   50 +++++++++++++++++++++++++++++++++++++++++---------
 1 files changed, 41 insertions(+), 9 deletions(-)

diff --git a/knowledgebase/db/doc_db_helper.py b/knowledgebase/db/doc_db_helper.py
index 5089e30..6b8d6cb 100644
--- a/knowledgebase/db/doc_db_helper.py
+++ b/knowledgebase/db/doc_db_helper.py
@@ -17,6 +17,7 @@
     """
     鏂囨。鏁版嵁搴撳姪鎵�
     """
+
     def __init__(self):
         self.session = init_doc_db()
 
@@ -81,6 +82,12 @@
         self.session.commit()
         return paragraph_entity_link.id
 
+    def get_entity(self, entity):
+        ret = self.session.query(TEntity).where(
+            TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first()
+        if ret:
+            return ret
+
     def add_entity(self, entity):
         """
         娣诲姞瀹炰綋
@@ -105,17 +112,42 @@
     def get_docs(self) -> list[TDoc]:
         return self.session.query(TDoc).all()
 
+    def get_texts_with_entities(self, entity_names: list[str]):
+        """
+        鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹瑰垪琛�
+        :param entity_names: list[str] - 瀹炰綋璇�
+        :return: list[str] - 鏂囨湰鍒楄〃
+        """
+        if not entity_names:
+            return ""
+        _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all()
+        _entitie_ids = [entity.id for entity in _entities]
+        links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
+        _paragraphs = [link.paragraph for link in links]
+        return [self.get_paragraph_full_text(p) for p in _paragraphs]
+    def get_text_with_entities(self, entity_names: list[str]) -> str:
+        """
+        鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹�
+        :param entity_names: list[str] - 瀹炰綋璇�
+        :return: str - 鏂囨湰
+        """
+        texts = self.get_texts_with_entities(entity_names)
+        return '\n'.join(texts)
+
+    def get_entities_by_names(self, names: list[str]):
+        _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all()
+        return _entities
+
+    def get_paragraph_full_text(self, p: TParagraph):
+        result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text
+        return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
+
+    def get_entities_by_doc_type(self, doc_type):
+        _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all()
+        return _entities
+
     def commit(self):
         self.session.commit()
 
 
 doc_dbh = DocDbHelper()
-
-# if __name__ == '__main__':
-#     doc_db = DocDbHelper()
-#     # doc_db.insert_entities()
-#     doc = doc_db.add_doc(DocInfo(file='aaa', file_name='test'))
-#     p1 = doc_db.add_paragraph(doc.id, None, ParagraphInfo(text='test1', title_level=1, num=1, num_level=1))
-#     p2 = doc_db.add_paragraph(doc.id, p1.id, ParagraphInfo(text='test2', title_level=2, num=1, num_level=2))
-#     p3 = doc_db.add_paragraph(doc.id, p2.id, ParagraphInfo(text='test3', title_level=3, num=1, num_level=3))
-#     doc_db.add_paragraph_ref_link(TParagraphRefLink(parent_id=p1.id, child_id=p3.id))

--
Gitblit v1.9.1