Lyken17 / pytorch-OpCounter

Count the MACs / FLOPs of your PyTorch model.
MIT License
4.85k stars 528 forks source link

Count flops by a range #198

Open Angiemaster opened 1 year ago

Angiemaster commented 1 year ago

Hi, I tried this code, it works, but may I know how can I get flops if I just wanna some specific ranges/blocks? thank you

    for m in model.modules():
        if len(list(m.children())) > 0: # skip for non-leaf module
            continue
        # print layer-wise information here.
        print(str(m),  m.total_ops, m.total_params)
        total_ops += m.total_ops
        total_params += m.total_params

Originally posted by @Lyken17 in https://github.com/Lyken17/pytorch-OpCounter/issues/3#issuecomment-442506327