lyg
2025-05-14 37c40c84aa27ff68f6dc7325fb45c9a8c7b70fe8
修改大模型生成json逻辑
2个文件已修改
268 ■■■■ 已修改文件
knowledgebase/db/doc_db_helper.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/gen_base_db/json_generate.py 256 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/db/doc_db_helper.py
@@ -17,6 +17,7 @@
    """
    文档数据库助手
    """
    def __init__(self):
        self.session = init_doc_db()
@@ -105,6 +106,17 @@
    def get_docs(self) -> list[TDoc]:
        return self.session.query(TDoc).all()
    def get_text_with_entities(self, entity_names: list[str]) -> str:
        """
        根据实体词获取文本内容
        :param entity_names: list[str] - 实体词
        :return: str - 文本
        """
        if not entity_names:
            return ""
        return '\n'.join([entity.name for entity in self.get_all_entities() if entity.name in entity_names])
    def commit(self):
        self.session.commit()
knowledgebase/gen_base_db/json_generate.py
@@ -9,15 +9,32 @@
import time
import json
import data_templates
from knowledgebase.db.doc_db_helper import doc_dbh
from knowledgebase.llm import llm
from knowledgebase import utils
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage,SystemMessage
import textwrap
from knowledgebase.log import Log
USE_CACHE = True
class JsonGenerate:
    project: dict
    devs: list[dict]
    cadu: dict
    vcs: list[dict]
    tm_pkts: list[dict]
    vc_pkts: list[dict]
    bus_pkts: list[dict]
    tc_frame: dict
    tc_pkt_format: dict
    tc_pkts: dict
    def __init__(self):
        self.llm = llm
        self.systemPrompt = """
@@ -70,7 +87,16 @@
        """
        
    # 模型调用
    def call_model(self, msg, cache_file, doc, validation=None, try_cnt=3):
    def call_model(self, msg: str, cache_file: str, doc: str, validation=None, try_cnt=3) -> str:
        """
        调用大模型
        :param msg: 问题
        :param cache_file: 生成结果缓存文件
        :param doc: 文档文本
        :param validation: 校验函数,(text: str)-> None
        :param try_cnt: 失败重试次数
        :return: 生成的文本
        """
        if USE_CACHE and os.path.isfile(cache_file):
            with open(cache_file, 'r', encoding='utf-8') as f:
                text = f.read()
@@ -81,44 +107,55 @@
                messages.append(HumanMessage(info))
            prompt = ChatPromptTemplate.from_messages(messages)
            chain = prompt | self.llm
            # 去除多余的缩进
            msg = textwrap.dedent(msg).strip()
            resp = chain.invoke({"msg": msg})
            text = resp.content
            if validation:
                try:
                    validation(text)
                except BaseException as e:
                    print(e)
                    Log.error(e)
                    if try_cnt <= 0:
                        raise RuntimeError('生成失败,重试次数太多,强制结束!')
                    return self.call_model(msg, cache_file, validation, try_cnt - 1)
            if cache_file:
                with open(cache_file, 'w', encoding='utf-8') as f:
                    f.write(text)
            print(f'耗时:{time.time() - s}')
            Log.info(f'耗时:{time.time() - s}')
        return text
    # :param type: int - None 全部、1 遥测、2 遥控
    def run(self,type):
        # 根据文档,生成结构化数据
        if type is not None:
            if type == 1:
                self.handle_yc_structured_data()
            if type == 2:
                self.handle_yk_structured_data()
        else:
            self.handle_yc_structured_data()
            self.handle_yk_structured_data()
    @staticmethod
    def get_text_with_entity(entity_names: list[str]) -> str:
        """
        根据实体词获取文档文本
        :param entity_names: str - 实体词名称
        :return: str - 文本内容
        """
        return doc_dbh.get_text_with_entities(entity_names)
    
    # 遥测-start
    def handle_yc_structured_data(self):
    def run(self):
        # 根据文档,生成结构化数据
        self.handle_tm_structured_data()
        self.handle_tc_structured_data()
    # region start 遥测
    def handle_tm_structured_data(self):
        self.gen_project()
        self.gen_device()
    # 获取项目信息
    def gen_project(self):
        _msg = '根据文档输出型号信息,型号字段包括:名称和代号。仅输出型号这一级。例如:{"name":"xxx","id":"xxx"}'
        result = self.call_model(_msg, 'out/型号信息.json', ['这里是文档中抽取的内容'])
        print('型号信息:' + result)
        _msg = """
        # 指令
        根据文档内容分析型号信息,型号字段包括:名称和代号。
        # 例子
        {"name":"xxx","id":"xxx"}
        """
        doc_text = self.get_text_with_entity(['系统概述'])
        text = self.call_model(_msg, 'out/型号信息.json', doc_text)
        self.project = json.loads(text)
        Log.info('型号信息:' + self.project)
    # 获取设备信息
    def gen_device(self):
