From c099e6662b8a6e320ac314d31eda9b40455e5aa7 Mon Sep 17 00:00:00 2001 From: lyg <1543117173@qq.com> Date: 星期四, 22 五月 2025 09:27:37 +0800 Subject: [PATCH] 修改指令json生成相关提示词和代码逻辑 --- knowledgebase/db/doc_db_helper.py | 40 +++++++++++++++++++++++++--------------- 1 files changed, 25 insertions(+), 15 deletions(-) diff --git a/knowledgebase/db/doc_db_helper.py b/knowledgebase/db/doc_db_helper.py index fe36fb2..6b8d6cb 100644 --- a/knowledgebase/db/doc_db_helper.py +++ b/knowledgebase/db/doc_db_helper.py @@ -82,6 +82,12 @@ self.session.commit() return paragraph_entity_link.id + def get_entity(self, entity): + ret = self.session.query(TEntity).where( + TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first() + if ret: + return ret + def add_entity(self, entity): """ 娣诲姞瀹炰綋 @@ -106,11 +112,11 @@ def get_docs(self) -> list[TDoc]: return self.session.query(TDoc).all() - def get_text_with_entities(self, entity_names: list[str]) -> str: + def get_texts_with_entities(self, entity_names: list[str]): """ - 鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹� + 鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹瑰垪琛� :param entity_names: list[str] - 瀹炰綋璇� - :return: str - 鏂囨湰 + :return: list[str] - 鏂囨湰鍒楄〃 """ if not entity_names: return "" @@ -118,26 +124,30 @@ _entitie_ids = [entity.id for entity in _entities] links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all() _paragraphs = [link.paragraph for link in links] + return [self.get_paragraph_full_text(p) for p in _paragraphs] + def get_text_with_entities(self, entity_names: list[str]) -> str: + """ + 鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹� + :param entity_names: list[str] - 瀹炰綋璇� + :return: str - 鏂囨湰 + """ + texts = self.get_texts_with_entities(entity_names) + return '\n'.join(texts) - return '\n'.join([self.get_paragraph_full_text(p) for p in _paragraphs]) + def get_entities_by_names(self, names: list[str]): + _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all() + return _entities def get_paragraph_full_text(self, p: TParagraph): result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children]) + + def get_entities_by_doc_type(self, doc_type): + _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all() + return _entities def commit(self): self.session.commit() doc_dbh = DocDbHelper() - -# if __name__ == '__main__': -# text = doc_dbh.get_text_with_entities(['閬ユ帶鍖呮牸寮�']) -# print(text) -# doc_db = DocDbHelper() -# # doc_db.insert_entities() -# doc = doc_db.add_doc(DocInfo(file='aaa', file_name='test')) -# p1 = doc_db.add_paragraph(doc.id, None, ParagraphInfo(text='test1', title_level=1, num=1, num_level=1)) -# p2 = doc_db.add_paragraph(doc.id, p1.id, ParagraphInfo(text='test2', title_level=2, num=1, num_level=2)) -# p3 = doc_db.add_paragraph(doc.id, p2.id, ParagraphInfo(text='test3', title_level=3, num=1, num_level=3)) -# doc_db.add_paragraph_ref_link(TParagraphRefLink(parent_id=p1.id, child_id=p3.id)) -- Gitblit v1.9.1