From c099e6662b8a6e320ac314d31eda9b40455e5aa7 Mon Sep 17 00:00:00 2001 From: lyg <1543117173@qq.com> Date: 星期四, 22 五月 2025 09:27:37 +0800 Subject: [PATCH] 修改指令json生成相关提示词和代码逻辑 --- knowledgebase/db/doc_db_helper.py | 44 ++++++++++++++++++++++++++++++++------------ 1 files changed, 32 insertions(+), 12 deletions(-) diff --git a/knowledgebase/db/doc_db_helper.py b/knowledgebase/db/doc_db_helper.py index d1231cc..6b8d6cb 100644 --- a/knowledgebase/db/doc_db_helper.py +++ b/knowledgebase/db/doc_db_helper.py @@ -82,6 +82,12 @@ self.session.commit() return paragraph_entity_link.id + def get_entity(self, entity): + ret = self.session.query(TEntity).where( + TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first() + if ret: + return ret + def add_entity(self, entity): """ 娣诲姞瀹炰綋 @@ -106,28 +112,42 @@ def get_docs(self) -> list[TDoc]: return self.session.query(TDoc).all() + def get_texts_with_entities(self, entity_names: list[str]): + """ + 鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹瑰垪琛� + :param entity_names: list[str] - 瀹炰綋璇� + :return: list[str] - 鏂囨湰鍒楄〃 + """ + if not entity_names: + return "" + _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all() + _entitie_ids = [entity.id for entity in _entities] + links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all() + _paragraphs = [link.paragraph for link in links] + return [self.get_paragraph_full_text(p) for p in _paragraphs] def get_text_with_entities(self, entity_names: list[str]) -> str: """ 鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹� :param entity_names: list[str] - 瀹炰綋璇� :return: str - 鏂囨湰 """ - if not entity_names: - return "" + texts = self.get_texts_with_entities(entity_names) + return '\n'.join(texts) - return '\n'.join([entity.name for entity in self.get_all_entities() if entity.name in entity_names]) + def get_entities_by_names(self, names: list[str]): + _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all() + return _entities + + def get_paragraph_full_text(self, p: TParagraph): + result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text + return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children]) + + def get_entities_by_doc_type(self, doc_type): + _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all() + return _entities def commit(self): self.session.commit() doc_dbh = DocDbHelper() - -# if __name__ == '__main__': -# doc_db = DocDbHelper() -# # doc_db.insert_entities() -# doc = doc_db.add_doc(DocInfo(file='aaa', file_name='test')) -# p1 = doc_db.add_paragraph(doc.id, None, ParagraphInfo(text='test1', title_level=1, num=1, num_level=1)) -# p2 = doc_db.add_paragraph(doc.id, p1.id, ParagraphInfo(text='test2', title_level=2, num=1, num_level=2)) -# p3 = doc_db.add_paragraph(doc.id, p2.id, ParagraphInfo(text='test3', title_level=3, num=1, num_level=3)) -# doc_db.add_paragraph_ref_link(TParagraphRefLink(parent_id=p1.id, child_id=p3.id)) -- Gitblit v1.9.1