lyg
2025-04-08 1e85c429ceaad860aba16d1f518160d263c094c0
生成指令帧和包格式结构
4个文件已修改
1 文件已复制
10个文件已添加
2 文件已重命名
4236 ■■■■■ 已修改文件
.gitignore 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
data_templates.py 337 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
db_struct_flow.py 1218 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/db/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/db/data_creator.py 327 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/db/db_helper.py 408 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/db/models.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/markitdown/__about__.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/markitdown/__init__.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/markitdown/__main__.py 82 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/markitdown/_markitdown.py 1708 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
knowledgebase/utils.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
main.py 32 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
prompts.json 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
requirements.txt 补丁 | 查看 | 原始文档 | blame | 历史
tc_frame_format.json 77 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
.gitignore
@@ -1,5 +1,7 @@
/db.db
/out
/out_bak
/doc
/datas
/.conda
/docs
/out*
data_templates.py
New file
@@ -0,0 +1,337 @@
vc_pkt_fields = [
    {
        "name": "版本号",
        "id": "Ver",
        "pos": 0,
        "length": 3,
        "type": "para",
        "content": "0",
        "dataTy": "INVAR"
    },
    {
        "name": "类型",
        "id": "TM_Type",
        "pos": 3,
        "length": 1,
        "type": "para",
        "content": "0",
        "dataTy": "INVAR"
    },
    {
        "name": "副导头标志",
        "id": "Vice_Head",
        "pos": 4,
        "length": 1,
        "type": "para",
        "dataTy": "RANDOM"
    },
    {
        "name": "应用过程标识符",
        "id": "Proc_Sign",
        "pos": 5,
        "length": 11,
        "type": "para",
        "dataTy": "ENUM",
        "is_key": True
    },
    {
        "name": "分组标志",
        "id": "Group_Sign",
        "pos": 16,
        "length": 2,
        "type": "para",
        "content": "3",
        "dataTy": "INVAR"
    },
    {
        "name": "包序列计数",
        "id": "Package_Count",
        "pos": 18,
        "length": 14,
        "type": "para",
        "dataTy": "RANDOM"
    },
    {
        "name": "包长",
        "id": "Pack_Len",
        "pos": 32,
        "length": 16,
        "type": "para",
        "content": "1Bytes/EPDU_Data.length - 1",
        "dataTy": "LEN"
    },
    {
        "name": "服务",
        "id": "service",
        "pos": 48,
        "length": 8,
        "type": "para",
        "content": None,
        "dataTy": "ENUM",
        "is_key": True
    },
    {
        "name": "子服务",
        "id": "subservice",
        "pos": 56,
        "length": 8,
        "type": "para",
        "content": None,
        "dataTy": "ENUM",
        "is_key": True
    },
    {
        "name": "数据域",
        "id": "EPDU_DATA",
        "pos": 64,
        "length": "length-current",
        "type": "any",
        "children": []
    }
]
def get_tm_frame(data):
    return {
        "name": "遥测帧",
        "id": "TM_Frame",
        "type": "enc",
        "pos": 0,
        "length": 8192,
        "children": [
            {
                "name": "同步头",
                "id": "Sync_Head",
                "type": "para",
                "pos": 0,
                "content": "0x1ACFFC1D",
                "dataTy": "INVAR",
                "length": 32
            },
            {
                "name": "VCDU",
                "id": "VCDU",
                "type": "enc",
                "pos": 32,
                "length": 8160,
                "content": "1",
                "children": [
                    {
                        "name": "传输帧版本号",
                        "id": "Ver_",
                        "type": "para",
                        "pos": 0,
                        "length": 2,
                        "content": "01B",
                        "dataTy": "INVAR"
                    },
                    {
                        "name": "航天器标识符SCID",
                        "id": "SCID",
                        "type": "para",
                        "pos": 2,
                        "length": 8,
                        "content": "0x01",
                        "dataTy": "INVAR"
                    },
                    {
                        "name": "虚拟信道标识符VCID",
                        "id": "VCID",
                        "type": "para",
                        "pos": 10,
                        "length": 6,
                        "content": data['vcidContent'],
                        "dataTy": "ENUM"
                    },
                    {
                        "name": "VCDU计数",
                        "id": "VCDUCnt",
                        "type": "para",
                        "pos": 16,
                        "length": 24,
                        "content": "0:16777215:1",
                        "dataTy": "INCREASE"
                    },
                    {
                        "name": "回放标志",
                        "id": "PlaybackFlag",
                        "type": "para",
                        "pos": 40,
                        "length": 1,
                        "content": "回放,1 不回放,0",
                        "dataTy": "ENUM"
                    },
                    {
                        "name": "保留位",
                        "id": "spare",
                        "type": "para",
                        "pos": 41,
                        "length": 7,
                        "content": "0",
                        "dataTy": "INVAR"
                    },
                    {
                        "name": "插入域",
                        "id": "InsertionDomain",
                        "type": "linear",
                        "pos": 48,
                        "length": 640,
                        "content": None,
                        "children": data['insertDomain']
                    },
                    {
                        "name": "传输帧数据域",
                        "id": "DataDomain",
                        "type": "enc",
                        "pos": 688,
                        "length": 7456
                    },
                    {
                        "name": "传输帧尾",
                        "id": "FrameTail",
                        "type": "para",
                        "pos": 8144,
                        "length": 16,
                        "content": "CRC_check;1;All;this.START+0;this.CURRENT-1",
                        "dataTy": "CHECKSUM"
                    }
                ]
            }
        ]
    }
def get_bus_datas(pkts):
    return [
        {
            "name": "传输消息类型",
            "id": "BMessageType",
            "type": "para",
            "pos": 0,
            "length": 8,
            "content": "广播BC/RT_传统,0x00 广播RT/RT_传统,0xFF BC/RT_传统,0x11 RT/RT_传统,0x12 时分复用模式的BC/RT,0x21",
            "dataTy": "ENUM",
            "is_key": True
        },
        {
            "name": "消息传输格式及消息体",
            "id": "BMessagePro",
            "type": "enc",
            "pos": 8,
            "length": "length-current",
            "children": [
                {
                    "name": "BMessagePro",
                    "id": "BMessagePro",
                    "type": "enc",
                    "pos": 0,
                    "length": "length-current",
                    "vals": "0x11/",
                    "children": [
                        {
                            "id": "BRT_Add",
                            "name": "RT地址",
                            "type": "para",
                            "pos": 0,
                            "content": "1,1 2,2 3,3 4,4 5,5 6,6 7,7 8,8 9,9 10,10 11,11 12,12 13,13 14,14 15,15 16,16 17,17 18,18 19,19 20,20 21,21 22,22 23,23 24,24 25,25 26,26 27,27 28,28 29,29 30,30 31,31",
                            "length": 8,
                            "dataTy": "ENUM",
                            "is_key": True,
                        },
                        {
                            "id": "BSub_add",
                            "name": "子地址",
                            "type": "para",
                            "pos": 8,
                            "content": "1,1 2,2 3,3 4,4 5,5 6,6 7,7 8,8 9,9 10,10 11,11 12,12 13,13 14,14 15,15 16,16 17,17 18,18 19,19 20,20 21,21 22,22 23,23 24,24 25,25 26,26 27,27 28,28 29,29 30,30 31,31",
                            "length": 8,
                            "dataTy": "ENUM",
                            "is_key": True,
                        },
                        {
                            "id": "BT_R_M",
                            "name": "传输方向/方式代号",
                            "type": "para",
                            "pos": 16,
                            "content": "RT2BC,0xAA BC2RT,0xBB 方式字,0xCC",
                            "length": 8,
                            "dataTy": "ENUM",
                            "is_key": True
                        },
                        {
                            "id": "BFrame",
                            "name": "帧号",
                            "type": "para",
                            "pos": 24,
                            "content": "0:19:1",
                            "length": 8,
                            "dataTy": "INCREASE",
                            "is_key": True
                        },
                        {
                            "id": "BusA_B",
                            "name": "总线A/B",
                            "type": "para",
                            "pos": 32,
                            "content": "A总线,1 B总线,0",
                            "length": 8,
                            "dataTy": "ENUM"
                        },
                        {
                            "id": "BErrorFlag",
                            "name": "Error Flag(status word.bit12)",
                            "type": "para",
                            "pos": 40,
                            "content": None,
                            "length": 16,
                            "dataTy": "RANDOM"
                        },
                        {
                            "id": "BControlWord",
                            "name": "ControlWord",
                            "type": "para",
                            "pos": 56,
                            "content": None,
                            "length": 16,
                            "dataTy": "RANDOM"
                        },
                        {
                            "id": "BCommandWord",
                            "name": "CommandWord",
                            "type": "para",
                            "pos": 72,
                            "content": None,
                            "length": 16,
                            "dataTy": "RANDOM"
                        },
                        {
                            "id": "BStatusWord",
                            "name": "StatusWord",
                            "type": "para",
                            "pos": 88,
                            "content": None,
                            "length": 16,
                            "dataTy": "RANDOM"
                        },
                        {
                            "id": "BTime",
                            "name": "传输时间",
                            "type": "para",
                            "pos": 104,
                            "content": None,
                            "length": 64,
                            "dataTy": "RANDOM"
                        },
                        {
                            "id": "SA7_258",
                            "name": "综合数管单元数据块传输",
                            "type": "any",
                            "pos": 168,
                            "length": "length-current",
                            "children": pkts
                        }
                    ]
                }
            ]
        }
    ]
db_struct_flow.py
@@ -1,20 +1,52 @@
import os
import time
from datetime import datetime
from openai import OpenAI
from pathlib import Path
import re
import json
import copy
from datas import pkt_vc, pkt_datas, dev_pkt, proj_data
from db.db_generate import create_project, create_device, create_data_stream
from db.models import TProject, TDevice
import data_templates
from knowledgebase.db.db_helper import create_project, create_device, create_data_stream, \
    update_rule_enc, create_extend_info, create_ref_ds_rule_stream, create_ins_format
from knowledgebase.db.data_creator import create_prop_enc, create_enc_pkt, get_data_ty, create_any_pkt
from knowledgebase.db.models import TProject
file_map = {
    "文档合并": "./doc/文档合并.md",
    "遥测源包设计报告": "./doc/XA-5D无人机分系统探测源包设计报告(公开).md",
    "遥测大纲": "./doc/XA-5D无人机探测大纲(公开).md",
    "总线传输通信帧分配": "./doc/XA-5D无人机1314A总线传输通信帧分配(公开).md",
    "应用软件用户需求": "./doc/XA-5D无人机软件用户需求(公开).docx.md",
    "指令格式": "./doc/ZL格式(公开).docx.md"
}
# file_map = {
#     "遥测源包设计报告": "./docs/HY-4A数管分系统遥测源包设计报告 Z 240824 更改3(内部) .docx.md",
#     "遥测大纲": "./docs/HY-4A卫星遥测大纲 Z 240824 更改3(内部).docx.md",
#     "总线传输通信帧分配": "./docs/HY-4A卫星1553B总线传输通信帧分配 Z 240824 更改3(内部).docx.md",
#     "应用软件用户需求": "./docs/HY-4A数管分系统应用软件用户需求(星务管理分册) Z 240831 更改4(内部).docx.md"
# }
# file_map = {
#     "文档合并": "./doc/文档合并.md",
#     "遥测源包设计报告": "./doc/XA-5D无人机分系统探测源包设计报告(公开).md",
#     "遥测大纲": "./doc/XA-5D无人机探测大纲(公开).md",
#     "总线传输通信帧分配": "./doc/XA-5D无人机1314A总线传输通信帧分配(公开).md"
# }
BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
API_KEY = 'sk-15ecf7e273ad4b729c7f7f42b542749e'
MODEL_NAME = 'qwen-long'
MODEL_NAME = 'qwen2.5-14b-instruct-1m'
# BASE_URL = 'http://10.74.15.164:11434/v1/'
# API_KEY = 'ollama'
# MODEL_NAME = 'qwen2.5:32b-128k'
# BASE_URL = 'http://10.74.15.164:1001/api'
# API_KEY = 'sk-a909385bc14d4491a718b6ee264c3227'
# MODEL_NAME = 'qwen2.5:32b-128k'
USE_CACHE = True
assistant_msg = """
# 角色
你是一个专业的文档通信分析师,擅长进行文档分析和通信协议分析,同时能够解析 markdown 类型的文档。拥有成熟准确的文档阅读与分析能力,能够妥善处理多文档间存在引用关系的复杂情况。
@@ -25,64 +57,259 @@
2. 分析文档的结构、主题和重点内容,同样只依据文档进行表述。
3. 如果文档间存在引用关系,梳理引用脉络,明确各文档之间的关联,且仅呈现文档中体现的内容。
### 技能 2:通信协议分析
1. 接收通信协议相关信息,理解协议的规则和流程,仅依据所给信息进行分析。
## 背景知识
###软件主要功能与运行机制总结如下:
1. 数据采集和处理:
   DIU负责根据卫星的工作状态或模式提供遥测数据,包括模拟量(AN)、总线信号(BL)以及温度(TH)和数字量(DS),并将这些信息打包,通过总线发送给SMU。
   SMU则收集硬通道上的遥测参数,并通过总线接收DIU采集的信息。
2. 多路复用与数据传输:
   遥测源包被组织成E-PDU,进一步复用为M-PDU,并填充到VCDU中构成遥测帧。
   利用CCSDS AOS CADU格式进行遥测数据的多路复用和传输。
3. 虚拟信道(VC)调度机制:
   通过常规遥测VC、突发数据VC、延时遥测VC、记录数据VC以及回放VC实现不同类型的数据下传。
4. 遥控指令处理:
   上行遥控包括直接指令和间接指令,需经过格式验证后转发给相应单机执行。
   遥控帧通过特定的虚拟信道(VC)进行传输。
这些知识需要你记住,再后续的处理中可以帮助你理解要处理的数据。
## 目标导向
1. 通过对文档和通信协议的分析,为用户提供清晰、准确的数据结构,帮助用户更好地理解和使用相关信息。
2. 以 JSON 格式组织输出内容,确保数据结构的完整性和可读性。
## 规则
1. 每一个型号都会有一套文档,需准确判断是否为同一个型号的文档后再进行整体分析。
2. 每次只分析同一个型号。
3. 大多数文档结构为:型号下包含设备,设备下包含数据流,数据流下包含数据帧,数据帧中有一块是包域,包域中会挂载各种类型的数据包。
4. 这些文档都是数据传输协议的描述,在数据流、数据帧、数据包等传输实体中都描述了各个字段的分布和每个字段的大小,且大小单位不统一,需理解这些单位,并将所有输出单位统一为 bits,统一使用length表示。
5. 如果有层级,使用树形 JSON 输出,子节点 key 使用children;需保证相同类型的数据结构统一,并且判断每个层级是什么类型,输出类型字段,类型字段的 key 使用 type ;例如当前层级为字段时使用:type:"field";当前层级为设备时使用:type:"device"
6.名称相关的字段的 key 使用name;代号或者唯一标识相关的字段的key使用id;序号相关的字段的key使用number;其他没有举例的字段使用精简的翻译作为字段的key;
7.探测帧为CADU,其中包含同步头和VCDU,按照习惯需要使用VCDU层级包含下一层级中传输帧主导头、传输帧插入域、传输帧数据域、传输帧尾的结构
1. 每一个型号都会有一套文档,需准确判断是否为同一个型号的文档后再进行整体分析,每次只分析同一个型号的文档。
2. 大多数文档结构为:型号下包含设备,设备下包含数据流,数据流下包含数据帧,数据帧中有一块是包域,包域中会挂载各种类型的数据包。
3. 文档都是对于数据传输协议的描述,在数据流、数据帧、数据包等传输实体中都描述了各个字段的分布、各个字段的大小和位置等信息,且大小单位不统一,需理解这些单位,并将所有输出单位统一为 bits,长度字段使用 length 表示,位置字段使用 pos 表示,如果为变长使用“"变长"”表示。
4. 如果有层级,使用树形 JSON 输出,如果有子节点,子节点 key 使用children;需保证一次输出的数据结构统一,并且判断每个层级是什么类型,输出类型字段(type),类型字段的 key 使用 type,类型包括:型号(project)、设备(dev)、封装包(enc)、线性包(linear)、参数(para),封装包子级有数据包,所以type为enc,线性包子级只有参数,所以type为linear;每个层级都包含偏移位置(pos),每个层级的偏移位置从0开始。
5. 名称相关的字段的 key 使用name;代号、编号或者唯一标识相关的字段的key使用id,id由数字、英文字母、下划线组成且以英文字母开头,长度尽量简短;序号相关的字段的key使用number;偏移位置相关字段的key使用pos;其他没有举例的字段使用精简的翻译作为字段的key;每个结构必须包含name和id。
6. 遥测帧为CADU,其中包含同步头和VCDU,按照习惯需要使用VCDU层级嵌套传输帧主导头、传输帧插入域、传输帧数据域、传输帧尾的结构。
7. 数据包字段包括:name、id、type、pos、length、children;参数字段包括:name、id、pos、type、length;必须包含pos和length字段。
8. 常用id参考:遥测(TM)、遥控(TC)、总线(BUS)、版本号(Ver)、应用过程标识(APID)。
9. 注意:一定要记得morkdown文档中会将一些特殊字符进行转义,以此来保证文档的正确性,这些转义符号(也就是反斜杠‘\’)不需要在结果中输出。
10. 以 JSON 格式组织输出内容,确保数据结构的完整性和可读性,注意:生成的JSON语法格式必须符合json规范,避免出现错误。
## 限制:
- 所输出的内容必须按照JSON格式进行组织,不能偏离框架要求,且严格遵循文档内容进行输出,只输出 JSON ,不要输出其它文字。
- 不输出任何注释等描述性信息
- 不输出任何注释等描述性信息。
"""
g_completion = None
def read_from_file(cache_file):
    with open(cache_file, 'r', encoding='utf-8') as f:
        text = f.read()
    return text
def save_to_file(text, file_path):
    if USE_CACHE:
        with open(file_path, 'w', encoding='utf-8') as f:
            f.write(text)
json_pat = re.compile(r'```json(.*?)```', re.DOTALL)
def remove_markdown(text):
    # 使用正则表达式提取json文本
    try:
        return json_pat.findall(text)[0]
    except IndexError:
        return text
