bes-dev / MobileStyleGAN.pytorch

An official implementation of MobileStyleGAN in PyTorch
Apache License 2.0
672 stars 81 forks source link

MobileStyleGAN Checkpoint converted to ONNX generates grey images #44

Open IvonaTau opened 2 years ago

IvonaTau commented 2 years ago

Hi!

Thank you for an amazing repository. I successfully converted my StyleGAN2-ada rosinality checkpoint, by running the following line: python convert_rosinality_ckpt.py --ckpt {path_to_rosinality_stylegan2_ckpt} --ckpt-mnet output/mnet.ckpt --ckpt-snet output/snet.ckpt --cfg-path output/config.json

I tested the checkpoint with demo.py and it produces images as expected.

I then converted it to ONNX by running python train.py --cfg output/config.json --export-model onnx --export-dir onnx-2 and tried to use the converted checkpoint in MobileStyleGAN web demo (https://github.com/cyrildiagne/mobilestylegan-web-demo). It produces uniform grey images for all seeds. The web demo works fine with the authors' ffhq checkpoint so it seems to be an issue with the converted model.

Do you have any thoughts on what might be causing this?

Screenshot 2022-10-05 at 17 57 19
johndpope commented 3 weeks ago

it seems the onnx has problem with fusedleakyrelu -

related - https://github.com/rosinality/stylegan2-pytorch/issues/322

am using this same function in own project - https://github.com/bes-dev/MobileStyleGAN.pytorch/blob/a9776ff8f05a868b2d3b637bda14eca4c074d2a3/core/models/modules/ops/fused_act.py#L22

i have this class that spits out onnx - the opsets are very important as lower versions dont have the support https://github.com/johndpope/IMF/blob/main/onnxconv.py#L153

but this didn't resolve things for me yet.

still digging....

wait a second maybe this recent upstream fix on pytorch solves things - https://github.com/pytorch/pytorch/issues/125753 need to test

pip install onnxscript -U
pip install onnxruntime -U
pip install onnxconverter_common -U
pip install onnx  -U

N.B. - this library has a diagnosis for onnx https://github.com/webonnx/wonnx

cargo install --git https://github.com/webonnx/wonnx.git wonnx-cli
nnx info ./data/models/opt-squeeze.onnx

[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/linear_layers.2_1/Gemm' input '/latent_token_encoder/activation_7/LeakyRelu_output_0' has unknown shape
[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/activation_8/LeakyRelu' input '/latent_token_encoder/linear_layers.2_1/Gemm_output_0' has unknown shape
[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/linear_layers.3_1/Gemm' input '/latent_token_encoder/activation_8/LeakyRelu_output_0' has unknown shape
[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/activation_9/LeakyRelu' input '/latent_token_encoder/linear_layers.3_1/Gemm_output_0' has unknown shape
[2024-10-10T21:07:06Z ERROR nnx::info] Node '/latent_token_encoder/final_linear_1/Gemm' input '/latent_token_encoder/activation_9/LeakyRelu_output_0' has unknown shape
+------------------+-------------------------------------------------------------------+
| Model version    | 0                                                                 |
+------------------+-------------------------------------------------------------------+
| IR version       | 8                                                                 |
+------------------+-------------------------------------------------------------------+
| Producer name    | pytorch                                                           |
+------------------+-------------------------------------------------------------------+
| Producer version | 2.4.0                                                             |
+------------------+-------------------------------------------------------------------+
| Opsets           | 15                                                                |
+------------------+-------------------------------------------------------------------+
| Inputs           | +-------------+-------------+----------------------------+------+ |
|                  | | Name        | Description | Shape                      | Type | |
|                  | +-------------+-------------+----------------------------+------+ |
|                  | | x_current   |             | batch_size x 3 x 256 x 256 | f32  | |
|                  | +-------------+-------------+----------------------------+------+ |
|                  | | x_reference |             | batch_size x 3 x 256 x 256 | f32  | |
|                  | +-------------+-------------+----------------------------+------+ |
+------------------+-------------------------------------------------------------------+
| Outputs          | +------+-------------+----------------------------+------+        |
|                  | | Name | Description | Shape                      | Type |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | f_r  |             | batch_size x 128 x 64 x 64 | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | t_r  |             | batch_size x 256 x 32 x 32 | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | t_c  |             | batch_size x 512 x 16 x 16 | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | x    |             | batch_size x 512 x 8 x 8   | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | 3032 |             | Gemm3032_dim_0 x 32        | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
|                  | | 4799 |             | Gemm4799_dim_0 x 32        | f32  |        |
|                  | +------+-------------+----------------------------+------+        |
+------------------+-------------------------------------------------------------------+
| Ops used         | +-----------------+---------------------+                         |
|                  | | Op              | Attributes          |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Add             |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Cast            | to=7                |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Concat          | axis=0              |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Constant        | value=<TENSOR>      |                         |
|                  | +-----------------+---------------------+                         |
|                  | | ConstantOfShape | value=<TENSOR>      |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Conv            | dilations=<INTS>    |                         |
|                  | |                 | group=1             |                         |
|                  | |                 | kernel_shape=<INTS> |                         |
|                  | |                 | pads=<INTS>         |                         |
|                  | |                 | strides=<INTS>      |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Div             |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Gather          | axis=0              |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Gemm            | alpha=1             |                         |
|                  | |                 | beta=1              |                         |
|                  | |                 | transB=1            |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Identity        |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | LeakyRelu       | alpha=0.2           |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Mul             |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Pad             | mode=constant       |                         |
|                  | +-----------------+---------------------+                         |
|                  | | ReduceMean      | axes=<INTS>         |                         |
|                  | |                 | keepdims=0          |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Relu            |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Reshape         | allowzero=0         |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Shape           |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Slice           |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Sub             |                     |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Transpose       | perm=<INTS>         |                         |
|                  | +-----------------+---------------------+                         |
|                  | | Unsqueeze       |                     |                         |
|                  | +-----------------+---------------------+                         |
+------------------+-------------------------------------------------------------------+
| Memory usage     | +--------------+-----------+                                      |
|                  | | Inputs       |       0 B |                                      |
|                  | +--------------+-----------+                                      |
|                  | | Outputs      |       0 B |                                      |
|                  | +--------------+-----------+                                      |
|                  | | Intermediate |       0 B |                                      |
|                  | +--------------+-----------+                                      |
|                  | | Weights      | 204.4 MiB |                                      |
|                  | +--------------+-----------+                                      |
|                  | | Total        | 204.4 MiB |                                      |
|                  | +--------------+-----------+                                      |
+------------------+-------------------------------------------------------------------+