microsoft / cliffordlayers

https://microsoft.github.io/cliffordlayers
MIT License
143 stars 17 forks source link

Minor fixes to allow `torch.compile` #2

Closed rejuvyesh closed 1 year ago

rejuvyesh commented 1 year ago

Some interesting findings:

CliffordConv1d
| module   | #parameters or shape   | #flops   |
|:---------|:-----------------------|:---------|
| weight   | 0.768K                 | 0.194M   |
|  0       |  (16, 8, 3)            |          |
|  1       |  (16, 8, 3)            |          |
Benchmarking...
no compile 0.700600357055664
compile 1.0268057250976563

Conv1d
| module   | #parameters or shape   | #flops   |
|:---------|:-----------------------|:---------|
| model    | 0.768K                 | 96.768K  |
|  weight  |  (16, 16, 3)           |          |
Benchmarking...
0.3552262496948242
compile 0.3716716766357422

CliffordConv2d
| module   | #parameters or shape   | #flops   |
|:---------|:-----------------------|:---------|
| weight   | 4.608K                 | 0.293G   |
|  0       |  (16, 8, 3, 3)         |          |
|  1       |  (16, 8, 3, 3)         |          |
|  2       |  (16, 8, 3, 3)         |          |
|  3       |  (16, 8, 3, 3)         |          |
Benchmarking...
1.360364227294922
compile 1.1797920227050782

Conv2d
| module   | #parameters or shape   | #flops   |
|:---------|:-----------------------|:---------|
| model    | 4.608K                 | 73.157M  |
|  weight  |  (16, 32, 3, 3)        |          |
Benchmarking...
0.3499417495727539
compile 0.3843891143798828

CliffordConv3d
| module   | #parameters or shape   | #flops   |
|:---------|:-----------------------|:---------|
| weight   | 27.648K                | 5.972G   |
|  0       |  (16, 8, 3, 3, 3)      |          |
|  1       |  (16, 8, 3, 3, 3)      |          |
|  2       |  (16, 8, 3, 3, 3)      |          |
|  3       |  (16, 8, 3, 3, 3)      |          |
|  4       |  (16, 8, 3, 3, 3)      |          |
|  5       |  (16, 8, 3, 3, 3)      |          |
|  6       |  (16, 8, 3, 3, 3)      |          |
|  7       |  (16, 8, 3, 3, 3)      |          |
Benchmarking...
6.124010620117187
compile 4.024053649902344

Conv3d
| module   | #parameters or shape   | #flops   |
|:---------|:-----------------------|:---------|
| model    | 27.648K                | 0.746G   |
|  weight  |  (16, 64, 3, 3, 3)     |          |
Benchmarking...
2.1221580505371094
compile 2.1798707580566408