def rt_pkt_map_gen(pkt, trans_ser, rt_pkt_map, pkt_id, vals):
    # 逻辑封装包,数据块传输的只有一个,取数的根据RT地址、子地址和帧号划分
    frame_num = pkt['frameNum']
    if trans_ser == '数据块传输':
        # 数据块传输根据RT地址和子地址划分
        key = f'{pkt["rt"]}_{pkt["subAddr"]}'
        name = f'{pkt["rt"]}_{pkt["subAddr"]}_{trans_ser}'
    else:
        # 取数根据RT地址、子地址和帧号划分
        key = f'{pkt["rt"]}_{pkt["subAddr"]}_{pkt["frameNum"]}'
        name = f'{pkt["rt"]}_{pkt["subAddr"]}_帧号{frame_num}_{trans_ser}'
    #
    if key not in rt_pkt_map:
        rt_pkt_map[key] = {
            "name": name,
            "id": pkt_id,
            "type": "logic",
            "pos": 0,
            "content": "CYCLEBUFFER,Message,28,0xFFFF",
            "length": "",
            "vals": vals,
            "children": []
        }
    frame = f'{pkt["frameNum"]}'
    interval = f'{pkt["interval"]}'.replace(".", "_")
    if trans_ser == '取数':
        _key = f'RT{pkt["rtAddr"]}Frame{frame.replace("|", "_")}_Per{interval}'
    else:
        # 数据块传输
        if pkt['burst']:
            _key = f'RT{pkt["rtAddr"]}FrameALL'
        else:
            _key = f'RT{pkt["rtAddr"]}Frame{frame}Per{interval}'
    _pkt = next(filter(lambda it: it['name'] == _key, rt_pkt_map[key]['children']), None)
    if _pkt is None:
        ext_info = None
        if trans_ser == '数据块传输' and not pkt['burst']:
            # 数据块传输且有周期的包需要
            ext_info = [{"id": "PeriodTriger", "name": "时分复用总线触发属性", "val": f"{pkt['interval']}"},
                        {"id": "FrameNumber", "name": "时分复用协议帧号", "val": frame}]
        _pkt = {
            "name": _key,
            "id": _key,
            "type": "enc",
            "pos": 0,
            "content": "1:N;EPDU",
            "length": "length",
            "extInfo": ext_info,
            "children": [
                {
                    "id": "C02_ver",
                    "name": "遥测版本",
                    "type": "para",
                    "pos": 0,
                    "length": 3,
                    "dataTy": "INVAR",
                    "content": "0"
                },
                {
                    "id": "C02_type",
                    "name": "类型",
                    "type": "para",
                    "pos": 3,
                    "length": 1,
                    "dataTy": "INVAR",
                    "content": "0"
                },
                {
                    "id": "C02_viceHead",
                    "name": "副导头标识",
                    "type": "para",
                    "pos": 4,
                    "length": 1,
                    "content": "1",
                    "dataTy": "INVAR"
                },
                {
                    "id": "C02_PackSign",
                    "name": "APID",
                    "type": "para",
                    "pos": 5,
                    "length": 11,
                    "is_key": True,
                    "dataTy": "ENUM"
                },
                {
                    "id": "C02_SerCtr_1",
                    "name": "序列标记",
                    "type": "para",
                    "pos": 16,
                    "length": 2,
                    "content": "3"
                },
                {
                    "id": "C02_SerCtr_2",
                    "name": "包序计数",
                    "type": "para",
                    "pos": 18,
                    "length": 14,
                    "content": "0:167772:1",
                    "dataTy": "INCREASE"
                },
                {
                    "id": "C02_PackLen",
                    "name": "包长",
                    "type": "para",
                    "pos": 32,
                    "length": 16,
                    "content": "1Bytes/C02_Data.length+1",
                    "dataTy": "LEN"
                },
                {
                    "id": "C02_Ser",
                    "name": "服务",
                    "type": "para",
                    "pos": 48,
                    "length": 8,
                    "is_key": True,
                    "dataTy": "ENUM"
                },
                {
                    "id": "C02_SubSer",
                    "name": "子服务",
                    "type": "para",
                    "pos": 56,
                    "length": 8,
                    "is_key": True,
                    "dataTy": "ENUM"
                },
                {
                    "id": "C02_Data",
                    "name": "数据区",
                    "type": "linear",
                    "pos": 64,
                    "length": 'length-current',
                    "children": []
                },
            ]
        }
        rt_pkt_map[key]['children'].append(_pkt)
    # 数据区下面的包
    data_area = next(filter(lambda it: it['name'] == '数据区', _pkt['children']), None)
    ser_sub_ser: str = pkt['service']
    ser = ''
    sub_ser = ''
    if ser_sub_ser:
        nums = re.findall(r'\d+', ser_sub_ser)
        if len(nums) == 2:
            ser = nums[0]
            sub_ser = nums[1]
    if 'children' not in pkt:
        pkt['children'] = []
    p_name = pkt['id'] + '_' + pkt['name']
    data_area['children'].append({
        "name": p_name,
        "id": pkt["id"],
        "type": "linear",
        "pos": 0,
        "length": pkt["length"],
        "vals": f"0x{pkt['apid']}/{ser}/{sub_ser}/",
        "children": pkt['children'],
    })
def build_vcid_content(vcs):
    _vcs = []
    for vc in vcs:
        _vcs.append(vc['name'] + ',' + vc['VCID'])
    return ' '.join(_vcs)
class DbStructFlow:
    files = []
    file_objects = []
    # 工程
    proj: TProject = None
    # 遥测源包列表,仅包名称、包id和hasParams
    tm_pkts = []
    # vc源包
    vc_pkts = []
    def __init__(self, doc_files):
    def __init__(self):
        self.client = OpenAI(
            api_key=API_KEY,
            base_url=BASE_URL,
            # api_key="ollama",
            # base_url="http://192.168.1.48:11434/v1/",
        )
        if doc_files:
            self.files = doc_files
        self.load_file_objs()
        self.delete_all_files()
        self.upload_files()
    def load_file_objs(self):
        file_stk = self.client.files.list()
        self.file_objects = file_stk.data
    def delete_all_files(self):
        for file_object in self.file_objects:
            self.client.files.delete(file_object.id)
    def upload_file(self, file_path):
        file_object = self.client.files.create(file=Path(file_path), purpose="file-extract")
        return file_object
    def upload_files(self):
        self.file_objects = []
        for file_path in self.files:
            file_object = self.upload_file(file_path)
            self.file_objects.append(file_object)
    def run(self):
        # 生成型号结构
@@ -90,78 +317,24 @@
        # 生成数据流结构 CADU
        # 生成VCDU结构
        # 生成遥测数据包结构
        proj = self.gen_project([])
        # proj = TProject(C_PROJECT_PK='2e090a487c1a4f7f741be3a437374e2f')
        self.proj = self.gen_project()
        devs = self.gen_device([], proj)
        # with open('datas/设备列表.json', 'w', encoding='utf8') as f:
        #     json.dump(devs, f, ensure_ascii=False, indent=4)
        #
        # proj['devices'] = devs
        #
        # messages = []
        # cadu = self.gen_tm_frame(messages)
        # with open("datas/探测帧.json", 'w', encoding='utf8') as f:
        #     json.dump(cadu, f, ensure_ascii=False, indent=4)
        #
        # messages = []
        # vcs = self.gen_vc(messages)
        # with open('datas/虚拟信道.json', 'w', encoding='utf8') as f:
        #     json.dump(vcs, f, ensure_ascii=False, indent=4)
        #
        # messages = []
        # pkt_vcs = self.gen_pkt_vc(messages)
        # with open('datas/VC源包.json', 'w', encoding='utf8') as f:
        #     json.dump(pkt_vcs, f, ensure_ascii=False, indent=4)
        #
        # messages = []
        # dev_pkts = self.gen_dev_pkts(messages)
        # with open('datas/设备源包.json', 'w', encoding='utf8') as f:
        #     json.dump(dev_pkts, f, ensure_ascii=False, indent=4)
        #
        # messages = []
        # _pkts = self.gen_pkts()
        # pkts = []
        # for pkt in _pkts:
        #     _pkt = self.gen_pkt_details(pkt['name'])
        #     pkts.append(_pkt)
        # with open('datas/源包列表.json', 'w', encoding='utf8') as f:
        #     json.dump(pkts, f, ensure_ascii=False, indent=4)
        #
        # for dev in devs:
        #     ds = dev['data_streams'][0]
        #     _cadu = copy.deepcopy(cadu)
        #     ds['cadu'] = _cadu
        #     _vcdu = next(filter(lambda it: it['name'] == '传输帧', _cadu['children']))
        #     vcdu_data = next(filter(lambda it: it['name'] == '传输帧数据域', _vcdu['children']))
        #     _vcs = copy.deepcopy(vcs)
        #     vcdu_data['children'] = _vcs
        #     dev_pkt = next(filter(lambda it: it['name'] == dev['name'], dev_pkts), None)
        #     if dev_pkt is None:
        #         continue
        #     for pkt in dev_pkt['pkts']:
        #         for vc in _vcs:
        #             _pkt = next(
        #                 filter(lambda it: it['name'] == pkt['name'] and it['vcs'].__contains__(vc['code']), pkt_vcs),
        #                 None)
        #             if _pkt:
        #                 if vc.__contains__('pkts') is False:
        #                     vc['pkts'] = []
        #                 _pkt = next(filter(lambda it: it['name'] == _pkt['name'], pkts), None)
        #                 if _pkt:
        #                     vc['pkts'].append(_pkt)
        #
        # with open("datas/型号.json", 'w', encoding='utf8') as f:
        #     json.dump(proj, f, ensure_ascii=False, indent=4)
        devs = self.gen_device(self.proj)
        # self.gen_tc()
        return ''
    def _gen(self, msgs, msg):
    def _gen(self, msgs, msg, files=None):
        if files is None:
            files = [file_map['文档合并']]
        messages = [] if msgs is None else msgs
        doc_text = ''
        for file in files:
            doc_text += '\n' + read_from_file(file)
        if len(messages) == 0:
            # 如果是第一次提问加入文档
            # 如果是第一次提问加入system消息
            messages.append({'role': 'system', 'content': assistant_msg})
            for file_object in self.file_objects:
                messages.append({'role': 'system', 'content': 'fileid://' + file_object.id})
            messages.append({'role': 'user', 'content': "以下是文档内容:\n" + doc_text})
        messages.append({'role': 'user', 'content': msg})
        completion = self.client.chat.completions.create(
@@ -171,34 +344,71 @@
            temperature=0.0,
            top_p=0,
            timeout=30 * 60000,
            max_completion_tokens=1000000
            max_completion_tokens=1000000,
            seed=0
            # stream_options={"include_usage": True}
        )
        g_completion = completion
        text = ''
        for chunk in completion:
            if chunk.choices[0].delta.content is not None:
                text += chunk.choices[0].delta.content
                print(chunk.choices[0].delta.content, end="")
        print("")
        g_completion = None
        return text
    def gen_project(self, messages):
        _msg = f"""
根据文档输出型号信息,型号字段包括:名称和代号,仅输出型号的属性,不输出其他层级数据
        """
        print('型号信息:')
        text = self._gen(messages, _msg)
        messages.append({'role': 'assistant', 'content': text})
        text = self.remove_markdown(text)
        proj_dict = json.loads(text)
        # return proj_dict
    def generate_text(self, msg, cache_file, msgs=None, files=None, validation=None, try_cnt=5):
        if msgs is None:
            msgs = []
        if USE_CACHE and os.path.isfile(cache_file):
            text = read_from_file(cache_file)
        else:
            s = time.time()
            text = self._gen(msgs, msg, files)
            text = remove_markdown(text)
            if validation:
                try:
                    validation(text)
                except BaseException as e:
                    print(e)
                    if try_cnt <= 0:
                        raise RuntimeError('生成失败,重试次数太多,强制结束!')
                    return self.generate_text(msg, cache_file, msgs, files, validation, try_cnt - 1)
            save_to_file(text, cache_file)
            print(f'耗时:{time.time() - s}')
        return text
    def generate_tc_text(self, msg, cache_file, messages=None, files=None, validation=None, try_cnt=5):
        if messages is None:
            messages = []
        doc_text = ''
        for file in files:
            doc_text += '\n' + read_from_file(file)
        if len(messages) == 0:
            # 如果是第一次提问加入system消息
            messages.append({'role': 'user', 'content': "以下是文档内容:\n" + doc_text})
        return self.generate_text(msg, cache_file, messages, files, validation, try_cnt)
    def gen_project(self):
        #         _msg = """
        # 根据文档输出型号信息,型号字段包括:名称和代号。仅输出型号这一级。
        # 例如:{"name":"xxx","id":"xxx"}
        # """
        #         print('型号信息:')
        #         text = self.generate_text(_msg, 'out/型号信息.json', files=[file_map['应用软件用户需求']])
        #         proj_dict = json.loads(text)
        # 工程信息从系统获取
        proj_dict = {
            "id": "JB200001",
            "name": "HY-4A"
        }
        code = proj_dict['id']
        name = proj_dict['name']
        proj = create_project(code, name, code, name, "", datetime.now())
        return proj
    def gen_device(self, messages, proj):
    def gen_device(self, proj):
        """
        设备列表生成规则:
        1.如文档中有1553协议描述,加入1553设备
@@ -208,28 +418,29 @@
        设备类型:工控机[0]、1553B[1]
        :param messages:
        :param proj:
        :return:
        """
        proj_pk = proj.C_PROJECT_PK
        devices = []
        _msg = f"""
输出所有设备列表,设备字段包括名称(name)、代号(code),如果没有代号则使用名称的英文翻译缩写代替且缩写长度不超过5个字符,JSON格式,并且给每个设备增加三个字段,第一个字段hasTcTm“是否包含遥控遥测”,判断该设备是否包含遥控遥测的功能;第二个字段hasTemperatureAnalog“是否包含温度量、模拟量等数据的采集”,判断该设备是否包含温度量等信息的采集功能;第三个字段hasBus“是否是总线设备”,判断该设备是否属于总线设备,是否有RT地址;每个字段的值都使用true或false来表示。
仅输出JSON,不要输出JSON以外的任何字符。
输出分系统下的硬件产品(设备)列表,字段包括:名称(name)、代号(code),硬件产品名称一般会包含“管理单元”或者“接口单元”,如果没有代号则使用名称的英文缩写代替缩写长度不超过5个字符;
并且给每个硬件产品增加三个字段:第一个字段hasTcTm“是否包含遥控遥测”,判断该硬件产品是否包含遥控遥测的功能、
第二个字段hasTemperatureAnalog“是否包含温度量、模拟量等数据的采集”,判断该硬件产品是否包含温度量等信息的采集功能、
第三个字段hasBus“是否是总线硬件产品”,判断该设备是否属于总线硬件产品,是否有RT地址;每个字段的值都使用true或false来表示。
仅输出JSON,结构最外层为数组,数组元素为设备信息,不要输出JSON以外的任何字符。
        """
        print('设备列表:')
        text = self._gen(messages, _msg)
        text = self.remove_markdown(text)
        cache_file = 'out/设备列表.json'
        def validation(gen_text):
            _devs = json.loads(gen_text)
            assert isinstance(_devs, list), '数据结构最外层不是数组'
            assert next(filter(lambda it: it['name'].endswith('管理单元'), _devs), None), '生成的设备列表中没有管理单元'
        text = self.generate_text(_msg, cache_file, files=[file_map['应用软件用户需求']], validation=validation)
        devs = json.loads(text)
        hasBus = any(d['hasBus'] for d in devs)
        if hasBus:
            # 总线设备
            dev = create_device("B1553", "1553总线", '1', 'StandardProCommunicationDev', proj_pk)
            devices.append(dev)
            # 创建数据流
            ds_u153 = create_data_stream(proj_pk, dev.C_DEV_PK, 'ECSS上行总线数据', 'U153', 'B153', '0', 'E153', '001')
            ds_d153 = create_data_stream(proj_pk, dev.C_DEV_PK, 'ECSS下行总线数据', 'D153', 'B153', '1', 'E153', '001')
        # 类SMU设备,包含遥测和遥控功能,名称结尾为“管理单元”
        like_smu_devs = list(filter(lambda it: it['hasTcTm'] and it['name'].endswith('管理单元'), devs))
@@ -237,13 +448,31 @@
            dev = create_device(dev['code'], dev['name'], '0', 'StandardProCommunicationDev', proj.C_PROJECT_PK)
            devices.append(dev)
            # 创建数据流
            ds_tmfl = create_data_stream(proj_pk, dev.C_DEV_PK, 'AOS遥测', 'TMFL', 'TMFL', '1', 'TMFL', '001')
            ds_tcfl = create_data_stream(proj_pk, dev.C_DEV_PK, '遥控指令', 'TCFL', 'TCFL', '0', 'TCFL', '006')
            ds_tmfl, rule_stream, _ = create_data_stream(proj_pk, dev.C_DEV_PK, 'AOS遥测', 'TMF1', 'TMFL', '1', 'TMF1',
                                                         '001')
            self.gen_tm_frame(proj_pk, rule_stream.C_RULE_PK, ds_tmfl, rule_stream.C_PATH)
            # ds_tcfl, rule_stream, _ = create_data_stream(proj_pk, dev.C_DEV_PK, '遥控指令', 'TCFL', 'TCFL', '0', 'TCFL',
            #                                              '006')
        hasBus = any(d['hasBus'] for d in devs)
        if hasBus:
            # 总线设备
            dev = create_device("1553", "1553总线", '1', 'StandardProCommunicationDev', proj_pk)
            create_extend_info(proj_pk, "BusType", "总线类型", "ECSS_Standard", dev.C_DEV_PK)
            devices.append(dev)
            # 创建数据流
            ds_u153, rs_u153, rule_enc = create_data_stream(proj_pk, dev.C_DEV_PK, '上行总线数据', 'U15E', 'B153',
                                                            '0', '1553', '001')
            # 创建总线结构
            self.gen_bus(proj_pk, rule_enc, '1553', ds_u153, rs_u153.C_PATH, dev.C_DEV_NAME)
            ds_d153, rule_stream, rule_enc = create_data_stream(proj_pk, dev.C_DEV_PK, '下行总线数据', 'D15E', 'B153',
                                                                '1', '1553', '001', rs_u153.C_RULE_PK)
            create_ref_ds_rule_stream(proj_pk, rule_stream.C_STREAM_PK, rule_stream.C_STREAM_ID,
                                      rule_stream.C_STREAM_NAME, rule_stream.C_STREAM_DIR, rs_u153.C_STREAM_PK)
        # 类RTU设备,包含温度量和模拟量功能,名称结尾为“接口单元”
        like_rtu_devs = list(filter(lambda it: it['hasTemperatureAnalog'] and it['name'].endswith('接口单元'), devs))
        for dev in like_rtu_devs:
            dev = create_device(dev['code'], dev['name'], '0', 'StandardProCommunicationDev', proj.C_PROJECT_PK)
        # like_rtu_devs = list(filter(lambda it: it['hasTemperatureAnalog'] and it['name'].endswith('接口单元'), devs))
        # for dev in like_rtu_devs:
        #     dev = create_device(dev['code'], dev['name'], '0', 'StandardProCommunicationDev', proj.C_PROJECT_PK)
        # for dev in like_rtu_devs:
        #     dev = create_device(dev['code'], dev['name'], '0', '', proj.C_PROJECT_PK)
@@ -252,151 +481,626 @@
        #     ds_tmfl = create_data_stream(proj.C_PROJECT_PK, '温度量', 'TMFL', 'TMFL', '1', 'TMFL', '001')
        #     ds_tcfl = create_data_stream(proj.C_PROJECT_PK, '模拟量', 'TCFL', 'TCFL', '0', 'TCFL', '006')
        print()
        # 总线设备
        # print('是否有总线设备:', end='')
        # _msg = "文档中描述的有总线相关内容吗?仅回答:“有”或“无”,不要输出其他文本。"
        # text = self._gen([], _msg)
        # if text == "有":
        #     _msg = f"""
        #     文档中描述的总线型号是多少,仅输出总线型号不要输出型号以外的其他任何文本,总线型号由数字和英文字母组成。
        #     """
        #     print('设备ID:')
        #     dev_code = self._gen([], _msg)
        #     dev = create_device(dev_code, dev_code, '1', '', proj.C_PROJECT_PK)
        #     devices.append(dev)
        # 类SMU软件
        # print('是否有类SMU设备:', end='')
        # _msg = "文档中有描述遥测和遥控功能吗?仅回答:“有”或“无”,不要输出其他文本。"
        # text = self._gen([], _msg)
        # if text == "有":
        #     # 系统管理单元
        #     print('是否有系统管理单元(SMU):', end='')
        #     _msg = f"文档中有描述系统管理单元(SMU)吗?仅回答“有”或“无”,不要输出其他文本。"
        #     text = self._gen([], _msg)
        #     if text == "有":
        #         dev = create_device("SMU", "系统管理单元", '0', '', proj.C_PROJECT_PK)
        #         devices.append(dev)
        #     # 中心控制单元(CTU)
        #     print('是否有中心控制单元(CTU):', end='')
        #     _msg = f"文档中有描述中心控制单元(CTU)吗?仅回答“有”或“无”,不要输出其他文本。"
        #     text = self._gen([], _msg)
        #     if text == "有":
        #         dev = create_device("CTU", "中心控制单元", '0', '', proj.C_PROJECT_PK)
        #         devices.append(dev)
        #
        # # 类RTU
        # print('是否有类RTU设备:', end='')
        # _msg = "文档中有描述模拟量采集和温度量采集功能吗?仅回答:“有”或“无”,不要输出其他文本。"
        # text = self._gen([], _msg)
        # if text == "有":
        #     dev = create_device("RTU", "远置单元", '0', '', proj.C_PROJECT_PK)
        #     devices.append(dev)
        # device_dicts = json.loads(text)
        # for device_dict in device_dicts:
        #     data_stream = {'name': '数据流', 'code': 'DS'}
        #     device_dict['data_streams'] = [data_stream]
        #
        # return device_dicts
        return devices
    def gen_tm_frame(self, messages):
        _msg = f"""
输出探测帧的结构,探测帧字段包括:探测帧代号(id)、探测帧名称(name)、长度(length)、下级数据单元列表(children)。代号如果没有则用名称的英文翻译,包括下级数据单元。
    def gen_insert_domain_params(self):
        _msg = """
