TMElyralab / MuseTalk

MuseTalk: Real-Time High Quality Lip Synchorization with Latent Space Inpainting
Other
1.84k stars 219 forks source link

您好,请问为什么导出onnx文件会把所有算子的权重都导出? #128

Open DestoryVIP opened 2 weeks ago

DestoryVIP commented 2 weeks ago

您好,我写了一个onnx导出脚本,只导出unet.model,然而导出后文件并不是保存在一个model.onnx中,,而是model.onnx只保存文件结构,而权重保存成零散的文件?

导出代码如下:

# ===============================构建算子
import onnxscript

## Assuming you use opset18
from onnxscript.onnx_opset import opset18 as op

custom_opset = onnxscript.values.Opset(domain="torch.onnx", version=17)

@onnxscript.script(custom_opset)
def ScaledDotProductAttention(
    query,
    key,
    value,
    dropout_p,
):
    # Swap the last two axes of key
    key_shape = op.Shape(key)
    key_last_dim = key_shape[-1:]
    key_second_last_dim = key_shape[-2:-1]
    key_first_dims = key_shape[:-2]
    # Contract the dimensions that are not the last two so we can transpose
    # with a static permutation.
    key_squeezed_shape = op.Concat(
        op.Constant(value_ints=[-1]), key_second_last_dim, key_last_dim, axis=0
    )
    key_squeezed = op.Reshape(key, key_squeezed_shape)
    key_squeezed_transposed = op.Transpose(key_squeezed, perm=[0, 2, 1])
    key_transposed_shape = op.Concat(key_first_dims, key_last_dim, key_second_last_dim, axis=0)
    key_transposed = op.Reshape(key_squeezed_transposed, key_transposed_shape)

    embedding_size = op.CastLike(op.Shape(query)[-1], query)
    scale = op.Div(1.0, op.Sqrt(embedding_size))

    # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
    query_scaled = op.Mul(query, op.Sqrt(scale))
    key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale))
    attn_weight = op.Softmax(
        op.MatMul(query_scaled, key_transposed_scaled),
        axis=-1,
    )
    attn_weight, _ = op.Dropout(attn_weight, dropout_p)
    return op.MatMul(attn_weight, value)

def custom_scaled_dot_product_attention(g, query, key, value, attn_mask, dropout, is_causal, scale=None):
    return g.onnxscript_op(ScaledDotProductAttention, query, key, value, dropout).setType(query.type())

torch.onnx.register_custom_op_symbolic(
    symbolic_name="aten::scaled_dot_product_attention",
    symbolic_fn=custom_scaled_dot_product_attention,
    opset_version=17,
)
#=============================================
from diffusers import AutoencoderKL,UNet2DConditionModel
from pathlib import Path
unet_config = "./models/musetalk/musetalk.json"
unet_md_path = "./models/musetalk/pytorch_model.bin"
with open(unet_config, 'r') as f:
    unet_config = json.load(f)

model = UNet2DConditionModel(**unet_config)
model.eval()

# 加载权重
weights = torch.load(unet_md_path) if torch.cuda.is_available() else torch.load(unet_md_path, map_location='cpu')
model.load_state_dict(weights)

# # 构造假数据
# """
#     latent_batch: 1, 8, 32, 32
#     timespace: 1  [0]
#     audio_feature_batch: 1, 50, 384
# """

fake_latent_batch = torch.from_numpy(np.random.random(size=(1, 8, 32, 32)).astype(np.float32))  # 256x256图片压缩到32x32
timespaces = torch.from_numpy(np.array([0]))
fake_audio_feature_batch = torch.from_numpy(np.random.random(size=(1, 50, 384)).astype(np.float32))

torch.onnx.export(
    model,
    (
        fake_latent_batch, timespaces, fake_audio_feature_batch
    ),
    f="unet.onnx",
    custom_opsets={"torch.onnx": 17},
    opset_version=17,
    input_names=['latent_batch', "timespace", "audio_feature_batch"],
    output_names=['output'],
    verbose=True,
    export_params=True,
    dynamic_axes={
        'latent_batch': {0: 'batch_size'},
        'audio_feature_batch': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

导出的结果如下所示:(导出结果占3.2G) image