@@ -155,37 +192,98 @@
                }
            ]
        """
        result = self.call_model(_msg, 'out/设备列表.json', ['这里是文档中抽取的内容'])
        print('设备列表:' + result)
        doc_text = self.get_text_with_entity(['系统概述', '总线管理'])
        text = self.call_model(_msg, 'out/设备列表.json', doc_text)
        Log.info('设备列表:' + text)
        devs = json.loads(result)
        self.devs = json.loads(text)
        # 类SMU设备,包含遥测和遥控功能,名称结尾为“管理单元”
        like_smu_devs = list(filter(lambda it: it['hasTcTm'] and it['name'].endswith('管理单元'), devs))
        like_smu_devs = list(filter(lambda it: it['hasTcTm'] and it['name'].endswith('管理单元'), self.devs))
        for dev in like_smu_devs:
            self.gen_tm_frame(dev)
        # 总线
        hasBus = any(d['hasBus'] for d in self.devs)
        if hasBus:
        self.gen_bus()
    def gen_tm_frame(self,dev):
        # 插入域参数列表
        self.gen_insert_domain_params(dev)
        insert_domain = self.gen_insert_domain_params(dev)
        # VC源包格式
        vc_pkt_fields = data_templates.vc_pkt_fields
        # 获取虚拟信道 vc
        vcs = self.gen_vc(dev)
        self.vcs = self.gen_vc(dev)
        for vc in self.vcs:
            vc['children'] = []
            vc['VCID'] = str(int(vc['VCID'], 2))
            for field in vc_pkt_fields:
                if field['name'] == '数据域':
                    field['children'] = []
                vc['children'].append(dict(field))
        def build_vcid_content(vcs):
            _vcs = []
            for _vc in vcs:
                _vcs.append(_vc['name'] + ',' + _vc['VCID'])
            return ' '.join(_vcs)
        # VCID 字段内容
        vcid_content = build_vcid_content(self.vcs)
        # 遥测帧结构由模板生成,只需提供特定参数
        tm_data = {
            "vcidContent": vcid_content,
            'insertDomain': insert_domain,
        }
        self.cadu = data_templates.get_tm_frame(tm_data)
        # 获取vc源包
        vc_pkts = self.gen_pkt_vc(dev)
        self.vc_pkts = self.gen_pkt_vc(dev)
        # 获取源包列表
        tm_pkts = self.gen_pkts(dev)
        self.tm_pkts = self.gen_pkts(dev)
        # 获取VC下面的遥测包数据
        for vc in vcs:
        for vc in self.vcs:
            # 此VC下的遥测包过滤
            _vc_pkts = filter(lambda it: it['vcs'].__contains__(vc['id']), vc_pkts)
            _vc_pkts = filter(lambda it: it['vcs'].__contains__(vc['id']), self.vc_pkts)
            for _pkt in _vc_pkts:
                # 判断遥测包是否有详细定义
                if not next(filter(lambda it: it['name'] == _pkt['name'] and it['hasParams'], tm_pkts), None):
                if not next(filter(lambda it: it['name'] == _pkt['name'] and it['hasParams'], self.tm_pkts), None):
                    continue
                # 获取包详情
                _pkt = self.gen_pkt_details(_pkt['name'], _pkt['id'])
                epdu = next(filter(lambda it: it['name'] == '数据域', vc['children']), None)
                if epdu and _pkt:
                    _pkt['children'] = _pkt['datas']
                    _last_par = _pkt['children'][len(_pkt['children']) - 1]
                    _pkt['length'] = (_last_par['pos'] + _last_par['length'])
                    _pkt['pos'] = 0
                    if 'children' not in epdu:
                        epdu['children'] = []
                    # 添加解析规则后缀防止重复
                    _pkt['id'] = _pkt['id'] + '_' + vc['VCID']
                    # 给包名加代号前缀
                    if not _pkt['name'].startswith(_pkt['id']):
                        _pkt['name'] = _pkt['id'] + '_' + _pkt['name']
                    epdu['children'].append(_pkt)
                    apid_node = next(filter(lambda it: it['name'].__contains__('应用过程'), _pkt['headers']), None)
                    ser_node = next(filter(lambda it: it['name'] == '服务', _pkt['headers']), None)
                    sub_ser_node = next(filter(lambda it: it['name'] == '子服务', _pkt['headers']), None)
                    _pkt['vals'] = \
                        f"{apid_node['content']}/{int(ser_node['content'], 16)}/{int(sub_ser_node['content'], 16)}/"
        # 重新计数起始偏移
        self.compute_length_pos(self.cadu['children'])
    def compute_length_pos(self, items: list):
        length = 0
        pos = 0
        for child in items:
            if 'children' in child:
                self.compute_length_pos(child['children'])
            child['pos'] = pos
            if 'length' in child and isinstance(child['length'], int):
                length = length + child['length']
                pos = pos + child['length']
        # node['length'] = length
    def gen_insert_domain_params(self,dev):
        _msg = """