分析文档,输出插入域的参数列表,将所有参数全部输出,不要有遗漏。
数据结构最外层为数组,数组元素为参数信息对象,参数信息字段包括:name、id、pos、length、type。
1个字节的长度为8位,使用B0-B7来表示,请认真计算参数长度。
文档中位置描述信息可能存在跨字节的情况,,例如:"Byte1_B6~Byte2_B0":表示从第1个字节的第7位到第2个字节的第1位,长度是3;"Byte27_B7~Byte28_B0":表示从第27个字节的第8位到第28个字节的第1位,长度是2。
"""
        print('插入域参数列表:')
        files = [file_map['遥测大纲']]
        def validation(gen_text):
            params = json.loads(gen_text)
            assert isinstance(params, list), '插入域参数列表数据结构最外层必须是数组'
            assert len(params), '插入域参数列表不能为空'
        text = self.generate_text(_msg, './out/插入域参数列表.json', files=files, validation=validation)
        return json.loads(text)
    def gen_tm_frame_data(self):
        _msg = """
        """
        print('探测帧信息:')
        text = self._gen(messages, _msg)
        messages.append({'role': 'assistant', 'content': text})
        text = self.remove_markdown(text)
        cadu = json.loads(text)
        files = [file_map['遥测大纲']]
        def validation(gen_text):
            pass
    def gen_tm_frame(self, proj_pk, rule_pk, ds, name_path):
        # 插入域参数列表
        insert_domain = self.gen_insert_domain_params()
        # VC源包格式
        vc_pkt_fields = data_templates.vc_pkt_fields  # self.gen_pkt_format()
        # 获取虚拟信道 vc
        vcs = self.gen_vc()
        for vc in vcs:
            vc['children'] = []
            vc['VCID'] = str(int(vc['VCID'], 2))
            for field in vc_pkt_fields:
                if field['name'] == '数据域':
                    field['children'] = []
                vc['children'].append(dict(field))
        # VCID 字段内容
        vcid_content = build_vcid_content(vcs)
        # 遥测帧结构由模板生成,只需提供特定参数
        tm_data = {
            "vcidContent": vcid_content,
            'insertDomain': insert_domain,
        }
        cadu = data_templates.get_tm_frame(tm_data)
        # VC源包
        self.vc_pkts = self.gen_pkt_vc()
        # 遥测源包设计中的源包列表
        self.tm_pkts = self.gen_pkts()
        # 处理VC下面的遥测包数据
        for vc in vcs:
            # 此VC下的遥测包过滤
            _vc_pkts = filter(lambda it: it['vcs'].__contains__(vc['id']), self.vc_pkts)
            for _pkt in _vc_pkts:
                # 判断遥测包是否有详细定义
                if not next(filter(lambda it: it['name'] == _pkt['name'] and it['hasParams'], self.tm_pkts), None):
                    continue
                # 获取包详情
                _pkt = self.gen_pkt_details(_pkt['name'], _pkt['id'])
                epdu = next(filter(lambda it: it['name'] == '数据域', vc['children']), None)
                if epdu and _pkt:
                    _pkt['children'] = _pkt['datas']
                    _last_par = _pkt['children'][len(_pkt['children']) - 1]
                    _pkt['length'] = (_last_par['pos'] + _last_par['length'])
                    _pkt['pos'] = 0
                    if 'children' not in epdu:
                        epdu['children'] = []
                    # 添加解析规则后缀防止重复
                    _pkt['id'] = _pkt['id'] + '_' + vc['VCID']
                    # 给包名加代号前缀
                    if not _pkt['name'].startswith(_pkt['id']):
                        _pkt['name'] = _pkt['id'] + '_' + _pkt['name']
                    epdu['children'].append(_pkt)
                    apid_node = next(filter(lambda it: it['name'].__contains__('应用过程'), _pkt['headers']), None)
                    ser_node = next(filter(lambda it: it['name'] == '服务', _pkt['headers']), None)
                    sub_ser_node = next(filter(lambda it: it['name'] == '子服务', _pkt['headers']), None)
                    _pkt['vals'] = \
                        f"{apid_node['content']}/{int(ser_node['content'], 16)}/{int(sub_ser_node['content'], 16)}/"
        # 重新计数起始偏移
        self.compute_length_pos(cadu['children'])
        # 将数据插入数据库
        seq = 1
        for cadu_it in cadu['children']:
            if cadu_it['name'] == 'VCDU':
                # VCDU
                # 将信道替换到数据域位置
                vc_data = next(filter(lambda it: it['name'].__contains__('数据域'), cadu_it['children']), None)
                if vc_data:
                    idx = cadu_it['children'].index(vc_data)
                    cadu_it['children'].pop(idx)
                    for vc in vcs:
                        # 处理虚拟信道属性
                        vc['type'] = 'logic'
                        vc['length'] = vc_data['length']
                        vc['pos'] = vc_data['pos']
                        vc['content'] = 'CCSDSMPDU'
                        vcid = vc['VCID']
                        vc['condition'] = f'VCID=={vcid}'
                        # 将虚拟信道插入到VCDU
                        cadu_it['children'].insert(idx, vc)
                        idx += 1
                for vc in vcs:
                    self.compute_length_pos(vc['children'])
                # 设置VCID的content
                vcid_node = next(filter(lambda it: it['name'].__contains__('VCID'), cadu_it['children']), None)
                if vcid_node:
                    vcid_node['content'] = vcid_content
                create_enc_pkt(proj_pk, rule_pk, cadu_it, rule_pk, seq, name_path, ds, '001', 'ENC')
            else:
                # 参数
                create_prop_enc(proj_pk, rule_pk, cadu_it, get_data_ty(cadu_it), seq)
                seq += 1
        return cadu
    def gen_vc(self, messages):
        _msg = f"""
输出探测虚拟信道的划分,不需要描述信息,使用一个数组输出,字段包括:代号(code)、vcid、名称(name)。
        """
    def gen_vc(self):
        _msg = """
请分析文档中的遥测包格式,输出遥测虚拟信道的划分,数据结构最外层为数组,数组元素为虚拟信道信息字典,字典包含以下键值对:
id: 虚拟信道代号
name: 虚拟信道名称
VCID: 虚拟信道VCID(二进制)
format: 根据虚拟信道类型获取对应的数据包的格式的名称
深入理解文档中描述的关系,例如:文档中描述了常规遥测是常规数据的下传信道,并且还描述了分系统常规遥测参数包就是实时遥测参数包,并且文档中对实时遥测参数包的格式进行了描述,所以常规遥测VC应该输出为:{"id": "1", "name": "常规遥测VC", "VCID": "0", "format": "实时遥测参数包"}
"""
        def validation(gen_text):
            vcs = json.loads(gen_text)
            assert next(filter(lambda it: re.match('^[0-1]+$', it['VCID']), vcs)), '生成的VCID必须是二进制'
        print('虚拟信道:')
        text = self._gen(messages, _msg)
        messages.append({'role': 'assistant', 'content': text})
        text = self.remove_markdown(text)
        text = self.generate_text(_msg, "out/虚拟信道.json", files=[file_map['遥测大纲']], validation=validation)
        vcs = json.loads(text)
        return vcs
    def gen_dev_pkts(self, messages):
    def gen_dev_pkts(self):
        _msg = f"""
输出文档中探测源包类型定义描述的设备以及设备下面的探测包,数据结构:最外层为设备列表 > 探测包列表(pkts),设备字段包括:名称(name)、代号(id),源包字段包括:名称(name)、代号(id)
输出文档中遥测源包类型定义描述的设备以及设备下面的遥测包,数据结构:最外层为数组 > 设备 > 遥测包列表(pkts),设备字段包括:名称(name)、代号(id),源包字段包括:名称(name)、代号(id)
        """
        print('设备探测源包信息:')
        file = next(filter(lambda it: it.filename == 'XA-5D无人机分系统探测源包设计报告(公开).md', self.file_objects),
                    None)
        messages = [{'role': 'system', 'content': assistant_msg}, {'role': 'system', 'content': 'fileid://' + file.id}]
        text = self._gen(messages, _msg)
        messages.append({'role': 'assistant', 'content': text})
        text = self.remove_markdown(text)
        print('设备遥测源包信息:')
        files = [file_map["遥测源包设计报告"]]
        text = self.generate_text(_msg, 'out/设备数据包.json', [], files)
        dev_pkts = json.loads(text)
        return dev_pkts
    def gen_pkt_details(self, pkt_name):
    def pkt_in_tm_pkts(self, pkt_name):
        cache_file = f'out/数据包-{pkt_name}.json'
        if os.path.isfile(cache_file):
            return True
        files = [file_map['遥测源包设计报告']]
        print(f'文档中有无“{pkt_name}”的字段描述:', end='')
        _msg = f"""
输出文档中描述的“{pkt_name}”探测包。
探测包字段包括:名称(name)、代号(id)、包头属性列表(headers)、数据域参数列表(datas),
包头属性字段包括:位置(pos)、名称(name)、代号(id)、定义(val),
数据域参数字段包括:位置(pos)、名称(name)、代号(id)、字节顺序(byteOrder),
如果没有代号用名称的英文翻译代替,如果没有名称用代号代替,
输出内容仅输出json,不要输出任何其他内容!
        """
        print(f'探测源包“{pkt_name}”信息:')
        file = next(filter(lambda it: it.filename == 'XA-5D无人机分系统探测源包设计报告(公开).md', self.file_objects),
                    None)
        messages = [{'role': 'system', 'content': assistant_msg}, {'role': 'system', 'content': 'fileid://' + file.id}]
        text = self._gen(messages, _msg)
        messages.append({'role': 'assistant', 'content': text})
        text = self.remove_markdown(text)
        pkt = json.loads(text)
文档中有遥测包“{pkt_name}”的字段表描述吗?遥测包名称必须完全匹配。输出:“无”或“有”,不要输出其他任何内容。
注意:遥测包的字段表紧接着遥测包章节标题,如果章节标题后面省略了或者详见xxx则是没有字段表描述。
根据文档内容输出。"""
        text = self.generate_text(_msg, f'out/pkts/有无数据包-{pkt_name}.txt', [], files)
        return text == '有'
    def gen_pkt_details(self, pkt_name, pkt_id):
        cache_file = f'out/数据包-{pkt_name}.json'
        files = [file_map['遥测源包设计报告']]
        if not os.path.isfile(cache_file):
            _msg = f"""
输出文档中描述的名称为“{pkt_name}”代号为“{pkt_id}”遥测包;
遥测包字段包括:名称(name)、代号(id)、类型(type)、包头属性列表(headers)、数据域参数列表(datas),类型为 linear;
包头属性字段包括:名称(name)、代号(id)、位置(pos)、定义(content)、长度(length)、类型(type),类型为 para;
数据域参数字段包括:参数名称(name)、参数代号(id)、位置(pos)、长度(length)、字节顺序(byteOrder),类型为 para;
如果没有名称用代号代替,如果没有代号用名称的英文翻译代替,翻译尽量简短;
你需要理解数据包的位置信息,并且将所有输出单位统一转换为 bits,位置字段的输出格式必须为数值类型;
数据结构仅只包含遥测包,仅输出json,不要输出任何其他内容。"""
            print(f'遥测源包“{pkt_name}”信息:')
            def validation(gen_text):
                _pkt = json.loads(gen_text)
                assert 'headers' in _pkt, '包结构中必须包含headers字段'
                assert 'datas' in _pkt, '包结构中必须包含datas字段'
            text = self.generate_text(_msg, cache_file, [], files, validation)
            pkt = json.loads(text)
        else:
            pkt = json.loads(read_from_file(cache_file))
        pkt_len = 0
        for par in pkt['datas']:
            par['pos'] = pkt_len
            pkt_len += par['length']
        pkt['length'] = pkt_len
        return pkt
    def gen_pkts(self):
        _msg = f"""
输出文档中描述的探测包。
探测包字段包括:名称(name)、代号(id),
输出文档中描述的遥测包。
遥测包字段包括:名称(name)、代号(id)、hasParams,
名称中不要包含代号,
hasParams表示当前遥测包是否有参数列表,遥测包的参数表紧接着遥测包章节标题,如果章节标题后面省略了或者详见xxx则是没有参数表,
如果没有代号用名称的英文翻译代替,如果没有名称用代号代替,
顶级结构直接从探测包开始,不包括探测包下面的参数。
            """
        print(f'探测源包列表:')
        file = next(
            filter(lambda it: it.filename == 'XA-5D无人机分系统探测源包设计报告(公开).md', self.file_objects),
            None)
        messages = [{'role': 'system', 'content': assistant_msg},
                    {'role': 'system', 'content': 'fileid://' + file.id}]
        text = self._gen(messages, _msg)
        messages.append({'role': 'assistant', 'content': text})
        text = self.remove_markdown(text)
数据结构最外层为数组数组元素为遥测包,不包括遥测包下面的参数。
"""
        print(f'遥测源包列表:')
        files = [file_map['遥测源包设计报告']]
        text = self.generate_text(_msg, 'out/源包列表.json', [], files)
        pkt = json.loads(text)
        return pkt
    def gen_pkt_vc(self, messages):
    def gen_pkt_vc(self):
        _msg = f"""
根据探测源包下传时机定义,输出各个探测源包信息列表,顶级结构为数组元素为探测源包,源包字段包括:包代号(id),名称(name),所属虚拟信道(vcs),下传时机(timeTags)
根据遥测源包下传时机定义,输出各个遥测源包信息列表,顶级结构为数组元素为遥测源包,源包字段包括:包代号(id),名称(name),所属虚拟信道(vcs),下传时机(timeTags)
        """
        print('探测源包所属虚拟信道:')
        text = self._gen(messages, _msg)
        messages.append({'role': 'assistant', 'content': text})
        text = self.remove_markdown(text)
        files = [file_map['遥测大纲']]
        print('遥测源包所属虚拟信道:')
        def validation(gen_text):
            pkts = json.loads(gen_text)
            assert len(pkts), 'VC源包列表不能为空'
        text = self.generate_text(_msg, 'out/遥测VC源包.json', files=files, validation=validation)
        pkt_vcs = json.loads(text)
        return pkt_vcs
    def remove_markdown(self, text):
        # 去掉开头的```json
        text = re.sub(r'^```json', '', text)
        # 去掉结尾的```json
        text = re.sub(r'```$', '', text)
        return text
    def gen_pkt_format(self):
        _msg = f"""
请仔细分系文档,输出各个数据包的格式,数据结构最外层为数组,数组元素为数据包格式,将主导头的子级提升到主导头这一级并且去除主导头,数据包type为logic,包数据域type为any。
包格式children包括:版本号(id:Ver)、类型(id:TM_Type)、副导头标志(id:Vice_Head)、应用过程标识符(id:Proc_Sign)、分组标志(id:Group_Sign)、包序列计数(id:Package_Count)、包长(id:Pack_Len)、数据域(id:EPDU_DATA)。
children元素的字段包括:name、id、pos、length、type
注意:生成的JSON语法格式要合法。
"""
        print('遥测包格式:')
        text = self.generate_text(_msg, 'out/数据包格式.json', files=[file_map['遥测大纲']])
        pkt_formats = json.loads(text)
        return pkt_formats
    def compute_length_pos(self, items: list):
        length = 0
        pos = 0
        for child in items:
            if 'children' in child:
                self.compute_length_pos(child['children'])
            child['pos'] = pos
            if 'length' in child and isinstance(child['length'], int):
                length = length + child['length']
                pos = pos + child['length']
        # node['length'] = length
    def gen_bus(self, proj_pk, rule_enc, rule_id, ds, name_path, dev_name):
        _msg = f"""
