| | |
| | | import time |
| | | import json |
| | | |
| | | import data_templates |
| | | from knowledgebase.db.doc_db_helper import doc_dbh |
| | | from knowledgebase.llm import llm |
| | | from knowledgebase import utils |
| | | |
| | | from langchain_core.prompts import ChatPromptTemplate |
| | | from langchain_core.messages import HumanMessage,SystemMessage |
| | | import textwrap |
| | | |
| | | from knowledgebase.log import Log |
| | | |
| | | USE_CACHE = True |
| | | |
| | | |
| | | class JsonGenerate: |
| | | project: dict |
| | | devs: list[dict] |
| | | cadu: dict |
| | | vcs: list[dict] |
| | | tm_pkts: list[dict] |
| | | vc_pkts: list[dict] |
| | | bus_pkts: list[dict] |
| | | tc_frame: dict |
| | | tc_pkt_format: dict |
| | | tc_pkts: dict |
| | | |
| | | def __init__(self): |
| | | self.llm = llm |
| | | self.systemPrompt = """ |
| | |
| | | """ |
| | | |
| | | # 模型调用 |
| | | def call_model(self, msg, cache_file, doc, validation=None, try_cnt=3): |
| | | def call_model(self, msg: str, cache_file: str, doc: str, validation=None, try_cnt=3) -> str: |
| | | """ |
| | | 调用大模型 |
| | | :param msg: 问题 |
| | | :param cache_file: 生成结果缓存文件 |
| | | :param doc: 文档文本 |
| | | :param validation: 校验函数,(text: str)-> None |
| | | :param try_cnt: 失败重试次数 |
| | | :return: 生成的文本 |
| | | """ |
| | | if USE_CACHE and os.path.isfile(cache_file): |
| | | with open(cache_file, 'r', encoding='utf-8') as f: |
| | | text = f.read() |
| | |
| | | messages.append(HumanMessage(info)) |
| | | prompt = ChatPromptTemplate.from_messages(messages) |
| | | chain = prompt | self.llm |
| | | # 去除多余的缩进 |
| | | msg = textwrap.dedent(msg).strip() |
| | | resp = chain.invoke({"msg": msg}) |
| | | text = resp.content |
| | | if validation: |
| | | try: |
| | | validation(text) |
| | | except BaseException as e: |
| | | print(e) |
| | | Log.error(e) |
| | | if try_cnt <= 0: |
| | | raise RuntimeError('生成失败,重试次数太多,强制结束!') |
| | | return self.call_model(msg, cache_file, validation, try_cnt - 1) |
| | | if cache_file: |
| | | with open(cache_file, 'w', encoding='utf-8') as f: |
| | | f.write(text) |
| | | print(f'耗时:{time.time() - s}') |
| | | Log.info(f'耗时:{time.time() - s}') |
| | | return text |
| | | |
| | | # :param type: int - None 全部、1 遥测、2 遥控 |
| | | def run(self,type): |
| | | # 根据文档,生成结构化数据 |
| | | if type is not None: |
| | | if type == 1: |
| | | self.handle_yc_structured_data() |
| | | if type == 2: |
| | | self.handle_yk_structured_data() |
| | | else: |
| | | self.handle_yc_structured_data() |
| | | self.handle_yk_structured_data() |
| | | @staticmethod |
| | | def get_text_with_entity(entity_names: list[str]) -> str: |
| | | """ |
| | | 根据实体词获取文档文本 |
| | | :param entity_names: str - 实体词名称 |
| | | :return: str - 文本内容 |
| | | """ |
| | | return doc_dbh.get_text_with_entities(entity_names) |
| | | |
| | | # 遥测-start |
| | | def handle_yc_structured_data(self): |
| | | def run(self): |
| | | # 根据文档,生成结构化数据 |
| | | self.handle_tm_structured_data() |
| | | self.handle_tc_structured_data() |
| | | |
| | | # region start 遥测 |
| | | def handle_tm_structured_data(self): |
| | | self.gen_project() |
| | | self.gen_device() |
| | | |
| | | # 获取项目信息 |
| | | def gen_project(self): |
| | | _msg = '根据文档输出型号信息,型号字段包括:名称和代号。仅输出型号这一级。例如:{"name":"xxx","id":"xxx"}' |
| | | result = self.call_model(_msg, 'out/型号信息.json', ['这里是文档中抽取的内容']) |
| | | print('型号信息:' + result) |
| | | _msg = """ |
| | | # 指令 |
| | | 根据文档内容分析型号信息,型号字段包括:名称和代号。 |
| | | # 例子 |
| | | {"name":"xxx","id":"xxx"} |
| | | """ |
| | | doc_text = self.get_text_with_entity(['系统概述']) |
| | | text = self.call_model(_msg, 'out/型号信息.json', doc_text) |
| | | self.project = json.loads(text) |
| | | Log.info('型号信息:' + self.project) |
| | | |
| | | # 获取设备信息 |
| | | def gen_device(self): |
| | |
| | | } |
| | | ] |
| | | """ |
| | | result = self.call_model(_msg, 'out/设备列表.json', ['这里是文档中抽取的内容']) |
| | | print('设备列表:' + result) |
| | | doc_text = self.get_text_with_entity(['系统概述', '总线管理']) |
| | | text = self.call_model(_msg, 'out/设备列表.json', doc_text) |
| | | Log.info('设备列表:' + text) |
| | | |
| | | devs = json.loads(result) |
| | | self.devs = json.loads(text) |
| | | # 类SMU设备,包含遥测和遥控功能,名称结尾为“管理单元” |
| | | like_smu_devs = list(filter(lambda it: it['hasTcTm'] and it['name'].endswith('管理单元'), devs)) |
| | | like_smu_devs = list(filter(lambda it: it['hasTcTm'] and it['name'].endswith('管理单元'), self.devs)) |
| | | for dev in like_smu_devs: |
| | | self.gen_tm_frame(dev) |
| | | # 总线 |
| | | hasBus = any(d['hasBus'] for d in self.devs) |
| | | if hasBus: |
| | | self.gen_bus() |
| | | |
| | | def gen_tm_frame(self,dev): |
| | | # 插入域参数列表 |
| | | self.gen_insert_domain_params(dev) |
| | | insert_domain = self.gen_insert_domain_params(dev) |
| | | # VC源包格式 |
| | | vc_pkt_fields = data_templates.vc_pkt_fields |
| | | # 获取虚拟信道 vc |
| | | vcs = self.gen_vc(dev) |
| | | self.vcs = self.gen_vc(dev) |
| | | for vc in self.vcs: |
| | | vc['children'] = [] |
| | | vc['VCID'] = str(int(vc['VCID'], 2)) |
| | | for field in vc_pkt_fields: |
| | | if field['name'] == '数据域': |
| | | field['children'] = [] |
| | | vc['children'].append(dict(field)) |
| | | |
| | | def build_vcid_content(vcs): |
| | | _vcs = [] |
| | | for _vc in vcs: |
| | | _vcs.append(_vc['name'] + ',' + _vc['VCID']) |
| | | return ' '.join(_vcs) |
| | | |
| | | # VCID 字段内容 |
| | | vcid_content = build_vcid_content(self.vcs) |
| | | # 遥测帧结构由模板生成,只需提供特定参数 |
| | | tm_data = { |
| | | "vcidContent": vcid_content, |
| | | 'insertDomain': insert_domain, |
| | | } |
| | | self.cadu = data_templates.get_tm_frame(tm_data) |
| | | |
| | | # 获取vc源包 |
| | | vc_pkts = self.gen_pkt_vc(dev) |
| | | self.vc_pkts = self.gen_pkt_vc(dev) |
| | | # 获取源包列表 |
| | | tm_pkts = self.gen_pkts(dev) |
| | | self.tm_pkts = self.gen_pkts(dev) |
| | | |
| | | # 获取VC下面的遥测包数据 |
| | | for vc in vcs: |
| | | for vc in self.vcs: |
| | | # 此VC下的遥测包过滤 |
| | | _vc_pkts = filter(lambda it: it['vcs'].__contains__(vc['id']), vc_pkts) |
| | | _vc_pkts = filter(lambda it: it['vcs'].__contains__(vc['id']), self.vc_pkts) |
| | | for _pkt in _vc_pkts: |
| | | # 判断遥测包是否有详细定义 |
| | | if not next(filter(lambda it: it['name'] == _pkt['name'] and it['hasParams'], tm_pkts), None): |
| | | if not next(filter(lambda it: it['name'] == _pkt['name'] and it['hasParams'], self.tm_pkts), None): |
| | | continue |
| | | # 获取包详情 |
| | | _pkt = self.gen_pkt_details(_pkt['name'], _pkt['id']) |
| | | epdu = next(filter(lambda it: it['name'] == '数据域', vc['children']), None) |
| | | if epdu and _pkt: |
| | | _pkt['children'] = _pkt['datas'] |
| | | _last_par = _pkt['children'][len(_pkt['children']) - 1] |
| | | _pkt['length'] = (_last_par['pos'] + _last_par['length']) |
| | | _pkt['pos'] = 0 |
| | | if 'children' not in epdu: |
| | | epdu['children'] = [] |
| | | # 添加解析规则后缀防止重复 |
| | | _pkt['id'] = _pkt['id'] + '_' + vc['VCID'] |
| | | # 给包名加代号前缀 |
| | | if not _pkt['name'].startswith(_pkt['id']): |
| | | _pkt['name'] = _pkt['id'] + '_' + _pkt['name'] |
| | | epdu['children'].append(_pkt) |
| | | apid_node = next(filter(lambda it: it['name'].__contains__('应用过程'), _pkt['headers']), None) |
| | | ser_node = next(filter(lambda it: it['name'] == '服务', _pkt['headers']), None) |
| | | sub_ser_node = next(filter(lambda it: it['name'] == '子服务', _pkt['headers']), None) |
| | | _pkt['vals'] = \ |
| | | f"{apid_node['content']}/{int(ser_node['content'], 16)}/{int(sub_ser_node['content'], 16)}/" |
| | | # 重新计数起始偏移 |
| | | self.compute_length_pos(self.cadu['children']) |
| | | |
| | | def compute_length_pos(self, items: list): |
| | | length = 0 |
| | | pos = 0 |
| | | for child in items: |
| | | if 'children' in child: |
| | | self.compute_length_pos(child['children']) |
| | | child['pos'] = pos |
| | | if 'length' in child and isinstance(child['length'], int): |
| | | length = length + child['length'] |
| | | pos = pos + child['length'] |
| | | # node['length'] = length |
| | | |
| | | def gen_insert_domain_params(self,dev): |
| | | _msg = """ |
| | |
| | | } |
| | | ] |
| | | """ |
| | | |
| | | def validation(gen_text): |
| | | params = json.loads(gen_text) |
| | | assert isinstance(params, list), '插入域参数列表数据结构最外层必须是数组' |
| | | assert len(params), '插入域参数列表不能为空' |
| | | result = self.call_model(_msg, 'out/'+dev.code+'_插入域参数列表.json', ['这里是文档中抽取的内容'], validation) |
| | | print('插入域参数列表:' + result) |
| | | |
| | | doc_text = self.get_text_with_entity(['插入域']) |
| | | result = self.call_model(_msg, 'out/' + dev.code + '_插入域参数列表.json', doc_text, |
| | | validation) |
| | | Log.info('插入域参数列表:' + result) |
| | | return json.loads(result) |
| | | |
| | | def gen_vc(self,dev): |
| | | _msg = """ |
| | |
| | | def validation(gen_text): |
| | | vcs = json.loads(gen_text) |
| | | assert next(filter(lambda it: re.match('^[0-1]+$', it['VCID']), vcs)), '生成的VCID必须是二进制' |
| | | result = self.call_model(_msg, 'out/'+dev.code+'_虚拟信道.json', ['这里是文档中抽取的内容'], validation) |
| | | print('虚拟信道:' + result) |
| | | doc_text = self.get_text_with_entity(['虚拟信道定义']) |
| | | result = self.call_model(_msg, 'out/' + dev.code + '_虚拟信道.json', doc_text, validation) |
| | | Log.info('虚拟信道:' + result) |
| | | return json.loads(result) |
| | | |
| | | def gen_pkt_vc(self,dev): |
| | |
| | | }, |
| | | ] |
| | | """ |
| | | |
| | | def validation(gen_text): |
| | | pkts = json.loads(gen_text) |
| | | assert len(pkts), 'VC源包列表不能为空' |
| | | result = self.call_model(_msg, 'out/'+dev.code+'_遥测VC源包.json', ['这里是文档中抽取的内容'], validation) |
| | | print('遥测源包所属虚拟信道:' + result) |
| | | return json.loads(result) |
| | | |
| | | text = self.call_model(_msg, 'out/' + dev.code + '_遥测源包下传时机.json', ['遥测源包下传时机'], validation) |
| | | Log.info('遥测源包所属虚拟信道:' + text) |
| | | return json.loads(text) |
| | | |
| | | def gen_pkts(self,dev): |
| | | _msg = """ |
| | |
| | | ] |
| | | """ |
| | | result = self.call_model(_msg, 'out/'+dev.code+'_源包列表.json', ['这里是文档中抽取的内容']) |
| | | print('遥测源包列表:' + result) |
| | | Log.info('遥测源包列表:' + result) |
| | | return json.loads(result) |
| | | |
| | | def gen_pkt_details(self, pkt_name, pkt_id): |
| | | cache_file = f'out/数据包-{pkt_name}.json' |
| | | if not os.path.isfile(cache_file): |
| | | # 先问最后一个参数的字节位置 |
| | | Log.info(f'遥测源包“{pkt_name}”信息:') |
| | | _msg = f""" |
| | | # 指令 |
| | | 我需要从文档中提取遥测源包的最后一个参数的bit位置和数据域参数个数,你要帮我完成参数bit位置和数据域参数个数的提取。 |
| | |
| | | ] |
| | | """ |
| | | |
| | | |
| | | def validation(gen_text): |
| | | _pkt = json.loads(gen_text) |
| | | with open(f'out/tmp/{time.time()}.json', 'w') as f: |
| | | f.write(gen_text) |
| | | assert 'headers' in _pkt, '包结构中必须包含headers字段' |
| | | assert 'datas' in _pkt, '包结构中必须包含datas字段' |
| | | print(f'参数个数:{len(_pkt["datas"])}') |
| | | Log.info(f'参数个数:{len(_pkt["datas"])}') |
| | | # assert par_num == len(_pkt['datas']), f'数据域参数个数不对!预计{par_num}个,实际{len(_pkt["datas"])}' |
| | | assert last_par_pos == _pkt['datas'][-1]['pos'], '最后一个参数的字节位置不对!' |
| | | |
| | | result = self.call_model(_msg, f'out/数据包-{pkt_name}.json', [], ['这里是文档中抽取的内容'], validation) |
| | | print(f'数据包“{pkt_name}”信息:'+result) |
| | | Log.info(f'数据包“{pkt_name}”信息:' + result) |
| | | pkt = json.loads(result) |
| | | else: |
| | | pkt = json.loads(utils.read_from_file(cache_file)) |
| | | pkt_len = 0 |
| | | for par in pkt['datas']: |
| | | par['pos'] = pkt_len |
| | | pkt_len += par['length'] |
| | | pkt['length'] = pkt_len |
| | | return pkt |
| | | |
| | | def gen_bus(self): |
| | | _msg = """ |
| | |
| | | json.loads(gen_text) |
| | | |
| | | result = self.call_model(_msg, 'out/总线.json', ['这里是文档中抽取的内容'], validation) |
| | | print('总线数据包:' + result) |
| | | Log.info('总线数据包:' + result) |
| | | |
| | | pkts = json.loads(result) |
| | | # 筛选经总线的数据包 |
| | |
| | | pkts = list(filter(lambda it: it['apid'], pkts)) |
| | | |
| | | pkts2 = [] |
| | | # todo 这一步应该通过数据库筛选,数据库中已经有所有遥测包以及遥测包对应的定义段落文本 |
| | | for pkt in pkts: |
| | | if self.pkt_in_tm_pkts(pkt["name"]): |
| | | pkts2.append(pkt) |
| | | for pkt in pkts2: |
| | | self.gen_pkt_details(pkt['name'], pkt['id']) |
| | | _pkt = self.gen_pkt_details(pkt['name'], pkt['id']) |
| | | if _pkt: |
| | | pkt['children'] = [] |
| | | pkt['children'].extend(_pkt['datas']) |
| | | pkt['length'] = _pkt['length'] |
| | | self.bus_pkts = pkts |
| | | |
| | | def pkt_in_tm_pkts(self, pkt_name): |
| | | _msg = f""" |
| | |
| | | 有 |
| | | """ |
| | | text = self.call_model(_msg, f'out/pkts/有无数据包-{pkt_name}.txt', ['这里是文档中抽取的内容']) |
| | | print(f'文档中有无“{pkt_name}”的字段描述:'+ text) |
| | | Log.info(f'文档中有无“{pkt_name}”的字段描述:' + text) |
| | | return text == '有' |
| | | |
| | | # 遥测-end |
| | | # endregion 遥测-end |
| | | |
| | | # 遥控-start |
| | | def handle_yk_structured_data(self): |
| | | # region start 遥控 |
| | | def handle_tc_structured_data(self): |
| | | # 数据帧格式 |
| | | self.gen_tc_transfer_frame_format() |
| | | self.tc_frame = self.gen_tc_transfer_frame_format() |
| | | # 遥控包格式 |
| | | self.gen_tc_pkt_format() |
| | | self.tc_pkt_format = self.gen_tc_pkt_format() |
| | | # 遥控包列表 |
| | | pkts = self.gen_tc_transfer_pkts() |
| | | for pkt in pkts: |
| | | self.tc_pkts = self.gen_tc_transfer_pkts() |
| | | for pkt in self.tc_pkts: |
| | | # 遥控包数据区内容 |
| | | self.gen_tc_pkt_details(pkt) |
| | | |
| | |
| | | def validation(gen_text): |
| | | json.loads(gen_text) |
| | | |
| | | result = self.call_model(_msg, 'out/tc_transfer_frame.json', ['这里是文档中抽取的内容'],validation) |
| | | print('遥控帧格式:' + result) |
| | | text = self.call_model(_msg, 'out/tc_transfer_frame.json', ['这里是文档中抽取的内容'], validation) |
| | | result: dict = json.loads(text) |
| | | format_text = utils.read_from_file('tpl/tc_transfer_frame.json') |
| | | format_text = utils.replace_tpl_paras(format_text, result) |
| | | frame = json.loads(format_text) |
| | | Log.info('遥控帧格式:' + format_text) |
| | | return frame |
| | | |
| | | def gen_tc_pkt_format(self): |
| | | _msg = ''' |
| | |
| | | def validation(gen_text): |
| | | json.loads(gen_text) |
| | | |
| | | result = self.call_model(_msg, 'out/tc_transfer_pkt.json', ['这里是文档中抽取的内容'],validation) |
| | | print('遥控包格式:' + result) |
| | | text = self.call_model(_msg, 'out/tc_transfer_pkt.json', ['这里是文档中抽取的内容'], validation) |
| | | result = json.loads(text) |
| | | |
| | | format_text = utils.read_from_file('tpl/tc_pkt_format.json') |
| | | format_text = utils.replace_tpl_paras(format_text, result) |
| | | pkt_format = json.loads(format_text) |
| | | Log.info('遥控包格式:' + format_text) |
| | | return pkt_format |
| | | |
| | | def gen_tc_transfer_pkts(self): |
| | | _msg = ''' |
| | |
| | | def validation(gen_text): |
| | | json.loads(gen_text) |
| | | |
| | | result = self.call_model(_msg, 'out/tc_transfer_pkts.json', ['这里是文档中抽取的内容'],validation) |
| | | print('遥控包列表:' + result) |
| | | text = self.call_model(_msg, 'out/tc_transfer_pkts.json', ['这里是文档中抽取的内容'], validation) |
| | | Log.info('遥控包列表:' + text) |
| | | return json.loads(text) |
| | | |
| | | def gen_tc_pkt_details(self, pkt): |
| | | tc_name = pkt['name'] |
| | |
| | | def validation(gen_text): |
| | | json.loads(gen_text) |
| | | |
| | | result = self.call_model(_msg, f'out/遥控指令数据域-{tc_code}-{utils.to_file_name(tc_name)}.json', ['这里是文档中抽取的内容'],validation) |
| | | print('遥控指令数据域:' + result) |
| | | result = self.call_model(_msg, f'out/遥控指令数据域-{tc_code}-{utils.to_file_name(tc_name)}.json', |
| | | ['这里是文档中抽取的内容'], validation) |
| | | Log.info('遥控指令数据域:' + result) |
| | | |
| | | # 遥控-end |
| | | # endregion 遥控-end |