From e60d75228fb161e464ca59fa2526bf0765f4d902 Mon Sep 17 00:00:00 2001
From: lyg <1543117173@qq.com>
Date: 星期四, 22 五月 2025 12:35:55 +0800
Subject: [PATCH] 修改指令json生成,加入fastapi

---
 knowledgebase/doc/entity_recognition.py |   28 +++++++++++++++++-----------
 1 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/knowledgebase/doc/entity_recognition.py b/knowledgebase/doc/entity_recognition.py
index 8b3d58e..6366f8f 100644
--- a/knowledgebase/doc/entity_recognition.py
+++ b/knowledgebase/doc/entity_recognition.py
@@ -11,7 +11,8 @@
 import json
 
 from knowledgebase import utils
-from knowledgebase.doc.entity_helper import entity_helper
+from knowledgebase.db.doc_db_helper import doc_dbh
+from knowledgebase.log import Log
 
 llm = ChatOpenAI(temperature=0,
                  model="qwen2.5-72b-instruct",
@@ -25,30 +26,35 @@
 
     浣跨敤langchain鏋勫缓瀹炰綋鎶藉彇娴佺▼銆�
     """
+    use_cache = False
     cache_file = "entity_recognition.cache"
 
     def __init__(self, doc_type: str):
         # 瀹炰綋璇嶅垪琛�
-        entities = filter(lambda x: x.doc_type == doc_type, entity_helper.entities)
-        entity_list = '锛沑n'.join([f'- {entity.name}锛歿entity.prompts}' for entity in entities]) + "銆�"
-        msg = HumanMessagePromptTemplate.from_template(template="""
+        entities = doc_dbh.get_entities_by_doc_type(doc_type)
+        entity_list = '锛�'.join([entity.name for entity in entities]) + "銆�"
+        entity_rules = "锛沑n".join([f"- {entity.name}锛歿entity.prompts}" for entity in entities]) + "銆�"
+        tpl = """
 # 鎸囦护
-璇蜂粠缁欏畾鐨勬枃鏈腑鎻愬彇瀹炰綋璇嶅垪琛紝瀹炰綋璇嶅垪琛ㄥ畾涔夊涓嬶細
-## 瀹炰綋璇嶅垪琛ㄥ強璇嗗埆瑙勫垯
+璇锋牴鎹疄浣撹瘝鍒ゆ柇瑙勫垯浠庣粰瀹氱殑鏂囨湰涓垽鏂槸鍚︽湁涓嬪垪瀹炰綋璇嶇浉鍏冲唴瀹癸紝濡傛灉鏈夊垯杈撳嚭鐩稿叧鐨勫疄浣撹瘝锛屾病鏈夊垯涓嶈緭鍑猴紝瀹炰綋璇嶅垪琛ㄥ畾涔夊涓嬶細
 """ + entity_list + """
+## 瀹炰綋璇嶅垽鏂鍒欙細
+""" + entity_rules + """
 # 绾︽潫
 - 杈撳嚭鏍煎紡涓篔SON鏍煎紡锛�
-- 鎻愬彇鐨勫疄浣撹瘝蹇呴』鏄笂闈㈠垪涓剧殑瀹炰綋璇嶏紱
+- 鎻愬彇鐨勫疄浣撹瘝蹇呴』鏄細""" + entity_list + """锛�
+- 濡傛灉娌℃湁澶嶅悎涓婅堪瑙勫垯鐨勫疄浣撹瘝鍒欎笉瑕佽緭鍑轰换浣曞疄浣撹瘝锛�
 - 杈撳嚭鏁版嵁缁撴瀯涓哄瓧绗︿覆鏁扮粍銆�
 # 绀轰緥
 ```json
-["閬ユ帶甯ф牸寮�","閬ユ帶鍖呮牸寮�"]
+[\"""" + entities[0].name + """\"]
 ```
 
 # 鏂囨湰濡備笅锛�
 {text}
 """
-                                                       )
+        Log.info(tpl)
+        msg = HumanMessagePromptTemplate.from_template(template=tpl)
         prompt = ChatPromptTemplate.from_messages([msg])
         parser = JsonOutputParser(pydantic_object=list[str])
         self.chain = prompt | llm | parser
@@ -68,7 +74,7 @@
         淇濆瓨缂撳瓨銆�
         """
         text = json.dumps(self.cache)
-        utils.save_to_file(text, self.cache_file)
+        utils.save_text_to_file(text, self.cache_file)
 
     def run(self, in_text: str) -> list[str]:
         """
@@ -77,7 +83,7 @@
         """
         # 缂撳瓨鍛戒腑
         text_md5 = utils.generate_text_md5(in_text)
-        if text_md5 in self.cache:
+        if self.use_cache and text_md5 in self.cache:
             return self.cache[text_md5]
         result = self.chain.invoke({"text": in_text})
         self.cache[text_md5] = result

--
Gitblit v1.9.1