wangqiang9 / SketchKnitter

About PyTorch implementation of SketchKnitter: Vectorized Sketch Generation with Diffusion Models, ICLR 2023, Spotlight.
MIT License
57 stars 9 forks source link

There is a problem with the sample.py #4

Closed TO1bHI8Q closed 1 year ago

TO1bHI8Q commented 1 year ago

TypeError: ddim_sample_loop() missing 3 required positional arguments: 'data', 'raster', and 'loss'

wangqiang9 commented 1 year ago

TypeError: ddim_sample_loop() missing 3 required positional arguments: 'data', 'raster', and 'loss'

Hi, can you post your full error message?

TO1bHI8Q commented 1 year ago

TypeError: ddim_sample_loop() missing 3 required positional arguments: 'data', 'raster', and 'loss'

Hi, can you post your full error message?

When I use sample.py, the weight is also model050000.pt And “all_images” defined as all in line 55 of sample. py, What is the meaning of ”len (all_images)“?

Traceback (most recent call last): File "D:/SketchKnitter-main/sample.py", line 95, in main() File "D:/SketchKnitter-main/sample.py", line 66, in main sample, pen_state = sample_fn( TypeError: ddim_sample_loop() missing 3 required positional arguments: 'data', 'raster', and 'loss'

wangqiang9 commented 1 year ago

data', 'raster', and 'loss'

bug is fix in this pull request: https://github.com/XDUWQ/SketchKnitter/commit/8250c3534ffd653e884fc10845acbd3eb2427241

wangqiang9 commented 1 year ago

For example, you can specify data, raster, loss as following:

    loss = th.nn.MSELoss()
    data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        # category=["moon.npz", "airplane.npz", "fish.npz", "umbrella.npz", "train.npz",
        #           "spider.npz", "shoe.npz", "apple.npz", "lion.npz", "bus.npz"],
        category=["apple.npz"],
        class_cond=True,
    )
    raster = AttentionMap().cuda()
TO1bHI8Q commented 1 year ago

data', 'raster', and 'loss'

bug is fix in this pull request: 8250c35

You also need to modify the code in line 67 ”sample, pen state",replace with "sample, pen state,_". Then, it work. Thanks!

wangqiang9 commented 1 year ago

data', 'raster', and 'loss'

bug is fix in this pull request: 8250c35

You also need to modify the code in line 67 ”sample, pen state",replace with "sample, pen state,_". Then, it work. Thanks!

Thank you very much for your feedback on the bug, it has been fixed.