@@ -212,12 +310,17 @@
            }
            ]
        """
        def validation(gen_text):
            params = json.loads(gen_text)
            assert isinstance(params, list), '插入域参数列表数据结构最外层必须是数组'
            assert len(params), '插入域参数列表不能为空'
        result = self.call_model(_msg, 'out/'+dev.code+'_插入域参数列表.json', ['这里是文档中抽取的内容'], validation)
        print('插入域参数列表:' + result)
        doc_text = self.get_text_with_entity(['插入域'])
        result = self.call_model(_msg, 'out/' + dev.code + '_插入域参数列表.json', doc_text,
                                 validation)
        Log.info('插入域参数列表:' + result)
        return json.loads(result)
    
    def gen_vc(self,dev):
        _msg = """
@@ -246,8 +349,9 @@
        def validation(gen_text):
            vcs = json.loads(gen_text)
            assert next(filter(lambda it: re.match('^[0-1]+$', it['VCID']), vcs)), '生成的VCID必须是二进制'
        result = self.call_model(_msg, 'out/'+dev.code+'_虚拟信道.json', ['这里是文档中抽取的内容'], validation)
        print('虚拟信道:' + result)
        doc_text = self.get_text_with_entity(['虚拟信道定义'])
        result = self.call_model(_msg, 'out/' + dev.code + '_虚拟信道.json', doc_text, validation)
        Log.info('虚拟信道:' + result)
        return json.loads(result)
        
    def gen_pkt_vc(self,dev):
@@ -271,12 +375,14 @@
            },
            ]
        """
        def validation(gen_text):
            pkts = json.loads(gen_text)
            assert len(pkts), 'VC源包列表不能为空'
        result = self.call_model(_msg, 'out/'+dev.code+'_遥测VC源包.json', ['这里是文档中抽取的内容'], validation)
        print('遥测源包所属虚拟信道:' + result)
        return json.loads(result)
        text = self.call_model(_msg, 'out/' + dev.code + '_遥测源包下传时机.json', ['遥测源包下传时机'], validation)
        Log.info('遥测源包所属虚拟信道:' + text)
        return json.loads(text)
    def gen_pkts(self,dev):
        _msg = """
@@ -305,10 +411,14 @@
            ]
        """
        result = self.call_model(_msg, 'out/'+dev.code+'_源包列表.json', ['这里是文档中抽取的内容'])
        print('遥测源包列表:' + result)
        Log.info('遥测源包列表:' + result)
        return json.loads(result)
    
    def gen_pkt_details(self, pkt_name, pkt_id):
        cache_file = f'out/数据包-{pkt_name}.json'
        if not os.path.isfile(cache_file):
            # 先问最后一个参数的字节位置
            Log.info(f'遥测源包“{pkt_name}”信息:')
        _msg = f"""
            # 指令
            我需要从文档中提取遥测源包的最后一个参数的bit位置和数据域参数个数,你要帮我完成参数bit位置和数据域参数个数的提取。
@@ -384,19 +494,27 @@
            ]
        """
        
        def validation(gen_text):
            _pkt = json.loads(gen_text)
            with open(f'out/tmp/{time.time()}.json', 'w') as f:
                f.write(gen_text)
            assert 'headers' in _pkt, '包结构中必须包含headers字段'
            assert 'datas' in _pkt, '包结构中必须包含datas字段'
            print(f'参数个数:{len(_pkt["datas"])}')
                Log.info(f'参数个数:{len(_pkt["datas"])}')
            # assert par_num == len(_pkt['datas']), f'数据域参数个数不对!预计{par_num}个,实际{len(_pkt["datas"])}'
            assert last_par_pos == _pkt['datas'][-1]['pos'], '最后一个参数的字节位置不对!'
        result = self.call_model(_msg, f'out/数据包-{pkt_name}.json', [], ['这里是文档中抽取的内容'], validation)
        print(f'数据包“{pkt_name}”信息:'+result)
            Log.info(f'数据包“{pkt_name}”信息:' + result)
            pkt = json.loads(result)
        else:
            pkt = json.loads(utils.read_from_file(cache_file))
        pkt_len = 0
        for par in pkt['datas']:
            par['pos'] = pkt_len
            pkt_len += par['length']
        pkt['length'] = pkt_len
        return pkt
    def gen_bus(self):
        _msg = """