请析文档,列出总线通信包传输约定中描述的所有数据包列表,
数据包字段包括:id、name、apid(16进制字符串)、service(服务子服务)、length(bit长度)、interval(传输周期)、subAddr(子地址/模式)、frameNum(通信帧号)、
transSer(传输服务)、note(备注)、rtAddr(所属RT的地址十进制)、rt(所属rt名称)、throughBus(是否经过总线)、burst(是否突发)、transDirect(传输方向),
数据结构最外层是数组,数组元素为数据包,以JSON格式输出,不要输出JSON以外的任何文本。
通信帧号:使用文档中的文本不要做任何转换。
subAddr:值为“深度”、“平铺”、“数字”或null。
是否经过总线的判断依据:“备注”列填写了内容类似“不经过总线”的文字表示不经过总线否则经过总线。
传输服务分三种:SetData(置数)、GetData(取数)、DataBlock(数据块传输)。
传输方向分:”收“和”发“,传输服务如果是”取数“是”收“,如果是”数据块传输“则根据包所在的分系统以及表格的”传输方向“列进行判断,判断对于SMU来说是收还是发。
是否突发的判断依据:根据表格中的”传输周期“列进行判断,如果填写了类似”突发“的文字表示是突发否则表示不是突发。
"""
        print('总线数据包:')
        def validation(gen_text):
            json.loads(gen_text)
        text = self.generate_text(_msg, 'out/总线.json', files=[file_map['总线传输通信帧分配']], validation=validation)
        pkts = json.loads(text)
        # 筛选经总线的数据包
        pkts = list(filter(lambda it: it['throughBus'], pkts))
        no_apid_pkts = list(filter(lambda it: not it['apid'], pkts))
        # 筛选有apid的数据包
        pkts = list(filter(lambda it: it['apid'], pkts))
        pkts2 = []
        for pkt in pkts:
            if self.pkt_in_tm_pkts(pkt["name"]):
                pkts2.append(pkt)
        for pkt in pkts2:
            _pkt = self.gen_pkt_details(pkt['name'], pkt['id'])
            if _pkt:
                pkt['children'] = []
                pkt['children'].extend(_pkt['datas'])
                pkt['length'] = _pkt['length']
        rt_pkt_map = {}
        for pkt in pkts:
            # 根据数据块传输和取数分组
            # 逻辑封装包的解析规则ID:RT[rt地址]SUB[子地址]S(S代表取数,方向是AA表示发送;R代表置数,方向是BB表示接受)
            # 取数:逻辑封装包根据子地址和帧号组合创建,有几个组合就创建几个逻辑封装包
            # 数据块:只有一个逻辑封装包
            # 处理子地址
            if pkt['burst']:
                # 突发包子地址是18~26
                pkt['subAddr'] = 26
            elif pkt['subAddr'] == '平铺' or pkt['subAddr'] is None:
                # 平铺:11~26,没有填写的默认为平铺
                pkt['subAddr'] = 26
            elif pkt['subAddr'] == '深度':
                # 深度:11
                pkt['subAddr'] = 11
            # 处理帧号
            if pkt['burst']:
                # 突发:ALL
                pkt['frameNum'] = 'ALL'
            elif not pkt['frameNum']:
                # 有
                pkt['frameNum'] = ''
            # todo: 处理传输方向
            rt_addr = pkt['rtAddr']
            sub_addr = pkt['subAddr']
            trans_ser = pkt['transSer']
            frame_no = pkt['frameNum'].replace('|', ',')
            if trans_ser == 'GetData':
                # 取数
                pkt_id = f"RT{rt_addr}SUB{sub_addr}"
                vals = f"{rt_addr}/{sub_addr}/0xAA/{frame_no}/"
                rt_pkt_map_gen(pkt, '取数', rt_pkt_map, pkt_id, vals)
            elif trans_ser == 'DataBlock':
                # 数据块
                direct = '0xAA'
                rt_pkt_map_gen(pkt, '数据块传输', rt_pkt_map, f"RT{rt_addr}SUB{sub_addr}{direct}",
                               f"{rt_addr}/{sub_addr}/{direct}/ALL/")
        _pkts = []
        for k in rt_pkt_map:
            _pkts.append(rt_pkt_map[k])
        bus_items = data_templates.get_bus_datas(_pkts)
        seq = 1
        sub_key_nodes = list(filter(lambda it: 'is_key' in it, bus_items))
        has_key = any(sub_key_nodes)
        rule_pk = rule_enc.C_ENC_PK
        sub_key = ''
        key_items = []
        self.compute_length_pos(bus_items)
        for item in bus_items:
            if item['type'] == 'enc':
                if has_key:
                    _prop_enc = create_any_pkt(proj_pk, rule_pk, item, seq, name_path, ds, 'ENC', sub_key_nodes,
                                               key_items)
                else:
                    _prop_enc, rule_stream, _ = create_enc_pkt(proj_pk, rule_pk, item, rule_enc.C_ENC_PK, seq,
                                                               name_path, ds, '001', 'ENC')
            else:
                # 参数
                _prop_enc = create_prop_enc(proj_pk, rule_pk, item, get_data_ty(item), seq)
                if item.__contains__('is_key'):
                    sub_key += _prop_enc.C_ENCITEM_PK + '/'
                    key_items.append(
                        {"pk": _prop_enc.C_ENCITEM_PK,
                         'id': _prop_enc.C_SEGMENT_ID,
                         'name': _prop_enc.C_NAME,
                         'val': ''})
            seq += 1
        if sub_key:
            rule_enc.C_KEY = sub_key
            update_rule_enc(rule_enc)
    def gen_tc(self):
        # 数据帧格式
        frame = self.gen_tc_transfer_frame()
        # 数据包格式
        pkt_format = self.gen_tc_transfer_pkt()
        # 数据包列表
        pkts = self.gen_tc_transfer_pkts()
        for pkt in pkts:
            pf = json.loads(json.dumps(pkt_format))
            pf['name'] = pkt['name']
            ph = next(filter(lambda x: x['name'] == '主导头', pf['children']), None)
            apid = next(filter(lambda x: x['name'] == '应用进程标识符(APID)', ph['children']), None)
            apid['value'] = pkt['apid']
            apid['type'] = 'const'
            sh = next(filter(lambda x: x['name'] == '副导头', pf['children']), None)
            ser = next(filter(lambda x: x['name'] == '服务类型', sh['children']), None)
            sub_ser = next(filter(lambda x: x['name'] == '服务子类型', sh['children']), None)
            ser['value'] = pkt['server']
            ser['type'] = 'const'
            sub_ser['value'] = pkt['subServer']
            sub_ser['type'] = 'const'
            frame['subPkts'].append(pf)
        self.order = 0
        def build_def(item: dict):
            if item['type'] == 'enum':
                return json.dumps({"EnumItems": item['enums'], "CanInput": True})
            elif item['type'] == 'length':
                return None
            elif item['type'] == 'checkSum':
                return json.dumps({"ChecksumType": "CRC-CCITT"})
            elif item['type'] == 'subPkt':
                return json.dumps({"CanInput": False})
            elif item['type'] == 'combPkt':
                return None
            elif 'value' in item:
                return item['value']
        def create_tc_format(parent_pk, field):
            field['order'] = self.order
            self.order += 1
            field['def'] = build_def(field)
            if 'length' in field:
                field['bitWidth'] = field['length']
            field['bitOrder'] = None
            field['attr'] = 0
            if field['type'] == 'length':
                val = field['value']
                field['range'] = val['start'] + "~" + val['end']
                field['formula'] = val['formula']
            ins_format = create_ins_format(self.proj.C_PROJECT_PK, parent_pk, field)
            if 'children' in field:
                autocode = 1
                if field['type'] == 'pkt':
                    ins_format = create_ins_format(self.proj.C_PROJECT_PK, ins_format.C_INS_FORMAT_PK,
                                                   {'order': self.order, 'type': 'subPkt',
                                                    'def': json.dumps({"CanInput": False})})
                    self.order += 1
                for child in field['children']:
                    child['autocode'] = autocode
                    autocode += 1
                    create_tc_format(ins_format.C_INS_FORMAT_PK, child)
            # if 'subPkts' in field:
            #     for pkt in field['subPkts']:
            #         ins_format = create_ins_format(self.proj.C_PROJECT_PK, ins_format.C_INS_FORMAT_PK,
            #                                        {'order': self.order, 'type': 'subPkt',
            #                                         'def': json.dumps({"CanInput": False})})
            #         create_tc_format(ins_format.C_INS_FORMAT_PK, pkt)
        create_tc_format(None, frame)
    def gen_tc_transfer_frame(self):
        _msg = '''
分析YK传送帧格式,提取YK传送帧的数据结构,不包括数据包的数据结构。
## 经验:
字段类型包括:
1.组合包:combPkt,
2.固定码字:const,
3.长度:length,
4.枚举值:enum,
5.校验和:checkSum,
6.数据区:subPkt。
根据字段描述分析字段的类型,分析方法:
1.字段描述中明确指定了字段值的,类型为const,
2.字段中没有明确指定字段值,但是罗列了取值范围的,类型为enum,
3.字段描述中如果存在多层级描述则父级字段的类型为combPkt,
4.字段如果是和“长度”有关,类型为length,
5.如果和数据域有关,类型为subPkt,
6.字段如果和校验和有关,类型为checkSum。
字段值提取方法:
1.字段描述中明确指定了字段值,
2.长度字段的值要根据描述确定起止字段范围以及计算公式,value格式例如:{"start":"<code>","end":"<code>","formula":"N-1"},注意:start和end的值为字段code。
## 限制:
- length 自动转换为bit长度。
- value 根据字段描述提取。
- enums 有些字段是枚举值,根据字段描述提取,枚举元素的数据结构为{"n":"","v":"","c":""}。
- 输出内容必须为严格的json,不能输出除json以外的任何内容。
字段数据结构:
主导头
    版本号、通过标志、控制命令标志、空闲位、HTQ标识、虚拟信道标识、帧长、帧序列号
传送帧数据域
帧差错控制域。
# 输出内容例子:
{
    "name": "YK帧",
    "type": "pkt"
    "children":[
        {
            "name": "主导头",
            "code": "primaryHeader",
            "length": 2,
            "value": "00",
            "type": "combPkt",
            "children": [
                {
                    "name": "版本号",
                    "code": "verNum"
                    "length": 1,
                    "value": "00"
                }
            ]
        }
    ],
    "subPkts":[]
}
'''
        def validation(gen_text):
            json.loads(gen_text)
        text = self.generate_tc_text(_msg, 'out/tc_transfer_frame.json', files=[file_map['指令格式']],
                                     validation=validation)
        frame = json.loads(text)
        return frame
    def gen_tc_transfer_pkt(self):
        _msg = '''
仅分析YK包格式,提取YK包数据结构。
## 经验:
字段类型包括:
1.组合包:combPkt,
2.固定码字:const,
3.长度:length,
4.枚举值:enum,
5.校验和:checkSum,
6.数据区:subPkt。
根据字段描述分析字段的类型,分析方法:
1.字段描述中明确指定了字段值的,类型为const,
2.字段中没有明确指定字段值,但是罗列了取值范围的,类型为enum,
3.字段描述中如果存在多层级描述则父级字段的类型为combPkt,
4.字段如果是和“长度”有关,类型为length,
5.如果和数据域有关,类型为subPkt,
6.字段如果和校验和有关,类型为checkSum。
字段值提取方法:
1.字段描述中明确指定了字段值,
2.长度字段的值要根据描述确定起止字段范围以及计算公式,value格式例如:{"start":"<code>","end":"<code>","formula":"N-1"},注意:start和end的值为字段code。
## 限制:
- length 自动转换为bit长度。
- value 根据字段描述提取。
- enums 有些字段是枚举值,根据字段描述提取,枚举元素的数据结构为{"n":"","v":"","c":""}。
- 输出内容必须为严格的json,不能输出除json以外的任何内容。
字段数据结构:
主导头
    包识别
        包版本号、包类型、数据区头标志、应用进程标识符(APID)
    包序列控制
        序列标志
        包序列计数
    包长
副导头
    CCSDS副导头标志
    YK包版本号
    命令正确应答(Ack)
    服务类型
    服务子类型
    源地址
应用数据区
帧差错控制域。
# 输出内容例子:
{
    "name": "YK包",
    "type": "pkt"
    "children":[
        {
            "name": "主导头",
            "code": "primaryHeader",
            "length": 2,
            "value": "00",
            "type": "combPkt",
            "children": [
                {
                    "name": "版本号",
                    "code": "verNum"
                    "length": 1,
                    "value": "00"
                }
            ]
        }
    ],
    "subPkts":[]
}
'''
        def validation(gen_text):
            json.loads(gen_text)
        text = self.generate_tc_text(_msg, 'out/tc_transfer_pkt.json', files=[file_map['指令格式']],
                                     validation=validation)
        pkt_format = json.loads(text)
        return pkt_format
    def gen_tc_transfer_pkts(self):
        _msg = '''
分析文档列出所有的遥控源包。
## 数据结构如下:
[{
"name": "xxx",
"code":"pkt",
"apid":"0xAA",
"server":"0x1",
"subServer":"0x2"
}]
'''
        def validation(gen_text):
            json.loads(gen_text)
        text = self.generate_tc_text(_msg, 'out/tc_transfer_pkts.json', files=[file_map['指令格式']],
                                     validation=validation)
        pkts = json.loads(text)
        return pkts
if __name__ == '__main__':
    md_file = 'D:\\workspace\\PythonProjects\\KnowledgeBase\\doc\\文档合并.md'
    md_file2 = 'D:\\workspace\\PythonProjects\\KnowledgeBase\\doc\\XA-5D无人机分系统探测源包设计报告(公开).md'
    # 启动大模型处理流程
    ret_text = DbStructFlow([md_file, md_file2]).run()
    try:
        os.makedirs("./out/pkts", exist_ok=True)
        # 启动大模型处理流程
        ret_text = DbStructFlow().run()
    except KeyboardInterrupt:
        if g_completion:
            g_completion.close()
knowledgebase/__init__.py
copy from db/__init__.py copy to knowledgebase/__init__.py
knowledgebase/db/__init__.py
knowledgebase/db/data_creator.py
New file
@@ -0,0 +1,327 @@
import math
from knowledgebase.db.db_helper import create_property_enc, \
    create_rule, create_rule_stream, create_rule_enc, create_enc_linear, create_rule_linear, create_property_linear, \
    update_rule_enc, create_extend_info, create_rulekey_info
from knowledgebase.utils import get_bit_mask
enc_ty_flag_map = {
    "DS": "0",
    "ENC": "1",
    "LOGICENC": "2",
    "ANY": "3",
    "LINEAR": "4",
}
def get_byte_len_str(node: dict):
    length = node['length']
    if node['type'] != 'linear':
        return length
    if isinstance(length, int):
        length = f'{math.ceil(length / 8)}'
    # if 'children' in node and len(node['children']):
    #     last = node['children'][-1:].pop()
    #     if isinstance(last['length'], int):
    #         length = last['pos'] + last['length']
    #         length = f'{math.ceil(length / 8)}'
    return length
def check_gen_content(func, check_fun, try_cnt=6):
    try:
        ret = func()
        check_fun(ret)
        return ret
    except BaseException as e:
        if try_cnt <= 0:
            print('生成失败!')
            raise e
        print(f'生成内容有误重新生成,第{6 - try_cnt}次。')
        return check_gen_content(func, check_fun, try_cnt - 1)
def get_data_ty(node: dict):
    data_ty = 'INVAR'
    if 'dataTy' in node:
        data_ty = node['dataTy']
    return data_ty
def create_prop_enc(proj_pk, enc_pk, node, ty, seq):
    bit_length = node['length']
    if isinstance(bit_length, int):
        pos = node['pos']
        byte_length = math.ceil((pos % 8 + bit_length) / 8)
        start = node['pos'] % 8
        end = start + bit_length - 1
        if start == 0 and bit_length % 8 == 0:
            mask = 'ALL'
        else:
            mask = hex(get_bit_mask(start, end))
    else:
        mask = 'ALL'
        byte_length = bit_length
    para_id = f'{node["id"]}'
    offset = f'{node["pos"] // 8}'
    content = None
    if 'content' in node:
        content = node['content']
    cond = None
    if 'condition' in node:
        cond = node['condition']
    prop_enc = create_property_enc(proj_pk, enc_pk, node['id'], node['name'], ty, content, f'{offset}',
                                   f'{byte_length}', '1', mask, cond, seq, '', para_id)
    return prop_enc
def create_prop_linear(proj_pk, linear_pk, node, seq):
    bit_length = node['length']
    if isinstance(bit_length, int):
        byte_length = math.ceil(bit_length / 8)
        start = node['pos'] % 8
        end = start + bit_length - 1
        mask = hex(get_bit_mask(start, end))
    else:
        mask = 'ALL'
        byte_length = bit_length
    para_id = f'{node["id"]}'
    offset = f'{node["pos"] // 8}'
    return create_property_linear(proj_pk, linear_pk, para_id, node['name'], 'INVAR', '0', f'{offset}',
                                  f'{byte_length}', None, mask, None, None, None, None, None, seq)
def create_key_liner_pkt(proj_pk, rule_pk, node, parent_rule_pk, seq, name_path, ds, content,
                         actual_parent_pk=None):
    # 创建线性包,父级包含子包主键字段的情况
    # 创建解析规则
    rule_name = node['rule_name'] if 'rule_name' in node else node['name']
    rule = create_rule(proj_pk, rule_pk, node['id'], rule_name, get_byte_len_str(node), parent_rule_pk,
                       enc_ty_flag_map['LINEAR'], actual_parent_pk)
    rule_linear = create_rule_linear(proj_pk, rule_pk, node['id'], node['name'], get_byte_len_str(node), content)
    # 创建t_rule_stream
    rule_stream = create_rule_stream(proj_pk,
                                     rule.C_RULE_PK,
                                     ds.C_STREAM_PK,
                                     ds.C_STREAM_ID,
                                     ds.C_NAME,
                                     ds.C_STREAM_DIR,
                                     f"{name_path}{node['name']}/")
    if 'children' in node:
        seq = 1
        for child in node['children']:
            # 创建线性包参数
            create_prop_linear(proj_pk, rule_linear.C_LINEAR_PK, child, seq)
            seq = seq + 1
    return rule
def create_liner_pkt(proj_pk, linear_pk, node, parent_rule_pk, seq, name_path, ds, content):
    # 创建线性包
    prop_enc = create_prop_enc(proj_pk, linear_pk, node, 'LINEAR', seq)
    # 创建 enc_linear
    enc_linear = create_enc_linear(proj_pk, prop_enc.C_ENCITEM_PK, '002')
    rule_pk = enc_linear.C_LINEAR_PK
    # 创建解析规则
    length = get_byte_len_str(node)
    rule_name = node['rule_name'] if 'rule_name' in node else node['name']
    rule = create_rule(proj_pk, rule_pk, node['id'], rule_name, length, parent_rule_pk,
                       enc_ty_flag_map['LINEAR'])
    rule_linear = create_rule_linear(proj_pk, rule_pk, node['id'], node['name'], length, content)
    # 创建t_rule_stream
    rule_stream = create_rule_stream(proj_pk,
                                     rule.C_RULE_PK,
                                     ds.C_STREAM_PK,
                                     ds.C_STREAM_ID,
                                     ds.C_NAME,
                                     ds.C_STREAM_DIR,
                                     f"{name_path}{node['name']}/")
    if 'children' in node:
        seq = 1
        for child in node['children']:
            # 创建线性包参数
            create_prop_linear(proj_pk, rule_linear.C_LINEAR_PK, child, seq)
            seq = seq + 1
def create_any_pkt(proj_pk, linear_pk, node, seq, name_path, ds, pkt_ty, sub_key_nodes, key_items=None):
    # any没有t_rule、t_rule_enc、t_rule_linear
    prop_enc = create_prop_enc(proj_pk, linear_pk, node, pkt_ty, seq)
    rule_name = node['rule_name'] if 'rule_name' in node else node['name']
    length = get_byte_len_str(node)
    rule = create_rule(proj_pk, prop_enc.C_ENCITEM_PK, node['id'], rule_name, length, prop_enc.C_ENC_PK,
                       enc_ty_flag_map['ANY'], None)
    if 'children' in node:
        child_seq = 1
        for child in node['children']:
            vals = None
            if 'vals' in child:
                values = []
                vals = child['vals']
                if vals.endswith("/"):
                    vals = vals[:-1]
                values.extend(vals.split("/"))
                for i in range(0, len(key_items)):
                    key_items[i]['val'] = values[i]
                node_name = '【'
                for i in range(0, len(values)):
                    sub_key_node = sub_key_nodes[i]
                    val = values[i]
                    node_name += f'{sub_key_node["name"]}={val}'
                node_name += '】'
                child['rule_name'] = child['name'] + node_name
            if child['type'] == 'enc':  # 封装包
                enc_linear = create_enc_linear(proj_pk, prop_enc.C_ENCITEM_PK, '001', vals)
                _, __, _rule = create_enc_pkt(proj_pk, enc_linear.C_LINEAR_PK, child, enc_linear.C_ENCITEM_PK,
                                              child_seq, name_path, ds, '001', 'ENC', parent_has_key=True,
                                              actual_parent_pk=prop_enc.C_ENC_PK)
                if key_items:
                    for it in key_items:
                        create_rulekey_info(proj_pk, _rule.C_RULE_PK, _rule.C_RULE_ID, _rule.C_RULE_NAME, it['pk'],
                                            it['id'], it['name'], it['val'])
            elif child['type'] == 'linear':  # 线性包
                # 查询已有的rule
                _rule = None  # find_rule_by_rule_id(child['id'])
                enc_linear = create_enc_linear(proj_pk, prop_enc.C_ENCITEM_PK, '002', vals,
                                               _rule.C_RULE_PK if _rule else None)
                if not _rule:
                    _rule = create_key_liner_pkt(proj_pk, enc_linear.C_LINEAR_PK, child, enc_linear.C_ENCITEM_PK,
                                                 child_seq, name_path, ds, None, prop_enc.C_ENC_PK)
                    if key_items:
                        for it in key_items:
                            create_rulekey_info(proj_pk, _rule.C_RULE_PK, _rule.C_RULE_ID, _rule.C_RULE_NAME,
                                                it['pk'], it['id'], it['name'], it['val'])
                else:
                    # 创建解析规则
                    rule_name = node['rule_name'] if 'rule_name' in node else node['name']
                    _rule = create_rule(proj_pk, _rule.C_RULE_PK, child['id'], rule_name, child['length'],
                                        prop_enc.C_ENCITEM_PK, enc_ty_flag_map['LINEAR'], prop_enc.C_ENC_PK)
                    # rule_linear = create_rule_linear(proj_pk, enc_linear.C_LINEAR_PK, node['id'], node['name'],
                    #                                  node['length'], None)
                    # 创建t_rule_stream
                    rule_stream = create_rule_stream(proj_pk,
                                                     _rule.C_RULE_PK,
                                                     ds.C_STREAM_PK,
                                                     ds.C_STREAM_ID,
                                                     ds.C_NAME,
                                                     ds.C_STREAM_DIR,
                                                     f"{name_path}{child['name']}/")
            elif child['type'] == 'logic':  # 逻辑封装包
                enc_linear = create_enc_linear(proj_pk, prop_enc.C_ENCITEM_PK, '005', vals)
                _, __, _rule = create_enc_pkt(proj_pk, enc_linear.C_LINEAR_PK, child, enc_linear.C_ENCITEM_PK,
                                              child_seq,
                                              name_path, ds, '005', 'ENC', True, prop_enc.C_ENC_PK)
                if key_items:
                    for it in key_items:
                        create_rulekey_info(proj_pk, _rule.C_RULE_PK, _rule.C_RULE_ID, _rule.C_RULE_NAME,
                                            it['pk'], it['id'], it['name'], it['val'])
            child_seq += 1
    return prop_enc
def create_enc_pkt(proj_pk, linear_pk, node, parent_rule_pk, seq, name_path, ds, ty, pkt_ty,
                   parent_has_key=False, actual_parent_pk=None):
    """
    创建封装包
    :param enc_pk:
    :param proj_pk:
    :param node:
    :param parent_rule_pk:
    :param seq:
    :param name_path:
    :param ds:
    :param ty:
    :param is_logic_enc:
    :return:
    """
    prop_enc = None
    key_items = []
    length = get_byte_len_str(node)
    # 查询已有的rule
    if not parent_has_key:
        # 创建封装包
        prop_enc = create_prop_enc(proj_pk, linear_pk, node, pkt_ty, seq)
        encitem_pk = prop_enc.C_ENCITEM_PK
        vals = None
        if 'vals' in node:
            vals = node['vals'] + '/'
        enc_linear = create_enc_linear(proj_pk, encitem_pk, ty, vals)
        rule_pk = enc_linear.C_LINEAR_PK
    else:
        rule_pk = linear_pk
    # 创建封装包下面的解析规则
    rule_name = node['rule_name'] if 'rule_name' in node else node['name']
    rule = create_rule(proj_pk, rule_pk, node['id'], rule_name, length, parent_rule_pk,
                       enc_ty_flag_map[pkt_ty], actual_parent_pk)
    rule_enc = create_rule_enc(proj_pk, rule.C_RULE_PK, node['id'], node['name'],
                               node['content'] if 'content' in node else None)
    name_path = f"{name_path}{node['name']}/"
    rule_stream = create_rule_stream(proj_pk,
                                     rule.C_RULE_PK,
                                     ds.C_STREAM_PK,
                                     ds.C_STREAM_ID,
                                     ds.C_NAME,
                                     ds.C_STREAM_DIR,
                                     name_path)
    if 'extInfo' in node and node['extInfo']:
        for info in node['extInfo']:
            create_extend_info(proj_pk, info['id'], info['name'], info['val'], rule.C_RULE_PK)
    if 'children' in node:
        child_seq = 1
        sub_key_nodes = list(filter(lambda it: it.__contains__('is_key'), node['children']))
        has_key = len(sub_key_nodes) > 0
        sub_key = ''
        for child in node['children']:
            if child['type'] == 'enc':  # 封装包
                if has_key:
                    create_any_pkt(proj_pk, rule.C_RULE_PK, child, child_seq, name_path, ds, 'ENC', sub_key_nodes)
                else:
                    create_enc_pkt(proj_pk, rule.C_RULE_PK, child, rule_pk, child_seq,
                                   name_path, ds, '001', 'ENC')
            elif child['type'] == 'para':  # 数据段参数
                _prop_enc = create_prop_enc(proj_pk, rule_enc.C_ENC_PK, child, get_data_ty(child), child_seq)
                is_key = False
                if 'is_key' in child:
                    is_key = child['is_key']
                if is_key:
                    sub_key += _prop_enc.C_ENCITEM_PK + '/'
                    key_items.append(
                        {"pk": _prop_enc.C_ENCITEM_PK,
                         'id': _prop_enc.C_SEGMENT_ID,
                         'name': _prop_enc.C_NAME,
                         'val': ''})
            elif child['type'] == 'linear':  # 线性包
                if has_key:
                    create_any_pkt(proj_pk, rule.C_RULE_PK, child, child_seq, name_path, ds, 'LINEAR', sub_key_nodes,
                                   key_items)
                else:
                    create_liner_pkt(proj_pk, rule.C_RULE_PK, child, rule_pk, child_seq,
                                     name_path, ds, None)
            elif child['type'] == 'logic':  # 逻辑封装包
                if has_key:
                    create_any_pkt(proj_pk, rule.C_RULE_PK, child, child_seq, name_path, ds, 'LOGICENC', sub_key_nodes,
                                   key_items)
                else:
                    create_enc_pkt(proj_pk, rule.C_RULE_PK, child, rule_pk, child_seq, name_path, ds,
                                   '005', 'LOGICENC')
            elif child['type'] == 'any':  # 任意包
                if has_key:
                    create_any_pkt(proj_pk, rule.C_RULE_PK, child, child_seq, name_path, ds, 'ANY', sub_key_nodes,
                                   key_items)
                else:
                    create_enc_pkt(proj_pk, rule.C_RULE_PK, child, rule_pk, child_seq, name_path, ds,
                                   '005', 'ANY')
            child_seq += 1
            if sub_key:
                rule_enc.C_KEY = sub_key
                update_rule_enc(rule_enc)
    return prop_enc, rule_stream, rule
knowledgebase/db/db_helper.py
New file
@@ -0,0 +1,408 @@
import uuid
from sqlalchemy.orm import sessionmaker, scoped_session
from knowledgebase.db.models import engine, TProject, TDevice, TDataStream, TDevStream, TRule, TRuleEnc, TPropertyEnc, \
    TPropertyLinear, TRuleStream, TEncLinear, TRuleLinear, TParameter, TParameterType, TExtendInfo, TRulekeyInfo, \
    TInsFormat
from hashlib import md5
# 创建一个会话工厂
session_factory = sessionmaker(bind=engine)
# 创建一个会话对象
Session = scoped_session(session_factory)
session = Session()
_para_id_map = {}
def get_pk():
    n = uuid.uuid4().hex
    pk = md5(n.encode('utf-8')).hexdigest()
    return pk
def create_project(sat_id, sat_name, proj_code, proj_name, desc, date_time, ) -> TProject:
    """
    创建project
    :param sat_id:
    :param sat_name:
    :param proj_code:
    :param proj_name:
    :param desc:
    :param date_time:
    :return: 创建完成的project
    """
    project = TProject(C_PROJECT_PK=get_pk(), C_SAT_ID=sat_id, C_SAT_NAME=sat_name, C_PROJECT_CODE=proj_code,
                       C_DESCRIPTION=desc, C_HASH=uuid.uuid4().int & 0xffffffff, C_PROJECT_NAME=proj_name,
                       C_DATETIME=date_time,
                       C_CREATEOR='')
    session.add(project)
    session.commit()
    return project
def create_device(device_id, device_name, device_type, dll, project_pk):
    """
    创建device
    :param device_id:
    :param device_name:
    :param device_type:
    :param dll:
    :param project_pk:
    :return:
    """
    device = TDevice(C_DEV_PK=get_pk(), C_DEV_ID=device_id, C_DEV_NAME=device_name, C_DEV_TYPE=device_type, C_DLL=dll,
                     C_PROJECT_PK=project_pk)
    session.add(device)
    session.commit()
    return device
def create_extend_info(proj_pk, prop_id, prop_name, val, fk):
    ext_info = TExtendInfo(
        C_PK=get_pk(),
        C_PROJECT_PK=proj_pk,
        C_PROPERTY_ID=prop_id,
        C_PROPERTY_NAME=prop_name,
        C_VAL=val,
        C_FOREIGN_PK=fk
    )
    session.add(ext_info)
    session.commit()
def create_data_stream(proj_pk, dev_pk, name, code, data_ty, direct, rule_id, rule_ty, rule_pk=None):
    """
    创建data_stream
    :param proj_pk:
    :param dev_pk:
    :param name:
    :param code:
    :param data_ty:
    :param direct:
    :param rule_id:
    :param rule_ty:
    :return:
    """
    ds = TDataStream(C_STREAM_PK=get_pk(),
                     C_PROJECT_PK=proj_pk,
                     C_STREAM_ID=code,
                     C_DATA_TYPE=data_ty,
                     C_STREAM_DIR=direct,
                     C_NAME=name,
                     C_DESCRIPTION='',
                     C_RULE_ID=rule_id,
                     C_RULE_TYPE=rule_ty)
    session.add(ds)
    link = TDevStream(C_PK=get_pk(), C_DEV_PK=dev_pk, C_STREAM_PK=ds.C_STREAM_PK, C_PROJECT_PK=proj_pk)
    session.add(link)
    rule_enc = None
    # 创建解析规则
    if rule_pk is None:
        rule_pk = get_pk()
        if rule_ty == '001':
            # 封装包
            rule_enc = create_rule_enc(proj_pk, rule_pk, rule_id, rule_id)
    rule = create_rule(proj_pk, ds.C_STREAM_PK, rule_id, name, None, None, '0')
    rule = create_rule(proj_pk, rule_pk, rule_id, rule_id, None, ds.C_STREAM_PK, '1')
    # rule stream
    rule_stream = create_rule_stream(proj_pk,
                                     rule_pk,
                                     ds.C_STREAM_PK,
                                     ds.C_STREAM_ID,
                                     ds.C_NAME,
                                     ds.C_STREAM_DIR,
                                     f"{ds.C_NAME}/{rule_id}/")
    session.add(rule_stream)
    session.commit()
    return ds, rule_stream, rule_enc
def create_rule(proj_pk, rule_pk, rule_id, rule_name, rule_len, parent_pk, flag, actual_parent_pk=None):
    rule = TRule(
        C_PK=get_pk(),
        C_PROJECT_PK=proj_pk,
        C_RULE_PK=rule_pk,
        C_RULE_ID=rule_id,
        C_RULE_NAME=rule_name,
        C_RULE_LENGTH=rule_len,
        C_PARENT_PK=parent_pk,
        C_FLAG=flag,
        C_ACTUAL_PARENT_PK=actual_parent_pk
    )
    session.add(rule)
    session.commit()
    return rule
def find_rule_by_rule_id(rule_id):
    return session.query(TRule).filter(TRule.C_RULE_ID == rule_id).first()
def create_rule_stream(proj_pk, rule_pk, stream_pk, stream_id, stream_name, stream_dir, _path):
    rule_stream = TRuleStream(
        C_PK=get_pk(),
        C_PROJECT_PK=proj_pk,
        C_RULE_PK=rule_pk,
        C_STREAM_PK=stream_pk,
        C_STREAM_ID=stream_id,
        C_STREAM_NAME=stream_name,
        C_STREAM_DIR=stream_dir,
        C_PATH=_path
    )
    session.add(rule_stream)
    session.commit()
    return rule_stream
def create_ref_ds_rule_stream(proj_pk, stream_pk, stream_id, stream_name, stream_dir, target_stream_pk):
    items: list = session.query(TRuleStream).filter(TRuleStream.C_STREAM_PK == target_stream_pk).all()
    for it in items:
        _path = it.C_PATH
        if len(_path.split('/')) == 3:
            continue
        _path = f'{stream_name}/{stream_id}/'.join(_path.split('/')[2:]) + '/'
        create_rule_stream(proj_pk, it.C_RULE_PK, stream_pk, stream_id, stream_name, stream_dir, _path)
def create_rule_enc(proj_pk, enc_pk, enc_id, name, content=None):
    rule_enc = TRuleEnc(
        C_ENC_PK=enc_pk,
        C_PROJECT_PK=proj_pk,
        C_ENC_ID=enc_id,
        C_NAME=name,
        C_CONTENT=content,
    )
    session.add(rule_enc)
    session.commit()
    return rule_enc
def create_rule_linear(proj_pk, linear_pk, linear_id, name, length, content):
    rule_linear = TRuleLinear(
        C_LINEAR_PK=linear_pk,
        C_PROJECT_PK=proj_pk,
        C_LINEAR_ID=linear_id,
        C_NAME=name,
        C_LENGTH=length,
        C_PACKAGE_TYPE=None,
        C_REL_LINEAR_PK=None,
        C_CONTENT=content
    )
    session.add(rule_linear)
    session.commit()
    return rule_linear
def create_property_enc(proj_pk, enc_pk, segment_id, name, ty, content, offset,
                        length, msb_first, mask, cond, seq, rel_enc_item_pk, para_id):
    property_enc = TPropertyEnc(
        C_ENCITEM_PK=get_pk(),
        C_ENC_PK=enc_pk,
        C_SEGMENT_ID=segment_id,
        C_NAME=name,
        C_TYPE=ty,
        C_CONTENT=content,
        C_PUBLISH=None,
        C_OFFSET=offset,
        C_LENGTH=length,
        C_MSBFIRST=msb_first,
        C_MASK=mask,
        C_CONDITION=cond,
        C_PROJECT_PK=proj_pk,
        C_SEQ=seq,
        C_REL_ENCITEM_PK=rel_enc_item_pk,
        C_PAR_ID=para_id
    )
    session.add(property_enc)
    para = TParameter(
        C_PAR_PK=get_pk(),
        C_PROJECT_PK=proj_pk,
        C_PAR_CODE=segment_id,
        C_PAR_NAME=name,
        C_SUBSYS=None,
        C_TYPE='0',
        C_UNIT=None,
        C_VALUE_RANGE=None,
        C_DIS_REQUIRE=None,
        C_MODULUS=None,
        C_PARAMS=None,
        C_PRECISION='0',
        C_REG_PK=None,
        C_METHOD_PK=None
    )
    session.add(para)
    if ty == 'ENUM' and content:
        items: list = content.split(' ')
        for item in items:
            idx = items.index(item)
            name, val = item.split(',')
            pt = TParameterType(
                C_PK=get_pk(),
                C_TYPE_ID=f'{idx}',
                C_TYPE_NAME=name,
                C_VALUE=val,
                C_DATA_TYPE=None,
                C_PAR_PK=para.C_PAR_PK,
                C_PROJECT_PK=proj_pk
            )
            session.add(pt)
    session.commit()
    return property_enc
def get_para_id(para_id):
    for i in range(1, 9999):
        _id = f'{i}'.zfill(4)
        _para_id = para_id + '_' + _id
        if _para_id not in _para_id_map:
            _para_id_map[_para_id] = True
            return _para_id
def create_property_linear(proj_pk, linear_pk, para_id, name, ty, content, offset,
                           length, msb_first, mask, cond, calc_expr, simuval, reg_par, params, seq):
    property_linear = TPropertyLinear(
        C_PK=get_pk(),
        C_LINEAR_PK=linear_pk,
        C_PAR_ID=para_id,
        C_TYPE=ty,
        C_CONTENT=content,
        C_OFFSET=offset,
        C_LENGTH=length,
        C_MSBFIRST=msb_first,
        C_MASK=mask,
        C_CONDITION=cond,
        C_CALC_EXPR=calc_expr,
        C_PAR_PK=get_pk(),
        C_SIMUVAL=simuval,
        C_REG_PAR=reg_par,
        C_PARAMS=params,
        C_PROJECT_PK=proj_pk,
        C_SEQ=seq,
        C_REL_PK=None
    )
    session.add(property_linear)
    if para_id in _para_id_map:
        get_para_id(para_id)
    para = TParameter(
        C_PAR_PK=property_linear.C_PAR_PK,
        C_PROJECT_PK=proj_pk,
        C_PAR_CODE=para_id,
        C_PAR_NAME=name,
        C_SUBSYS=None,
        C_TYPE=None,
        C_UNIT=None,
        C_VALUE_RANGE=None,
        C_DIS_REQUIRE=None,
        C_MODULUS=None,
        C_PARAMS=None,
        C_PRECISION='0',
        C_REG_PK=None,
        C_METHOD_PK=None
    )
    session.add(para)
    if ty == 'ENUM' and content:
        items: list = content.split(' ')
        for item in items:
            idx = items.index(item)
            name, val = item.split(',')
            pt = TParameterType(
                C_PK=get_pk(),
                C_TYPE_ID=f'{idx}',
                C_TYPE_NAME=name,
                C_VALUE=val,
                C_DATA_TYPE=None,
                C_PAR_PK=para.C_PAR_PK,
                C_PROJECT_PK=proj_pk
            )
            session.add(pt)
    session.commit()
    return property_linear
def create_enc_linear(proj_pk, enc_item_pk, ty, vals=None, linear_pk=None):
    """
    创建 t_enc_linear
    :param proj_pk: 工程pk
    :param enc_item_pk:
    :param ty: 001:封装包,002:线性包
    :param vals: 逻辑封装包的key值
    :return:
    """
    if linear_pk is None:
        linear_pk = get_pk()
    enc_linear = TEncLinear(
        C_PK=get_pk(),
        C_LINEAR_PK=linear_pk,
        C_ENCITEM_PK=enc_item_pk,
        C_VALS=vals,
        C_PROJECT_PK=proj_pk,
        C_TYPE=ty,
        C_FOLDER_PK=None
    )
    session.add(enc_linear)
    session.commit()
    return enc_linear
def update_rule_enc(rule_enc):
    # 更新
    session.query(TRuleEnc).filter(TRuleEnc.C_ENC_PK == rule_enc.C_ENC_PK).update({
        TRuleEnc.C_KEY: rule_enc.C_KEY
    })
    session.commit()
def create_rulekey_info(proj_pk, rule_pk, rule_id, rule_name, key_pk, key_id, key_name, key_val):
    info = TRulekeyInfo(
        C_PK=get_pk(),
        C_PROJECT_PK=proj_pk,
        C_RULE_PK=rule_pk,
        C_RULE_ID=rule_id,
        C_RULE_NAME=rule_name,
        C_KEY_PK=key_pk,
        C_KEY_ID=key_id,
        C_KEY_NAME=key_name,
        C_KEY_VAL=key_val
    )
    session.add(info)
    session.commit()
ins_ty = {
    "pkt": 1,
    "subPkt": 22,
    "combPkt": 12,
    "const": 15,
    "length": 17,
    "enum": 26,
    "checkSum": 20,
}
def create_ins_format(proj_pk: str, parent_pk: str, info: dict) -> TInsFormat:
    ins_format = TInsFormat(
        C_INS_FORMAT_PK=get_pk(),
        C_PROJECT_PK=proj_pk,
        C_PARENT_PK=parent_pk,
        C_ORDER=info['order'] if 'order' in info else 0,
        C_AUTOCODE=info['autocode'] if 'autocode' in info else None,
        C_NAME=info['name'] if 'name' in info else '',
        C_CODE=info['code'] if 'code' in info else '',
        C_TYPE=ins_ty[info['type']] if 'type' in info else 0,
        C_DEF=info['def'] if 'def' in info else None,
        C_BIT_WIDTH=info['bitWidth'] if 'bitWidth' in info else 0,
        C_BIT_ORDER=info['bitOrder'] if 'bitOrder' in info else 0,
        C_ATTR=info['attr'] if 'attr' in info else 0,
        C_RANGE=info['range'] if 'range' in info else None,
        C_CONDITION='',
        C_FORMULA=info['formula'] if 'formula' in info else '',
        C_NUMBER='',
    )
    session.add(ins_format)
    session.commit()
    return ins_format
knowledgebase/db/models.py
File was renamed from db/models.py
@@ -1,6 +1,7 @@
# coding: utf-8
from sqlalchemy import create_engine, Column, DateTime, Integer, Text
from sqlalchemy.ext.declarative import declarative_base
import os
Base = declarative_base()
metadata = Base.metadata
@@ -473,5 +474,7 @@
    C_EDIT = Column(Integer)
if os.path.isfile("db.db"):
    os.remove("db.db")
engine = create_engine('sqlite:///db.db', echo=True)
metadata.create_all(engine)
knowledgebase/markitdown/__about__.py
New file
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: 2024-present Adam Fourney <adamfo@microsoft.com>
#
# SPDX-License-Identifier: MIT
__version__ = "0.0.1a3"
knowledgebase/markitdown/__init__.py
New file
@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: 2024-present Adam Fourney <adamfo@microsoft.com>
#
# SPDX-License-Identifier: MIT
from ._markitdown import MarkItDown, FileConversionException, UnsupportedFormatException
__all__ = [
    "MarkItDown",
    "FileConversionException",
    "UnsupportedFormatException",
]
knowledgebase/markitdown/__main__.py
New file
@@ -0,0 +1,82 @@
# SPDX-FileCopyrightText: 2024-present Adam Fourney <adamfo@microsoft.com>
#
# SPDX-License-Identifier: MIT
import argparse
import sys
from textwrap import dedent
from .__about__ import __version__
from ._markitdown import MarkItDown, DocumentConverterResult
def main():
    parser = argparse.ArgumentParser(
        description="Convert various file formats to markdown.",
        prog="markitdown",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        usage=dedent(
            """
            SYNTAX:
                markitdown <OPTIONAL: FILENAME>
                If FILENAME is empty, markitdown reads from stdin.
            EXAMPLE:
                markitdown example.pdf
                OR
                cat example.pdf | markitdown
                OR
                markitdown < example.pdf
                OR to save to a file use
                markitdown example.pdf -o example.md
                OR
                markitdown example.pdf > example.md
            """
        ).strip(),
    )
    parser.add_argument(
        "-v",
        "--version",
        action="version",
        version=f"%(prog)s {__version__}",
        help="show the version number and exit",
    )
    parser.add_argument("filename", nargs="?")
    parser.add_argument(
        "-o",
        "--output",
        help="Output file name. If not provided, output is written to stdout.",
    )
    args = parser.parse_args()
    if args.filename is None:
        markitdown = MarkItDown()
        result = markitdown.convert_stream(sys.stdin.buffer)
        _handle_output(args, result)
    else:
        markitdown = MarkItDown()
        result = markitdown.convert(args.filename)
        _handle_output(args, result)
def _handle_output(args, result: DocumentConverterResult):
    """Handle output to stdout or file"""
    if args.output:
        with open(args.output, "w", encoding="utf-8") as f:
            f.write(result.text_content)
    else:
        print(result.text_content)
if __name__ == "__main__":
    main()
knowledgebase/markitdown/_markitdown.py
New file
@@ -0,0 +1,1708 @@
# type: ignore
import base64
import binascii
import copy
import html
import json
import mimetypes
import os
import re
import shutil
import subprocess
import sys
import tempfile
import traceback
import zipfile
from xml.dom import minidom
from typing import Any, Dict, List, Optional, Union
from pathlib import Path
from urllib.parse import parse_qs, quote, unquote, urlparse, urlunparse
from warnings import warn, resetwarnings, catch_warnings
import mammoth
import markdownify
import olefile
import pandas as pd
import pdfminer
import pdfminer.high_level
import pptx
# File-format detection
import puremagic
import requests
from bs4 import BeautifulSoup
from charset_normalizer import from_path
from bs4 import BeautifulSoup
# Optional Transcription support
IS_AUDIO_TRANSCRIPTION_CAPABLE = False
try:
    # Using warnings' catch_warnings to catch
    # pydub's warning of ffmpeg or avconv missing
    with catch_warnings(record=True) as w:
        import pydub
        if w:
            raise ModuleNotFoundError
    import speech_recognition as sr
    IS_AUDIO_TRANSCRIPTION_CAPABLE = True
except ModuleNotFoundError:
    pass
finally:
    resetwarnings()
# Optional YouTube transcription support
try:
    from youtube_transcript_api import YouTubeTranscriptApi
    IS_YOUTUBE_TRANSCRIPT_CAPABLE = True
except ModuleNotFoundError:
    pass
class _CustomMarkdownify(markdownify.MarkdownConverter):
    """
    A custom version of markdownify's MarkdownConverter. Changes include:
    - Altering the default heading style to use '#', '##', etc.
    - Removing javascript hyperlinks.
    - Truncating images with large data:uri sources.
    - Ensuring URIs are properly escaped, and do not conflict with Markdown syntax
    """
    def __init__(self, **options: Any):
        options["heading_style"] = options.get("heading_style", markdownify.ATX)
        # Explicitly cast options to the expected type if necessary
        super().__init__(**options)
    def convert_hn(self, n: int, el: Any, text: str, convert_as_inline: bool) -> str:
        """Same as usual, but be sure to start with a new line"""
        if not convert_as_inline:
            if not re.search(r"^\n", text):
                return "\n" + super().convert_hn(n, el, text, convert_as_inline)  # type: ignore
        return super().convert_hn(n, el, text, convert_as_inline)  # type: ignore
    def convert_a(self, el: Any, text: str, convert_as_inline: bool):
        """Same as usual converter, but removes Javascript links and escapes URIs."""
        prefix, suffix, text = markdownify.chomp(text)  # type: ignore
        if not text:
            return ""
        href = el.get("href")
        title = el.get("title")
        # Escape URIs and skip non-http or file schemes
        if href:
            try:
                parsed_url = urlparse(href)  # type: ignore
                if parsed_url.scheme and parsed_url.scheme.lower() not in ["http", "https", "file"]:  # type: ignore
                    return "%s%s%s" % (prefix, text, suffix)
                href = urlunparse(parsed_url._replace(path=quote(unquote(parsed_url.path))))  # type: ignore
            except ValueError:  # It's not clear if this ever gets thrown
                return "%s%s%s" % (prefix, text, suffix)
        # For the replacement see #29: text nodes underscores are escaped
        if (
            self.options["autolinks"]
            and text.replace(r"\_", "_") == href
            and not title
            and not self.options["default_title"]
        ):
            # Shortcut syntax
            return "<%s>" % href
        if self.options["default_title"] and not title:
            title = href
        title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
        return (
            "%s[%s](%s%s)%s" % (prefix, text, href, title_part, suffix)
            if href
            else text
        )
    def convert_img(self, el: Any, text: str, convert_as_inline: bool) -> str:
        """Same as usual converter, but removes data URIs"""
        alt = el.attrs.get("alt", None) or ""
        src = el.attrs.get("knowledgebase", None) or ""
        title = el.attrs.get("title", None) or ""
        title_part = ' "%s"' % title.replace('"', r"\"") if title else ""
        if (
            convert_as_inline
            and el.parent.name not in self.options["keep_inline_images_in"]
        ):
            return alt
        # Remove dataURIs
        if src.startswith("data:"):
            src = src.split(",")[0] + "..."
        return "![%s](%s%s)" % (alt, src, title_part)
    def convert_soup(self, soup: Any) -> str:
        return super().convert_soup(soup)  # type: ignore
class DocumentConverterResult:
    """The result of converting a document to text."""
    def __init__(self, title: Union[str, None] = None, text_content: str = ""):
        self.title: Union[str, None] = title
        self.text_content: str = text_content
class DocumentConverter:
    """Abstract superclass of all DocumentConverters."""
    def convert(
        self, local_path: str, **kwargs: Any
    ) -> Union[None, DocumentConverterResult]:
        raise NotImplementedError()
class PlainTextConverter(DocumentConverter):
    """Anything with content type text/plain"""
    def convert(
        self, local_path: str, **kwargs: Any
    ) -> Union[None, DocumentConverterResult]:
        # Guess the content type from any file extension that might be around
        content_type, _ = mimetypes.guess_type(
            "__placeholder" + kwargs.get("file_extension", "")
        )
        # Only accept text files
        if content_type is None:
            return None
        elif all(
            not content_type.lower().startswith(type_prefix)
            for type_prefix in ["text/", "application/json"]
        ):
            return None
        text_content = str(from_path(local_path).best())
        return DocumentConverterResult(
            title=None,
            text_content=text_content,
        )
class HtmlConverter(DocumentConverter):
    """Anything with content type text/html"""
    def convert(
        self, local_path: str, **kwargs: Any
    ) -> Union[None, DocumentConverterResult]:
        # Bail if not html
        extension = kwargs.get("file_extension", "")
        if extension.lower() not in [".html", ".htm"]:
            return None
        result = None
        with open(local_path, "rt", encoding="utf-8") as fh:
            result = self._convert(fh.read())
        return result
    def _convert(self, html_content: str) -> Union[None, DocumentConverterResult]:
        """Helper function that converts and HTML string."""
        # Parse the string
        soup = BeautifulSoup(html_content, "html.parser")
        # Remove javascript and style blocks
        for script in soup(["script", "style"]):
            script.extract()
        # Print only the main content
        body_elm = soup.find("body")
        webpage_text = ""
        if body_elm:
            webpage_text = _CustomMarkdownify().convert_soup(body_elm)
        else:
            webpage_text = _CustomMarkdownify().convert_soup(soup)
        assert isinstance(webpage_text, str)
        return DocumentConverterResult(
            title=None if soup.title is None else soup.title.string,
            text_content=webpage_text,
        )
class RSSConverter(DocumentConverter):
    """Convert RSS / Atom type to markdown"""
    def convert(
        self, local_path: str, **kwargs
    ) -> Union[None, DocumentConverterResult]:
        # Bail if not RSS type
        extension = kwargs.get("file_extension", "")
        if extension.lower() not in [".xml", ".rss", ".atom"]:
            return None
        try:
            doc = minidom.parse(local_path)
        except BaseException as _:
            return None
        result = None
        if doc.getElementsByTagName("rss"):
            # A RSS feed must have a root element of <rss>
            result = self._parse_rss_type(doc)
        elif doc.getElementsByTagName("feed"):
            root = doc.getElementsByTagName("feed")[0]
            if root.getElementsByTagName("entry"):
                # An Atom feed must have a root element of <feed> and at least one <entry>
                result = self._parse_atom_type(doc)
            else:
                return None
        else:
            # not rss or atom
            return None
        return result
    def _parse_atom_type(
        self, doc: minidom.Document
    ) -> Union[None, DocumentConverterResult]:
        """Parse the type of an Atom feed.
        Returns None if the feed type is not recognized or something goes wrong.
        """
        try:
            root = doc.getElementsByTagName("feed")[0]
            title = self._get_data_by_tag_name(root, "title")
            subtitle = self._get_data_by_tag_name(root, "subtitle")
            entries = root.getElementsByTagName("entry")
            md_text = f"# {title}\n"
            if subtitle:
                md_text += f"{subtitle}\n"
            for entry in entries:
                entry_title = self._get_data_by_tag_name(entry, "title")
                entry_summary = self._get_data_by_tag_name(entry, "summary")
                entry_updated = self._get_data_by_tag_name(entry, "updated")
                entry_content = self._get_data_by_tag_name(entry, "content")
                if entry_title:
                    md_text += f"\n## {entry_title}\n"
                if entry_updated:
                    md_text += f"Updated on: {entry_updated}\n"
                if entry_summary:
                    md_text += self._parse_content(entry_summary)
                if entry_content:
                    md_text += self._parse_content(entry_content)
            return DocumentConverterResult(
                title=title,
                text_content=md_text,
            )
        except BaseException as _:
            return None
    def _parse_rss_type(
        self, doc: minidom.Document
    ) -> Union[None, DocumentConverterResult]:
        """Parse the type of an RSS feed.
        Returns None if the feed type is not recognized or something goes wrong.
        """
        try:
            root = doc.getElementsByTagName("rss")[0]
            channel = root.getElementsByTagName("channel")
            if not channel:
                return None
            channel = channel[0]
            channel_title = self._get_data_by_tag_name(channel, "title")
            channel_description = self._get_data_by_tag_name(channel, "description")
            items = channel.getElementsByTagName("item")
            if channel_title:
                md_text = f"# {channel_title}\n"
            if channel_description:
                md_text += f"{channel_description}\n"
            if not items:
                items = []
            for item in items:
                title = self._get_data_by_tag_name(item, "title")
                description = self._get_data_by_tag_name(item, "description")
                pubDate = self._get_data_by_tag_name(item, "pubDate")
                content = self._get_data_by_tag_name(item, "content:encoded")
                if title:
                    md_text += f"\n## {title}\n"
                if pubDate:
                    md_text += f"Published on: {pubDate}\n"
                if description:
                    md_text += self._parse_content(description)
                if content:
                    md_text += self._parse_content(content)
            return DocumentConverterResult(
                title=channel_title,
                text_content=md_text,
            )
        except BaseException as _:
            print(traceback.format_exc())
            return None
    def _parse_content(self, content: str) -> str:
        """Parse the content of an RSS feed item"""
        try:
            # using bs4 because many RSS feeds have HTML-styled content
            soup = BeautifulSoup(content, "html.parser")
            return _CustomMarkdownify().convert_soup(soup)
        except BaseException as _:
            return content
    def _get_data_by_tag_name(
        self, element: minidom.Element, tag_name: str
    ) -> Union[str, None]:
        """Get data from first child element with the given tag name.
        Returns None when no such element is found.
        """
        nodes = element.getElementsByTagName(tag_name)
        if not nodes:
            return None
        fc = nodes[0].firstChild
        if fc:
            return fc.data
        return None
class WikipediaConverter(DocumentConverter):
    """Handle Wikipedia pages separately, focusing only on the main document content."""
    def convert(
        self, local_path: str, **kwargs: Any
    ) -> Union[None, DocumentConverterResult]:
        # Bail if not Wikipedia
        extension = kwargs.get("file_extension", "")
        if extension.lower() not in [".html", ".htm"]:
            return None
        url = kwargs.get("url", "")
        if not re.search(r"^https?:\/\/[a-zA-Z]{2,3}\.wikipedia.org\/", url):
            return None
        # Parse the file
        soup = None
        with open(local_path, "rt", encoding="utf-8") as fh:
            soup = BeautifulSoup(fh.read(), "html.parser")
        # Remove javascript and style blocks
        for script in soup(["script", "style"]):
            script.extract()
        # Print only the main content
        body_elm = soup.find("div", {"id": "mw-content-text"})
        title_elm = soup.find("span", {"class": "mw-page-title-main"})
        webpage_text = ""
        main_title = None if soup.title is None else soup.title.string
        if body_elm:
            # What's the title
            if title_elm and len(title_elm) > 0:
                main_title = title_elm.string  # type: ignore
                assert isinstance(main_title, str)
            # Convert the page
            webpage_text = f"# {main_title}\n\n" + _CustomMarkdownify().convert_soup(
                body_elm
            )
        else:
            webpage_text = _CustomMarkdownify().convert_soup(soup)
        return DocumentConverterResult(
            title=main_title,
            text_content=webpage_text,
        )
class YouTubeConverter(DocumentConverter):
    """Handle YouTube specially, focusing on the video title, description, and transcript."""
    def convert(
        self, local_path: str, **kwargs: Any
    ) -> Union[None, DocumentConverterResult]:
        # Bail if not YouTube
        extension = kwargs.get("file_extension", "")
        if extension.lower() not in [".html", ".htm"]:
            return None
        url = kwargs.get("url", "")
        if not url.startswith("https://www.youtube.com/watch?"):
            return None
        # Parse the file
        soup = None
        with open(local_path, "rt", encoding="utf-8") as fh:
            soup = BeautifulSoup(fh.read(), "html.parser")
        # Read the meta tags
        assert soup.title is not None and soup.title.string is not None
        metadata: Dict[str, str] = {"title": soup.title.string}
        for meta in soup(["meta"]):
            for a in meta.attrs:
                if a in ["itemprop", "property", "name"]:
                    metadata[meta[a]] = meta.get("content", "")
                    break
        # We can also try to read the full description. This is more prone to breaking, since it reaches into the page implementation
        try:
            for script in soup(["script"]):
                content = script.text
                if "ytInitialData" in content:
                    lines = re.split(r"\r?\n", content)
                    obj_start = lines[0].find("{")
                    obj_end = lines[0].rfind("}")
                    if obj_start >= 0 and obj_end >= 0:
                        data = json.loads(lines[0][obj_start : obj_end + 1])
                        attrdesc = self._findKey(data, "attributedDescriptionBodyText")  # type: ignore
                        if attrdesc:
                            metadata["description"] = str(attrdesc["content"])
                    break
        except Exception:
            pass
        # Start preparing the page
        webpage_text = "# YouTube\n"
        title = self._get(metadata, ["title", "og:title", "name"])  # type: ignore
        assert isinstance(title, str)
        if title:
            webpage_text += f"\n## {title}\n"
        stats = ""
        views = self._get(metadata, ["interactionCount"])  # type: ignore
        if views:
            stats += f"- **Views:** {views}\n"
        keywords = self._get(metadata, ["keywords"])  # type: ignore
        if keywords:
            stats += f"- **Keywords:** {keywords}\n"
        runtime = self._get(metadata, ["duration"])  # type: ignore
        if runtime:
            stats += f"- **Runtime:** {runtime}\n"
        if len(stats) > 0:
            webpage_text += f"\n### Video Metadata\n{stats}\n"
        description = self._get(metadata, ["description", "og:description"])  # type: ignore
        if description:
            webpage_text += f"\n### Description\n{description}\n"
        if IS_YOUTUBE_TRANSCRIPT_CAPABLE:
            transcript_text = ""
            parsed_url = urlparse(url)  # type: ignore
            params = parse_qs(parsed_url.query)  # type: ignore
            if "v" in params:
                assert isinstance(params["v"][0], str)
                video_id = str(params["v"][0])
                try:
                    youtube_transcript_languages = kwargs.get(
                        "youtube_transcript_languages", ("en",)
                    )
                    # Must be a single transcript.
                    transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=youtube_transcript_languages)  # type: ignore
                    transcript_text = " ".join([part["text"] for part in transcript])  # type: ignore
                    # Alternative formatting:
                    # formatter = TextFormatter()
                    # formatter.format_transcript(transcript)
                except Exception:
                    pass
            if transcript_text:
                webpage_text += f"\n### Transcript\n{transcript_text}\n"
        title = title if title else soup.title.string
        assert isinstance(title, str)
        return DocumentConverterResult(
            title=title,
            text_content=webpage_text,
        )
    def _get(
        self,
        metadata: Dict[str, str],
        keys: List[str],
        default: Union[str, None] = None,
    ) -> Union[str, None]:
        for k in keys:
            if k in metadata:
                return metadata[k]
        return default
    def _findKey(self, json: Any, key: str) -> Union[str, None]:  # TODO: Fix json type
        if isinstance(json, list):
            for elm in json:
                ret = self._findKey(elm, key)
                if ret is not None:
                    return ret
        elif isinstance(json, dict):
            for k in json:
                if k == key:
                    return json[k]
                else:
                    ret = self._findKey(json[k], key)
                    if ret is not None:
                        return ret
        return None
class IpynbConverter(DocumentConverter):
    """Converts Jupyter Notebook (.ipynb) files to Markdown."""
    def convert(
        self, local_path: str, **kwargs: Any
    ) -> Union[None, DocumentConverterResult]:
        # Bail if not ipynb
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".ipynb":
            return None
        # Parse and convert the notebook
        result = None
        with open(local_path, "rt", encoding="utf-8") as fh:
            notebook_content = json.load(fh)
            result = self._convert(notebook_content)
        return result
    def _convert(self, notebook_content: dict) -> Union[None, DocumentConverterResult]:
        """Helper function that converts notebook JSON content to Markdown."""
        try:
            md_output = []
            title = None
            for cell in notebook_content.get("cells", []):
                cell_type = cell.get("cell_type", "")
                source_lines = cell.get("source", [])
                if cell_type == "markdown":
                    md_output.append("".join(source_lines))
                    # Extract the first # heading as title if not already found
                    if title is None:
                        for line in source_lines:
                            if line.startswith("# "):
                                title = line.lstrip("# ").strip()
                                break
                elif cell_type == "code":
                    # Code cells are wrapped in Markdown code blocks
                    md_output.append(f"```python\n{''.join(source_lines)}\n```")
                elif cell_type == "raw":
                    md_output.append(f"```\n{''.join(source_lines)}\n```")
            md_text = "\n\n".join(md_output)
            # Check for title in notebook metadata
            title = notebook_content.get("metadata", {}).get("title", title)
            return DocumentConverterResult(
                title=title,
                text_content=md_text,
            )
        except Exception as e:
            raise FileConversionException(
                f"Error converting .ipynb file: {str(e)}"
            ) from e
class BingSerpConverter(DocumentConverter):
    """
    Handle Bing results pages (only the organic search results).
    NOTE: It is better to use the Bing API
    """
    def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
        # Bail if not a Bing SERP
        extension = kwargs.get("file_extension", "")
        if extension.lower() not in [".html", ".htm"]:
            return None
        url = kwargs.get("url", "")
        if not re.search(r"^https://www\.bing\.com/search\?q=", url):
            return None
        # Parse the query parameters
        parsed_params = parse_qs(urlparse(url).query)
        query = parsed_params.get("q", [""])[0]
        # Parse the file
        soup = None
        with open(local_path, "rt", encoding="utf-8") as fh:
            soup = BeautifulSoup(fh.read(), "html.parser")
        # Clean up some formatting
        for tptt in soup.find_all(class_="tptt"):
            if hasattr(tptt, "string") and tptt.string:
                tptt.string += " "
        for slug in soup.find_all(class_="algoSlug_icon"):
            slug.extract()
        # Parse the algorithmic results
        _markdownify = _CustomMarkdownify()
        results = list()
        for result in soup.find_all(class_="b_algo"):
            # Rewrite redirect urls
            for a in result.find_all("a", href=True):
                parsed_href = urlparse(a["href"])
                qs = parse_qs(parsed_href.query)
                # The destination is contained in the u parameter,
                # but appears to be base64 encoded, with some prefix
                if "u" in qs:
                    u = (
                        qs["u"][0][2:].strip() + "=="
                    )  # Python 3 doesn't care about extra padding
                    try:
                        # RFC 4648 / Base64URL" variant, which uses "-" and "_"
                        a["href"] = base64.b64decode(u, altchars="-_").decode("utf-8")
                    except UnicodeDecodeError:
                        pass
                    except binascii.Error:
                        pass
            # Convert to markdown
            md_result = _markdownify.convert_soup(result).strip()
            lines = [line.strip() for line in re.split(r"\n+", md_result)]
            results.append("\n".join([line for line in lines if len(line) > 0]))
        webpage_text = (
            f"## A Bing search for '{query}' found the following results:\n\n"
            + "\n\n".join(results)
        )
        return DocumentConverterResult(
            title=None if soup.title is None else soup.title.string,
            text_content=webpage_text,
        )
class PdfConverter(DocumentConverter):
    """
    Converts PDFs to Markdown. Most style information is ignored, so the results are essentially plain-text.
    """
    def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
        # Bail if not a PDF
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".pdf":
            return None
        return DocumentConverterResult(
            title=None,
            text_content=pdfminer.high_level.extract_text(local_path),
        )
class DocxConverter(HtmlConverter):
    """
    Converts DOCX files to Markdown. Style information (e.g.m headings) and tables are preserved where possible.
    """
    def table_unmerge(self,html):
        # 解析HTML
        soup = BeautifulSoup(html, "html.parser")
        # 获取所有表格
        tables = soup.find_all("table")
        # 遍历每个表格
        for table in tables:
            # 获取表格的行数和列数
            rows = table.find_all("tr")
            row_count = len(rows)
            col_count = max([len(row.find_all(["td", "th"])) for row in rows])
            # 创建一个二维数组来存储表格的数据
            data = []
            for i in range(row_count):
                data.append([])
            # 遍历每个单元格
            for i, row in enumerate(rows):
                cells = row.find_all(["td", "th"])
                for j, cell in enumerate(cells):
                    # 获取之前的所有合并单元格数量
                    # 获取单元格的行列数
                    rowspan = int(cell.get("rowspan", 1))
                    colspan = int(cell.get("colspan", 1))
                    data[i].append([cell.get_text().strip(), rowspan, colspan])
            # 水平合并
            for i in range(len(data)):
                row = data[i]
                for j in range(len(row) - 1, -1, -1):
                    col = row[j]
                    v, rs, cs = col
                    col[2] = 1
                    for k in range(1, cs):
                        row.insert(j + k, [v, rs, 1])
            # 垂直合并
            for i in range(len(data)):
                row = data[i]
                for j in range(len(row)):
                    col = row[j]
                    v, rs, cs = col
                    col[1] = 1
                    for k in range(1, rs):
                        data[i + k].insert(j, [v, 1, cs])
            # 将data转为value二维数组
            result = []
            for i in range(len(data)):
                row = data[i]
                result.append([])
                for j in range(len(row)):
                    col = row[j]
                    v, rs, cs = col
                    result[i].append(v)
            # 将表格的数据转换为DataFrame
            df = pd.DataFrame(result)
            # 将DataFrame转换为HTML表格
            html_table = df.to_html(index=False, header=False)
            # 将HTML表格替换原来的表格
            table.replace_with(BeautifulSoup(html_table, "html.parser"))
        return str(soup)
    def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
        # Bail if not a DOCX
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".docx":
            return None
        result = None
        with open(local_path, "rb") as docx_file:
            style_map = kwargs.get("style_map", None)
            result = mammoth.convert_to_html(docx_file, style_map=style_map)
            html_content = self.table_unmerge(result.value)
            result = self._convert(html_content)
        return result
class XlsxConverter(HtmlConverter):
    """
    Converts XLSX files to Markdown, with each sheet presented as a separate Markdown table.
    """
    def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
        # Bail if not a XLSX
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".xlsx":
            return None
        sheets = pd.read_excel(local_path, sheet_name=None, engine="openpyxl")
        md_content = ""
        for s in sheets:
            md_content += f"## {s}\n"
            html_content = sheets[s].to_html(index=False)
            md_content += self._convert(html_content).text_content.strip() + "\n\n"
        return DocumentConverterResult(
            title=None,
            text_content=md_content.strip(),
        )
class XlsConverter(HtmlConverter):
    """
    Converts XLS files to Markdown, with each sheet presented as a separate Markdown table.
    """
    def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
        # Bail if not a XLS
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".xls":
            return None
        sheets = pd.read_excel(local_path, sheet_name=None, engine="xlrd")
        md_content = ""
        for s in sheets:
            md_content += f"## {s}\n"
            html_content = sheets[s].to_html(index=False)
            md_content += self._convert(html_content).text_content.strip() + "\n\n"
        return DocumentConverterResult(
            title=None,
            text_content=md_content.strip(),
        )
class PptxConverter(HtmlConverter):
    """
    Converts PPTX files to Markdown. Supports heading, tables and images with alt text.
    """
    def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
        # Bail if not a PPTX
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".pptx":
            return None
        md_content = ""
        presentation = pptx.Presentation(local_path)
        slide_num = 0
        for slide in presentation.slides:
            slide_num += 1
            md_content += f"\n\n<!-- Slide number: {slide_num} -->\n"
            title = slide.shapes.title
            for shape in slide.shapes:
                # Pictures
                if self._is_picture(shape):
                    # https://github.com/scanny/python-pptx/pull/512#issuecomment-1713100069
                    alt_text = ""
                    try:
                        alt_text = shape._element._nvXxPr.cNvPr.attrib.get("descr", "")
                    except Exception:
                        pass
                    # A placeholder name
                    filename = re.sub(r"\W", "", shape.name) + ".jpg"
                    md_content += (
                        "\n!["
                        + (alt_text if alt_text else shape.name)
                        + "]("
                        + filename
                        + ")\n"
                    )
                # Tables
                if self._is_table(shape):
                    html_table = "<html><body><table>"
                    first_row = True
                    for row in shape.table.rows:
                        html_table += "<tr>"
                        for cell in row.cells:
                            if first_row:
                                html_table += "<th>" + html.escape(cell.text) + "</th>"
                            else:
                                html_table += "<td>" + html.escape(cell.text) + "</td>"
                        html_table += "</tr>"
                        first_row = False
                    html_table += "</table></body></html>"
                    md_content += (
                        "\n" + self._convert(html_table).text_content.strip() + "\n"
                    )
                # Charts
                if shape.has_chart:
                    md_content += self._convert_chart_to_markdown(shape.chart)
                # Text areas
                elif shape.has_text_frame:
                    if shape == title:
                        md_content += "# " + shape.text.lstrip() + "\n"
                    else:
                        md_content += shape.text + "\n"
            md_content = md_content.strip()
            if slide.has_notes_slide:
                md_content += "\n\n### Notes:\n"
                notes_frame = slide.notes_slide.notes_text_frame
                if notes_frame is not None:
                    md_content += notes_frame.text
                md_content = md_content.strip()
        return DocumentConverterResult(
            title=None,
            text_content=md_content.strip(),
        )
    def _is_picture(self, shape):
        if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.PICTURE:
            return True
        if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.PLACEHOLDER:
            if hasattr(shape, "image"):
                return True
        return False
    def _is_table(self, shape):
        if shape.shape_type == pptx.enum.shapes.MSO_SHAPE_TYPE.TABLE:
            return True
        return False
    def _convert_chart_to_markdown(self, chart):
        md = "\n\n### Chart"
        if chart.has_title:
            md += f": {chart.chart_title.text_frame.text}"
        md += "\n\n"
        data = []
        category_names = [c.label for c in chart.plots[0].categories]
        series_names = [s.name for s in chart.series]
        data.append(["Category"] + series_names)
        for idx, category in enumerate(category_names):
            row = [category]
            for series in chart.series:
                row.append(series.values[idx])
            data.append(row)
        markdown_table = []
        for row in data:
            markdown_table.append("| " + " | ".join(map(str, row)) + " |")
        header = markdown_table[0]
        separator = "|" + "|".join(["---"] * len(data[0])) + "|"
        return md + "\n".join([header, separator] + markdown_table[1:])
class MediaConverter(DocumentConverter):
    """
    Abstract class for multi-modal media (e.g., images and audio)
    """
    def _get_metadata(self, local_path):
        exiftool = shutil.which("exiftool")
        if not exiftool:
            return None
        else:
            try:
                result = subprocess.run(
                    [exiftool, "-json", local_path], capture_output=True, text=True
                ).stdout
                return json.loads(result)[0]
            except Exception:
                return None
class WavConverter(MediaConverter):
    """
    Converts WAV files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` is installed).
    """
    def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
        # Bail if not a WAV
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".wav":
            return None
        md_content = ""
        # Add metadata
        metadata = self._get_metadata(local_path)
        if metadata:
            for f in [
                "Title",
                "Artist",
                "Author",
                "Band",
                "Album",
                "Genre",
                "Track",
                "DateTimeOriginal",
                "CreateDate",
                "Duration",
            ]:
                if f in metadata:
                    md_content += f"{f}: {metadata[f]}\n"
        # Transcribe
        if IS_AUDIO_TRANSCRIPTION_CAPABLE:
            try:
                transcript = self._transcribe_audio(local_path)
                md_content += "\n\n### Audio Transcript:\n" + (
                    "[No speech detected]" if transcript == "" else transcript
                )
            except Exception:
                md_content += (
                    "\n\n### Audio Transcript:\nError. Could not transcribe this audio."
                )
        return DocumentConverterResult(
            title=None,
            text_content=md_content.strip(),
        )
    def _transcribe_audio(self, local_path) -> str:
        recognizer = sr.Recognizer()
        with sr.AudioFile(local_path) as source:
            audio = recognizer.record(source)
            return recognizer.recognize_google(audio).strip()
class Mp3Converter(WavConverter):
    """
    Converts MP3 files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` AND `pydub` are installed).
    """
    def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
        # Bail if not a MP3
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".mp3":
            return None
        md_content = ""
        # Add metadata
        metadata = self._get_metadata(local_path)
        if metadata:
            for f in [
                "Title",
                "Artist",
                "Author",
                "Band",
                "Album",
                "Genre",
                "Track",
                "DateTimeOriginal",
                "CreateDate",
                "Duration",
            ]:
                if f in metadata:
                    md_content += f"{f}: {metadata[f]}\n"
        # Transcribe
        if IS_AUDIO_TRANSCRIPTION_CAPABLE:
            handle, temp_path = tempfile.mkstemp(suffix=".wav")
            os.close(handle)
            try:
                sound = pydub.AudioSegment.from_mp3(local_path)
                sound.export(temp_path, format="wav")
                _args = dict()
                _args.update(kwargs)
                _args["file_extension"] = ".wav"
                try:
                    transcript = super()._transcribe_audio(temp_path).strip()
                    md_content += "\n\n### Audio Transcript:\n" + (
                        "[No speech detected]" if transcript == "" else transcript
                    )
                except Exception:
                    md_content += "\n\n### Audio Transcript:\nError. Could not transcribe this audio."
            finally:
                os.unlink(temp_path)
        # Return the result
        return DocumentConverterResult(
            title=None,
            text_content=md_content.strip(),
        )
class ImageConverter(MediaConverter):
    """
    Converts images to markdown via extraction of metadata (if `exiftool` is installed), OCR (if `easyocr` is installed), and description via a multimodal LLM (if an llm_client is configured).
    """
    def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
        # Bail if not an image
        extension = kwargs.get("file_extension", "")
        if extension.lower() not in [".jpg", ".jpeg", ".png"]:
            return None
        md_content = ""
        # Add metadata
        metadata = self._get_metadata(local_path)
        if metadata:
            for f in [
                "ImageSize",
                "Title",
                "Caption",
                "Description",
                "Keywords",
                "Artist",
                "Author",
                "DateTimeOriginal",
                "CreateDate",
                "GPSPosition",
            ]:
                if f in metadata:
                    md_content += f"{f}: {metadata[f]}\n"
        # Try describing the image with GPTV
        llm_client = kwargs.get("llm_client")
        llm_model = kwargs.get("llm_model")
        if llm_client is not None and llm_model is not None:
            md_content += (
                "\n# Description:\n"
                + self._get_llm_description(
                    local_path,
                    extension,
                    llm_client,
                    llm_model,
                    prompt=kwargs.get("llm_prompt"),
                ).strip()
                + "\n"
            )
        return DocumentConverterResult(
            title=None,
            text_content=md_content,
        )
    def _get_llm_description(self, local_path, extension, client, model, prompt=None):
        if prompt is None or prompt.strip() == "":
            prompt = "Write a detailed caption for this image."
        data_uri = ""
        with open(local_path, "rb") as image_file:
            content_type, encoding = mimetypes.guess_type("_dummy" + extension)
            if content_type is None:
                content_type = "image/jpeg"
            image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
            data_uri = f"data:{content_type};base64,{image_base64}"
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": data_uri,
                        },
                    },
                ],
            }
        ]
        response = client.chat.completions.create(model=model, messages=messages)
        return response.choices[0].message.content
class OutlookMsgConverter(DocumentConverter):
    """Converts Outlook .msg files to markdown by extracting email metadata and content.
    Uses the olefile package to parse the .msg file structure and extract:
    - Email headers (From, To, Subject)
    - Email body content
    """
    def convert(
        self, local_path: str, **kwargs: Any
    ) -> Union[None, DocumentConverterResult]:
        # Bail if not a MSG file
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".msg":
            return None
        try:
            msg = olefile.OleFileIO(local_path)
            # Extract email metadata
            md_content = "# Email Message\n\n"
            # Get headers
            headers = {
                "From": self._get_stream_data(msg, "__substg1.0_0C1F001F"),
                "To": self._get_stream_data(msg, "__substg1.0_0E04001F"),
                "Subject": self._get_stream_data(msg, "__substg1.0_0037001F"),
            }
            # Add headers to markdown
            for key, value in headers.items():
                if value:
                    md_content += f"**{key}:** {value}\n"
            md_content += "\n## Content\n\n"
            # Get email body
            body = self._get_stream_data(msg, "__substg1.0_1000001F")
            if body:
                md_content += body
            msg.close()
            return DocumentConverterResult(
                title=headers.get("Subject"), text_content=md_content.strip()
            )
        except Exception as e:
            raise FileConversionException(
                f"Could not convert MSG file '{local_path}': {str(e)}"
            )
    def _get_stream_data(
        self, msg: olefile.OleFileIO, stream_path: str
    ) -> Union[str, None]:
        """Helper to safely extract and decode stream data from the MSG file."""
        try:
            if msg.exists(stream_path):
                data = msg.openstream(stream_path).read()
                # Try UTF-16 first (common for .msg files)
                try:
                    return data.decode("utf-16-le").strip()
                except UnicodeDecodeError:
                    # Fall back to UTF-8
                    try:
                        return data.decode("utf-8").strip()
                    except UnicodeDecodeError:
                        # Last resort - ignore errors
                        return data.decode("utf-8", errors="ignore").strip()
        except Exception:
            pass
        return None
class ZipConverter(DocumentConverter):
    """Converts ZIP files to markdown by extracting and converting all contained files.
    The converter extracts the ZIP contents to a temporary directory, processes each file
    using appropriate converters based on file extensions, and then combines the results
    into a single markdown document. The temporary directory is cleaned up after processing.
    Example output format:
    ```markdown
    Content from the zip file `example.zip`:
    ## File: docs/readme.txt
    This is the content of readme.txt
    Multiple lines are preserved
    ## File: images/example.jpg
    ImageSize: 1920x1080
    DateTimeOriginal: 2024-02-15 14:30:00
    Description: A beautiful landscape photo
    ## File: data/report.xlsx
    ## Sheet1
    | Column1 | Column2 | Column3 |
    |---------|---------|---------|
    | data1   | data2   | data3   |
    | data4   | data5   | data6   |
    ```
    Key features:
    - Maintains original file structure in headings
    - Processes nested files recursively
    - Uses appropriate converters for each file type
    - Preserves formatting of converted content
    - Cleans up temporary files after processing
    """
    def convert(
        self, local_path: str, **kwargs: Any
    ) -> Union[None, DocumentConverterResult]:
        # Bail if not a ZIP
        extension = kwargs.get("file_extension", "")
        if extension.lower() != ".zip":
            return None
        # Get parent converters list if available
        parent_converters = kwargs.get("_parent_converters", [])
        if not parent_converters:
            return DocumentConverterResult(
                title=None,
                text_content=f"[ERROR] No converters available to process zip contents from: {local_path}",
            )
        extracted_zip_folder_name = (
            f"extracted_{os.path.basename(local_path).replace('.zip', '_zip')}"
        )
        extraction_dir = os.path.normpath(
            os.path.join(os.path.dirname(local_path), extracted_zip_folder_name)
        )
        md_content = f"Content from the zip file `{os.path.basename(local_path)}`:\n\n"
        try:
            # Extract the zip file safely
            with zipfile.ZipFile(local_path, "r") as zipObj:
                # Safeguard against path traversal
                for member in zipObj.namelist():
                    member_path = os.path.normpath(os.path.join(extraction_dir, member))
                    if (
                        not os.path.commonprefix([extraction_dir, member_path])
                        == extraction_dir
                    ):
                        raise ValueError(
                            f"Path traversal detected in zip file: {member}"
                        )
                # Extract all files safely
                zipObj.extractall(path=extraction_dir)
            # Process each extracted file
            for root, dirs, files in os.walk(extraction_dir):
                for name in files:
                    file_path = os.path.join(root, name)
                    relative_path = os.path.relpath(file_path, extraction_dir)
                    # Get file extension
                    _, file_extension = os.path.splitext(name)
                    # Update kwargs for the file
                    file_kwargs = kwargs.copy()
                    file_kwargs["file_extension"] = file_extension
                    file_kwargs["_parent_converters"] = parent_converters
                    # Try converting the file using available converters
                    for converter in parent_converters:
                        # Skip the zip converter to avoid infinite recursion
                        if isinstance(converter, ZipConverter):
                            continue
                        result = converter.convert(file_path, **file_kwargs)
                        if result is not None:
                            md_content += f"\n## File: {relative_path}\n\n"
                            md_content += result.text_content + "\n\n"
                            break
            # Clean up extracted files if specified
            if kwargs.get("cleanup_extracted", True):
                shutil.rmtree(extraction_dir)
            return DocumentConverterResult(title=None, text_content=md_content.strip())
        except zipfile.BadZipFile:
            return DocumentConverterResult(
                title=None,
                text_content=f"[ERROR] Invalid or corrupted zip file: {local_path}",
            )
        except ValueError as ve:
            return DocumentConverterResult(
                title=None,
                text_content=f"[ERROR] Security error in zip file {local_path}: {str(ve)}",
            )
        except Exception as e:
            return DocumentConverterResult(
                title=None,
                text_content=f"[ERROR] Failed to process zip file {local_path}: {str(e)}",
            )
class FileConversionException(BaseException):
    pass
class UnsupportedFormatException(BaseException):
    pass
class MarkItDown:
    """(In preview) An extremely simple text-based document reader, suitable for LLM use.
    This reader will convert common file-types or webpages to Markdown."""
    def __init__(
        self,
        requests_session: Optional[requests.Session] = None,
        llm_client: Optional[Any] = None,
        llm_model: Optional[str] = None,
        style_map: Optional[str] = None,
        # Deprecated
        mlm_client: Optional[Any] = None,
        mlm_model: Optional[str] = None,
    ):
        if requests_session is None:
            self._requests_session = requests.Session()
        else:
            self._requests_session = requests_session
        # Handle deprecation notices
        #############################
        if mlm_client is not None:
            if llm_client is None:
                warn(
                    "'mlm_client' is deprecated, and was renamed 'llm_client'.",
                    DeprecationWarning,
                )
                llm_client = mlm_client
                mlm_client = None
            else:
                raise ValueError(
                    "'mlm_client' is deprecated, and was renamed 'llm_client'. Do not use both at the same time. Just use 'llm_client' instead."
                )
        if mlm_model is not None:
            if llm_model is None:
                warn(
                    "'mlm_model' is deprecated, and was renamed 'llm_model'.",
                    DeprecationWarning,
                )
                llm_model = mlm_model
                mlm_model = None
            else:
                raise ValueError(
                    "'mlm_model' is deprecated, and was renamed 'llm_model'. Do not use both at the same time. Just use 'llm_model' instead."
                )
        #############################
        self._llm_client = llm_client
        self._llm_model = llm_model
        self._style_map = style_map
        self._page_converters: List[DocumentConverter] = []
        # Register converters for successful browsing operations
        # Later registrations are tried first / take higher priority than earlier registrations
        # To this end, the most specific converters should appear below the most generic converters
        self.register_page_converter(PlainTextConverter())
        self.register_page_converter(HtmlConverter())
        self.register_page_converter(RSSConverter())
        self.register_page_converter(WikipediaConverter())
        self.register_page_converter(YouTubeConverter())
        self.register_page_converter(BingSerpConverter())
        self.register_page_converter(DocxConverter())
        self.register_page_converter(XlsxConverter())
        self.register_page_converter(XlsConverter())
        self.register_page_converter(PptxConverter())
        self.register_page_converter(WavConverter())
        self.register_page_converter(Mp3Converter())
        self.register_page_converter(ImageConverter())
        self.register_page_converter(IpynbConverter())
        self.register_page_converter(PdfConverter())
        self.register_page_converter(ZipConverter())
        self.register_page_converter(OutlookMsgConverter())
    def convert(
        self, source: Union[str, requests.Response, Path], **kwargs: Any
    ) -> DocumentConverterResult:  # TODO: deal with kwargs
        """
        Args:
            - source: can be a string representing a path either as string pathlib path object or url, or a requests.response object
            - extension: specifies the file extension to use when interpreting the file. If None, infer from source (path, uri, content-type, etc.)
        """
        # Local path or url
        if isinstance(source, str):
            if (
                source.startswith("http://")
                or source.startswith("https://")
                or source.startswith("file://")
            ):
                return self.convert_url(source, **kwargs)
            else:
                return self.convert_local(source, **kwargs)
        # Request response
        elif isinstance(source, requests.Response):
            return self.convert_response(source, **kwargs)
        elif isinstance(source, Path):
            return self.convert_local(source, **kwargs)
    def convert_local(
        self, path: Union[str, Path], **kwargs: Any
    ) -> DocumentConverterResult:  # TODO: deal with kwargs
        if isinstance(path, Path):
            path = str(path)
        # Prepare a list of extensions to try (in order of priority)
        ext = kwargs.get("file_extension")
        extensions = [ext] if ext is not None else []
        # Get extension alternatives from the path and puremagic
        base, ext = os.path.splitext(path)
        self._append_ext(extensions, ext)
        for g in self._guess_ext_magic(path):
            self._append_ext(extensions, g)
        # Convert
        return self._convert(path, extensions, **kwargs)
    # TODO what should stream's type be?
    def convert_stream(
        self, stream: Any, **kwargs: Any
    ) -> DocumentConverterResult:  # TODO: deal with kwargs
        # Prepare a list of extensions to try (in order of priority)
        ext = kwargs.get("file_extension")
        extensions = [ext] if ext is not None else []
        # Save the file locally to a temporary file. It will be deleted before this method exits
        handle, temp_path = tempfile.mkstemp()
        fh = os.fdopen(handle, "wb")
        result = None
        try:
            # Write to the temporary file
            content = stream.read()
            if isinstance(content, str):
                fh.write(content.encode("utf-8"))
            else:
                fh.write(content)
            fh.close()
            # Use puremagic to check for more extension options
            for g in self._guess_ext_magic(temp_path):
                self._append_ext(extensions, g)
            # Convert
            result = self._convert(temp_path, extensions, **kwargs)
        # Clean up
        finally:
            try:
                fh.close()
            except Exception:
                pass
            os.unlink(temp_path)
        return result
    def convert_url(
        self, url: str, **kwargs: Any
    ) -> DocumentConverterResult:  # TODO: fix kwargs type
        # Send a HTTP request to the URL
        response = self._requests_session.get(url, stream=True)
        response.raise_for_status()
        return self.convert_response(response, **kwargs)
    def convert_response(
        self, response: requests.Response, **kwargs: Any
    ) -> DocumentConverterResult:  # TODO fix kwargs type
        # Prepare a list of extensions to try (in order of priority)
        ext = kwargs.get("file_extension")
        extensions = [ext] if ext is not None else []
        # Guess from the mimetype
        content_type = response.headers.get("content-type", "").split(";")[0]
        self._append_ext(extensions, mimetypes.guess_extension(content_type))
        # Read the content disposition if there is one
        content_disposition = response.headers.get("content-disposition", "")
        m = re.search(r"filename=([^;]+)", content_disposition)
        if m:
            base, ext = os.path.splitext(m.group(1).strip("\"'"))
            self._append_ext(extensions, ext)
        # Read from the extension from the path
        base, ext = os.path.splitext(urlparse(response.url).path)
        self._append_ext(extensions, ext)
        # Save the file locally to a temporary file. It will be deleted before this method exits
        handle, temp_path = tempfile.mkstemp()
        fh = os.fdopen(handle, "wb")
        result = None
        try:
            # Download the file
            for chunk in response.iter_content(chunk_size=512):
                fh.write(chunk)
            fh.close()
            # Use puremagic to check for more extension options
            for g in self._guess_ext_magic(temp_path):
                self._append_ext(extensions, g)
            # Convert
            result = self._convert(temp_path, extensions, url=response.url, **kwargs)
        # Clean up
        finally:
            try:
                fh.close()
            except Exception:
                pass
            os.unlink(temp_path)
        return result
    def _convert(
        self, local_path: str, extensions: List[Union[str, None]], **kwargs
    ) -> DocumentConverterResult:
        error_trace = ""
        for ext in extensions + [None]:  # Try last with no extension
            for converter in self._page_converters:
                _kwargs = copy.deepcopy(kwargs)
                # Overwrite file_extension appropriately
                if ext is None:
                    if "file_extension" in _kwargs:
                        del _kwargs["file_extension"]
                else:
                    _kwargs.update({"file_extension": ext})
                # Copy any additional global options
                if "llm_client" not in _kwargs and self._llm_client is not None:
                    _kwargs["llm_client"] = self._llm_client
                if "llm_model" not in _kwargs and self._llm_model is not None:
                    _kwargs["llm_model"] = self._llm_model
                # Add the list of converters for nested processing
                _kwargs["_parent_converters"] = self._page_converters
                if "style_map" not in _kwargs and self._style_map is not None:
                    _kwargs["style_map"] = self._style_map
                # If we hit an error log it and keep trying
                try:
                    res = converter.convert(local_path, **_kwargs)
                except Exception:
                    error_trace = ("\n\n" + traceback.format_exc()).strip()
                if res is not None:
                    # Normalize the content
                    res.text_content = "\n".join(
                        [line.rstrip() for line in re.split(r"\r?\n", res.text_content)]
                    )
                    res.text_content = re.sub(r"\n{3,}", "\n\n", res.text_content)
                    # Todo
                    return res
        # If we got this far without success, report any exceptions
        if len(error_trace) > 0:
            raise FileConversionException(
                f"Could not convert '{local_path}' to Markdown. File type was recognized as {extensions}. While converting the file, the following error was encountered:\n\n{error_trace}"
            )
        # Nothing can handle it!
        raise UnsupportedFormatException(
            f"Could not convert '{local_path}' to Markdown. The formats {extensions} are not supported."
        )
    def _append_ext(self, extensions, ext):
        """Append a unique non-None, non-empty extension to a list of extensions."""
        if ext is None:
            return
        ext = ext.strip()
        if ext == "":
            return
        # if ext not in extensions:
        extensions.append(ext)
    def _guess_ext_magic(self, path):
        """Use puremagic (a Python implementation of libmagic) to guess a file's extension based on the first few bytes."""
        # Use puremagic to guess
        try:
            guesses = puremagic.magic_file(path)
            # Fix for: https://github.com/microsoft/markitdown/issues/222
            # If there are no guesses, then try again after trimming leading ASCII whitespaces.
            # ASCII whitespace characters are those byte values in the sequence b' \t\n\r\x0b\f'
            # (space, tab, newline, carriage return, vertical tab, form feed).
            if len(guesses) == 0:
                with open(path, "rb") as file:
                    while True:
                        char = file.read(1)
                        if not char:  # End of file
                            break
                        if not char.isspace():
                            file.seek(file.tell() - 1)
                            break
                    try:
                        guesses = puremagic.magic_stream(file)
                    except puremagic.main.PureError:
                        pass
            extensions = list()
            for g in guesses:
                ext = g.extension.strip()
                if len(ext) > 0:
                    if not ext.startswith("."):
                        ext = "." + ext
                    if ext not in extensions:
                        extensions.append(ext)
            return extensions
        except FileNotFoundError:
            pass
        except IsADirectoryError:
            pass
        except PermissionError:
            pass
        return []
    def register_page_converter(self, converter: DocumentConverter) -> None:
        """Register a page text converter."""
        self._page_converters.insert(0, converter)
