kunheek / style-aware-discriminator

CVPR 2022 - Official PyTorch implementation of "A Style-Aware Discriminator for Controllable Image Translation"
https://arxiv.org/abs/2203.15375
MIT License
112 stars 8 forks source link

lack count_parameters() in torch_utils.py #7

Closed williamyang1991 closed 1 year ago

williamyang1991 commented 1 year ago

I got an error here https://github.com/kunheek/style-aware-discriminator/blob/a6c38f5525e4f2eef3cd4f5956066b2e68e8b60c/mylib/base_model.py#L39

No function of count_parameters() is found

I find a similar function on the web and add it to torch_utils.py, then I can successfully run the code.

def count_parameters(model):
    counts = sum(p.numel() for p in model.parameters() if p.requires_grad)
    #print(f'The model has {counts:,} trainable parameters')
    return counts
kunheek commented 1 year ago

Thanks for letting me know! I think I accidentally deleted it while refactoring the code. I added this function again :)

My original implementation looks like this:

def count_parameters(module):
    assert isinstance(module, nn.Module)
    num_params = 0
    for p in module.parameters():
        if p.requires_grad:
            num_params += p.numel()
    return num_params

It works same as your suggestion, so you can use your own implementation.