@@ -441,7 +559,7 @@
            json.loads(gen_text)
        result = self.call_model(_msg, 'out/总线.json', ['这里是文档中抽取的内容'], validation)
        print('总线数据包:' + result)
        Log.info('总线数据包:' + result)
        
        pkts = json.loads(result)
        # 筛选经总线的数据包
@@ -450,11 +568,18 @@
        pkts = list(filter(lambda it: it['apid'], pkts))
        pkts2 = []
        # todo 这一步应该通过数据库筛选,数据库中已经有所有遥测包以及遥测包对应的定义段落文本
        for pkt in pkts:
            if self.pkt_in_tm_pkts(pkt["name"]):
                pkts2.append(pkt)
        for pkt in pkts2:
            self.gen_pkt_details(pkt['name'], pkt['id'])
            _pkt = self.gen_pkt_details(pkt['name'], pkt['id'])
            if _pkt:
                pkt['children'] = []
                pkt['children'].extend(_pkt['datas'])
                pkt['length'] = _pkt['length']
        self.bus_pkts = pkts
    def pkt_in_tm_pkts(self, pkt_name):
        _msg = f"""
@@ -471,20 +596,20 @@
            有
        """
        text = self.call_model(_msg, f'out/pkts/有无数据包-{pkt_name}.txt', ['这里是文档中抽取的内容'])
        print(f'文档中有无“{pkt_name}”的字段描述:'+ text)
        Log.info(f'文档中有无“{pkt_name}”的字段描述:' + text)
        return text == '有'
    # 遥测-end
    # endregion 遥测-end
    # 遥控-start
    def handle_yk_structured_data(self):
    # region start 遥控
    def handle_tc_structured_data(self):
        # 数据帧格式
        self.gen_tc_transfer_frame_format()
        self.tc_frame = self.gen_tc_transfer_frame_format()
        # 遥控包格式
        self.gen_tc_pkt_format()
        self.tc_pkt_format = self.gen_tc_pkt_format()
        # 遥控包列表
        pkts = self.gen_tc_transfer_pkts()
        for pkt in pkts:
        self.tc_pkts = self.gen_tc_transfer_pkts()
        for pkt in self.tc_pkts:
            # 遥控包数据区内容
            self.gen_tc_pkt_details(pkt)
    
@@ -517,8 +642,13 @@
        def validation(gen_text):
            json.loads(gen_text)
        result = self.call_model(_msg, 'out/tc_transfer_frame.json', ['这里是文档中抽取的内容'],validation)
        print('遥控帧格式:' + result)
        text = self.call_model(_msg, 'out/tc_transfer_frame.json', ['这里是文档中抽取的内容'], validation)
        result: dict = json.loads(text)
        format_text = utils.read_from_file('tpl/tc_transfer_frame.json')
        format_text = utils.replace_tpl_paras(format_text, result)
        frame = json.loads(format_text)
        Log.info('遥控帧格式:' + format_text)
        return frame
    def gen_tc_pkt_format(self):
        _msg = '''
@@ -551,8 +681,14 @@
        def validation(gen_text):
            json.loads(gen_text)
        result = self.call_model(_msg, 'out/tc_transfer_pkt.json', ['这里是文档中抽取的内容'],validation)
        print('遥控包格式:' + result)
        text = self.call_model(_msg, 'out/tc_transfer_pkt.json', ['这里是文档中抽取的内容'], validation)
        result = json.loads(text)
        format_text = utils.read_from_file('tpl/tc_pkt_format.json')
        format_text = utils.replace_tpl_paras(format_text, result)
        pkt_format = json.loads(format_text)
        Log.info('遥控包格式:' + format_text)
        return pkt_format
        
    def gen_tc_transfer_pkts(self):
        _msg = '''
@@ -571,8 +707,9 @@
        def validation(gen_text):
            json.loads(gen_text)
        result = self.call_model(_msg, 'out/tc_transfer_pkts.json', ['这里是文档中抽取的内容'],validation)
        print('遥控包列表:' + result)
        text = self.call_model(_msg, 'out/tc_transfer_pkts.json', ['这里是文档中抽取的内容'], validation)
        Log.info('遥控包列表:' + text)
        return json.loads(text)
    def gen_tc_pkt_details(self, pkt):
        tc_name = pkt['name']
@@ -620,7 +757,8 @@
        def validation(gen_text):
            json.loads(gen_text)
        result = self.call_model(_msg, f'out/遥控指令数据域-{tc_code}-{utils.to_file_name(tc_name)}.json', ['这里是文档中抽取的内容'],validation)
        print('遥控指令数据域:' + result)
        result = self.call_model(_msg, f'out/遥控指令数据域-{tc_code}-{utils.to_file_name(tc_name)}.json',
                                 ['这里是文档中抽取的内容'], validation)
        Log.info('遥控指令数据域:' + result)
    # 遥控-end
    # endregion 遥控-end