knowledgebase/utils.py
New file
@@ -0,0 +1,11 @@
import math
def get_bit_mask(start, end):
    bits = math.ceil((end + 1) / 8) * 8
    if bits == 0:
        bits = 8
    mask = 0
    for i in range(start, end + 1):
        mask |= 1 << (bits - i - 1)
    return mask
main.py
@@ -1,6 +1,7 @@
import math
import os
from lang_flow import LangFlow
from markitdown import MarkItDown
from knowledgebase.markitdown import MarkItDown
from doc_to_docx import doc_to_docx
@@ -25,10 +26,10 @@
        if file.endswith(".docx"):
            # 转换为 md
            result = md.convert(dst_dir + file)
            text += '\n\n' + result.text_content
    out_file = dst_dir + 'docs.md'
    with open(out_file, 'w', encoding='utf-8') as f:
        f.write(text)
            text = result.text_content
            out_file = dst_dir + file + '.md'
            with open(out_file, 'w', encoding='utf-8') as f:
                f.write(text)
    return out_file
@@ -36,18 +37,29 @@
# 2.输入文档
# 3.启动LangFlow
def main():
    # doc_dir = "D:\\workspace\\PythonProjects\\KnowledgeBase\\doc\\"
    doc_dir = ".\\doc\\"
    # 处理文档
    # process_docs(doc_dir)
    # 文档转换为markdown
    # md_file = to_markdown(doc_dir)
    md_file = to_markdown(doc_dir)
    md_file = 'D:\\workspace\\PythonProjects\\KnowledgeBase\\doc\\test.md'
    # 启动大模型处理流程
    ret_text = LangFlow([md_file]).run()
    # ret_text = LangFlow([md_file]).run()
    # 保存结果
    # with open('D:\\workspace\\PythonProjects\\KnowledgeBase\\doc\\test.text', 'w', encoding='utf-8') as f:
    #     f.write(ret_text)
