Open Klanly opened 2 years ago
你可以试一下啊,炸了再说
这里有现成的,感觉生成很奇怪,暂时没找到文字输入方法 https://github.com/TabuaTambalam/vqqncnn
这里有现成的,感觉生成很奇怪,暂时没找到文字输入方法 https://github.com/TabuaTambalam/vqqncnn
好吧是我多乘了255. 有兴趣帮忙转一下min-dalle的Bert编/解码器吗? 免费colab的12gb内存不够torch.jit.trace (本地8gb破本本)
mindd = MinDalle(is_mega=True, is_reusable=False)
tokens = mindd.tokenizer.tokenize("blue apple")
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens
text_tokens = torch.tensor(text_tokens).to(torch.long)
mindd.init_encoder()
with torch.jit.optimized_execution(True):
traced_enc = torch.jit.trace(mindd.encoder, (text_tokens),optimize=True)
encoder_state = traced_enc(text_tokens)
with torch.jit.optimized_execution(True):
traced_enc.save('/content/enc.pt')
del traced_enc
del mindd.encoder
mindd.init_decoder()
torch.manual_seed(444)
'''
DalleBartDecoder.forward (now DalleBartDecoder.decode_initial?) must be patched, remove all int args on it.
'''
with torch.jit.optimized_execution(True):
traced_dec = torch.jit.trace(mindd.decoder.eval(), (text_tokens,encoder_state),optimize=True) #OOM here when Mega
image_tokens = traced_dec(text_tokens,encoder_state) #OOM here when Mini
with torch.jit.optimized_execution(True):
traced_dec.save('/content/dec.pt')
得到pt后pnnx应该是
!pnnx /content/enc_mega.pt inputshape=[2,64]i32
!pnnx /content/dec_mega.pt inputshape=[2,64]i32,[2,64,1024]f32
Tried to export with Colab Pro, but that example code does not work.
# I added this
!pip install min-dalle
import torch
import numpy
from min_dalle import MinDalle
initializing DalleBartDecoder
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
[<ipython-input-1-6e9f1de962ed>](https://localhost:8080/#) in <module>()
34
35 with torch.jit.optimized_execution(True):
---> 36 traced_dec = torch.jit.trace(mindd.decoder.eval(), (text_tokens,encoder_state),optimize=True) #OOM here when Mega
37
38
4 frames
[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _forward_unimplemented(self, *input)
199 registered hooks while the latter silently ignores them.
200 """
--> 201 raise NotImplementedError
202
203
NotImplementedError:
If I move to cuda I get this
initializing DalleBartDecoder
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
[<ipython-input-1-b1c410a0cb25>](https://localhost:8080/#) in <module>()
34
35 with torch.jit.optimized_execution(True):
---> 36 traced_dec = torch.jit.trace(mindd.decoder.eval().cuda(), (text_tokens.cuda(),encoder_state.cuda()),optimize=True) #OOM here when Mega
37
38
4 frames
[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _forward_unimplemented(self, *input)
199 registered hooks while the latter silently ignores them.
200 """
--> 201 raise NotImplementedError
202
203
NotImplementedError:
Tried to export with Colab Pro, but that example code does not work.
# I added this !pip install min-dalle import torch import numpy from min_dalle import MinDalle
initializing DalleBartDecoder --------------------------------------------------------------------------- NotImplementedError Traceback (most recent call last) [<ipython-input-1-6e9f1de962ed>](https://localhost:8080/#) in <module>() 34 35 with torch.jit.optimized_execution(True): ---> 36 traced_dec = torch.jit.trace(mindd.decoder.eval(), (text_tokens,encoder_state),optimize=True) #OOM here when Mega 37 38 4 frames [/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _forward_unimplemented(self, *input) 199 registered hooks while the latter silently ignores them. 200 """ --> 201 raise NotImplementedError 202 203 NotImplementedError:
If I move to cuda I get this
initializing DalleBartDecoder --------------------------------------------------------------------------- NotImplementedError Traceback (most recent call last) [<ipython-input-1-b1c410a0cb25>](https://localhost:8080/#) in <module>() 34 35 with torch.jit.optimized_execution(True): ---> 36 traced_dec = torch.jit.trace(mindd.decoder.eval().cuda(), (text_tokens.cuda(),encoder_state.cuda()),optimize=True) #OOM here when Mega 37 38 4 frames [/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _forward_unimplemented(self, *input) 199 registered hooks while the latter silently ignores them. 200 """ --> 201 raise NotImplementedError 202 203 NotImplementedError:
hmmm, the bert decoder seems to be a big issue...
Any luck with older version dalle_bart_decoder.py maybe? https://github.com/kuprel/min-dalle/blob/deefd24919f5f5b5b96127b749675f70cedb9435/min_dalle/models/dalle_bart_decoder.py (when there still a forward func in DalleBartDecoder)
The current DalleBartDecoder works like: 4 outputs from an initializer; seed & image_count are ints, can be hardcoded/erased.
encoder_state, attention_mask, attention_state, image_tokens_new = (
mindd.decoder.decode_initial(
seed=seed,
image_count=image_count,
text_tokens=text_tokens,
encoder_state=encoder_state
)
)
Then loop recursively 16 times with:
for row_index in range(16):
attention_state, image_tokens = mindd.decoder.decode_row(
row_index,
temperature=temperature,
top_k=top_k,
supercondition_factor=supercondition_factor,
encoder_state=encoder_state,
attention_mask=attention_mask,
attention_state=attention_state,
image_tokens_sequence=image_tokens
)
row_index might not be hardcoded or erased. Can NCNN handle this type of route? @nihui
error log | 日志或报错信息 | ログ
model | 模型 | モデル
how to reproduce | 复现步骤 | 再現方法
1.同https://github.com/pnnx/pnnx/issues/42
我用pnnx转换了它的VQGAN解码器后得到这个: vq.ncnn.zip但是运行后输出是错误的, 我看见作者说VQGAN解码器必须使用float32: https://github.com/kuprel/min-dalle/commit/c199507a7a2a827dffd651b82c8da60260ee1c44 https://github.com/kuprel/min-dalle/issues/60#issuecomment-1177548528
所以是pnnx的ncnn::convert_to_fp16_model(g)造成输出错误吗? 能取消fp16转换吗? 多乘了255, 还有须用PIL不能用cv2. 接下来希望能讨论min-dalle的Bert编/解码器转NCNN问题