def get_bit_mask(start, end):
    bits = math.ceil((end + 1) / 8) * 8
    if bits == 0:
        bits = 8
    mask = 0
    for i in range(start, end + 1):
        mask |= 1 << (bits - i - 1)
    return mask
if __name__ == '__main__':
    main()
    main()
prompts.json
New file
@@ -0,0 +1,14 @@
{
  "systemMsg": {
    "desc": "system 消息",
    "prompt": "# 角色\n你是一个专业的文档通信分析师,擅长进行文档分析和通信协议分析,同时能够解析 markdown 类型的文档。拥有成熟准确的文档阅读与分析能力,能够妥善处理多文档间存在引用关系的复杂情况。\n\n## 技能\n### 技能 1:文档分析(包括 markdown 文档)\n1. 当用户提供文档时,仔细阅读文档内容,严格按照文档中的描述提取关键信息,不得加入自己的回答或建议。\n2. 分析文档的结构、主题和重点内容,同样只依据文档进行表述。\n3. 如果文档间存在引用关系,梳理引用脉络,明确各文档之间的关联,且仅呈现文档中体现的内容。\n\n\n### 技能 2:通信协议分析\n1. 接收通信协议相关信息,理解协议的规则和流程,仅依据所给信息进行分析。\n\n## 目标导向\n1. 通过对文档和通信协议的分析,为用户提供清晰、准确的数据结构,帮助用户更好地理解和使用相关信息。\n2. 以 JSON 格式组织输出内容,确保数据结构的完整性和可读性。\n\n## 规则\n1. 每一个型号都会有一套文档,需准确判断是否为同一个型号的文档后再进行整体分析。\n2. 每次只分析同一个型号。\n3. 大多数文档结构为:型号下包含设备,设备下包含数据流,数据流下包含数据帧,数据帧中有一块是包域,包域中会挂载各种类型的数据包。\n4. 这些文档都是数据传输协议的描述,在数据流、数据帧、数据包等传输实体中都描述了各个字段的分布和每个字段的大小,且大小单位不统一,需理解这些单位,并将所有输出单位统一为 bits,统一使用length表示。\n5. 如果有层级,使用树形 JSON 输出,子节点 key 使用children;需保证相同类型的数据结构统一,并且判断每个层级是什么类型,输出类型字段,类型字段的 key 使用 type ;例如当前层级为字段时使用:type:\"field\";当前层级为设备时使用:type:\"device\"\n6.名称相关的字段的 key 使用name;代号或者唯一标识相关的字段的key使用id;序号相关的字段的key使用number;其他没有举例的字段使用精简的翻译作为字段的key;\n7.探测帧为CADU,其中包含同步头和VCDU,按照习惯需要使用VCDU层级包含下一层级中传输帧主导头、传输帧插入域、传输帧数据域、传输帧尾的结构\n\n## 限制:\n- 所输出的内容必须按照JSON格式进行组织,不能偏离框架要求,且严格遵循文档内容进行输出,只输出 JSON ,不要输出其它文字。\n- 不输出任何注释等描述性信息"
  },
  "getProject": {
    "desc": "获取型号信息",
    "prompt": "根据文档输出型号信息,型号字段包括:名称和代号,仅输出型号的属性,不输出其他层级数据"
  },
  "getDevice": {
    "desc": "获取设备信息",
    "prompt": "输出所有设备列表,设备字段包括名称(name)、代号(code),如果没有代号则使用名称的英文翻译缩写代替且缩写长度不超过5个字符,JSON格式,并且给每个设备增加三个字段,第一个字段hasTcTm“是否包含遥控遥测”,判断该设备是否包含遥控遥测的功能;第二个字段hasTemperatureAnalog“是否包含温度量、模拟量等数据的采集”,判断该设备是否包含温度量等信息的采集功能;第三个字段hasBus“是否是总线设备”,判断该设备是否属于总线设备,是否有RT地址;每个字段的值都使用true或false来表示。\n仅输出JSON,不要输出JSON以外的任何字符。"
  }
}
requirements.txt
Binary files differ
tc_frame_format.json
New file
@@ -0,0 +1,77 @@
[
  {
    "name": "帧主导头",
    "type": "combPkt",
    "children": [
      {
        "name": "版本号",
        "length": 2,
        "value": "00B",
        "type": "para",
        "dataTy": "const"
      },
      {
        "name": "通过标志",
        "length": 1,
        "value": "1",
        "type": "para",
        "dataTy": "const"
      },
      {
        "name": "控制命令标志",
        "length": 1,
        "value": "0",
        "type": "para",
        "dataTy": "const"
      },
      {
        "name": "空闲位",
        "length": 2,
        "value": "00",
        "type": "para",
        "dataTy": "const"
      },
      {
        "name": "航天器标识",
        "length": 10,
        "value": "",
        "type": "para",
        "dataTy": "const"
      },
      {
        "name": "虚拟信道标识",
        "length": 6,
        "value": "",
        "type": "para",
        "dataTy": "enum"
      },
      {
        "name": "帧长",
        "length": 10,
        "value": "",
        "type": "para",
        "dataTy": "length"
      },
      {
        "name": "帧序列号",
        "length": 1,
        "value": "00B",
        "type": "para",
        "dataTy": "const"
      }
    ]
  },
  {
    "name": "传送帧数据域",
    "length": 1,
    "value": "",
    "type": "para",
    "dataTy": "subPkt"
  },
  {
    "name": "帧差错控制域",
    "length": 1,
    "value": "00B",
    "type": "para"
  }
]