SHI-Labs / Neighborhood-Attention-Transformer

Neighborhood Attention Transformer, arxiv 2022 / CVPR 2023. Dilated Neighborhood Attention Transformer, arxiv 2022
MIT License
1.05k stars 86 forks source link

Fix for import errors #25

Closed alihassanijr closed 2 years ago

alihassanijr commented 2 years ago

Refactoring cuda into natten causes a name collision on natten.natten, so we should refactor that as well.

alihassanijr commented 2 years ago

Manual tests

Classification

> python3 natten/gradcheck.py
Verifying backward pass...
QK+RPB Gradients Ok
AV Gradients Ok
> python3 validate.py --model nat_base --pretrained ImageNet/ --amp
Validating in mixed precision with native PyTorch AMP.  
Data processing configuration for current model + dataset: 
 input_size: (3, 224, 224) 
 interpolation: bicubic
 mean: (0.485, 0.456, 0.406)  
 std: (0.229, 0.224, 0.225)
 crop_pct: 0.875
WARNING: Unsupported operator aten::mul encountered 148 time(s)
WARNING: Unsupported operator aten::softmax encountered 30 time(s) 
WARNING: Unsupported operator aten::add encountered 118 time(s)
WARNING: Unsupported operator aten::gelu encountered 30 time(s)
WARNING: Unsupported operator aten::rand encountered 58 time(s)
WARNING: Unsupported operator aten::floor_ encountered 58 time(s)  
WARNING: Unsupported operator aten::div encountered 58 time(s) 
WARNING: Unsupported operator aten::adaptive_avg_pool1d encountered 1 time(s) 
Model nat_base created, 89.770M Params and 13.728GFLOPs 
Test: [   0/196]  Time: 2.624s (2.624s,   97.55/s)  Loss:  0.3684 (0.3684)  Acc@1:  94.531 ( 94.531)  Acc@5:  98.828 ( 98.828)
Test: [  10/196]  Time: 0.374s (0.599s,  427.65/s)  Loss:  0.7993 (0.5135)  Acc@1:  81.641 ( 89.134)  Acc@5:  96.875 ( 98.438)
Test: [  20/196]  Time: 0.378s (0.499s,  512.84/s)  Loss:  0.4612 (0.5249)  Acc@1:  92.969 ( 88.895)  Acc@5:  98.438 ( 98.289)
Test: [  30/196]  Time: 0.374s (0.460s,  556.96/s)  Loss:  0.6294 (0.4928)  Acc@1:  87.891 ( 89.869)  Acc@5:  97.266 ( 98.374)
Test: [  40/196]  Time: 0.383s (0.439s,  582.62/s)  Loss:  0.4846 (0.5322)  Acc@1:  89.453 ( 88.729)  Acc@5:  97.656 ( 98.104)
Test: [  50/196]  Time: 0.376s (0.427s,  599.58/s)  Loss:  0.3374 (0.5330)  Acc@1:  94.922 ( 88.603)  Acc@5:  98.047 ( 98.108)
Test: [  60/196]  Time: 0.377s (0.419s,  611.52/s)  Loss:  0.6553 (0.5505)  Acc@1:  84.375 ( 88.140)  Acc@5:  96.875 ( 98.079)
Test: [  70/196]  Time: 0.377s (0.413s,  620.45/s)  Loss:  0.6382 (0.5395)  Acc@1:  85.938 ( 88.281)  Acc@5:  98.828 ( 98.195)
Test: [  80/196]  Time: 0.380s (0.408s,  627.45/s)  Loss:  0.9580 (0.5570)  Acc@1:  73.438 ( 87.871)  Acc@5:  95.703 ( 98.042)
Test: [  90/196]  Time: 0.380s (0.404s,  632.91/s)  Loss:  1.3174 (0.5819)  Acc@1:  67.188 ( 87.139)  Acc@5:  91.797 ( 97.785)
Test: [ 100/196]  Time: 0.378s (0.402s,  637.40/s)  Loss:  0.7969 (0.6144)  Acc@1:  80.469 ( 86.313)  Acc@5:  96.094 ( 97.478)
Test: [ 110/196]  Time: 0.377s (0.399s,  641.17/s)  Loss:  0.6367 (0.6257)  Acc@1:  85.938 ( 85.987)  Acc@5:  98.438 ( 97.410)
Test: [ 120/196]  Time: 0.379s (0.397s,  644.33/s)  Loss:  0.8379 (0.6304)  Acc@1:  80.859 ( 85.886)  Acc@5:  94.531 ( 97.333)
Test: [ 130/196]  Time: 0.376s (0.396s,  647.13/s)  Loss:  0.4229 (0.6473)  Acc@1:  90.625 ( 85.287)  Acc@5:  99.219 ( 97.188)
Test: [ 140/196]  Time: 0.377s (0.394s,  649.30/s)  Loss:  0.6094 (0.6538)  Acc@1:  88.672 ( 85.131)  Acc@5:  97.656 ( 97.144)
Test: [ 150/196]  Time: 0.381s (0.393s,  651.36/s)  Loss:  0.5747 (0.6605)  Acc@1:  88.672 ( 84.993)  Acc@5:  96.875 ( 97.072)
Test: [ 160/196]  Time: 0.377s (0.392s,  653.18/s)  Loss:  0.3867 (0.6670)  Acc@1:  92.578 ( 84.841)  Acc@5:  98.438 ( 96.962)
Test: [ 170/196]  Time: 0.377s (0.391s,  654.73/s)  Loss:  0.3965 (0.6762)  Acc@1:  91.406 ( 84.539)  Acc@5:  98.828 ( 96.870)
Test: [ 180/196]  Time: 0.378s (0.390s,  656.10/s)  Loss:  0.9956 (0.6861)  Acc@1:  76.562 ( 84.267)  Acc@5:  96.094 ( 96.828)
Test: [ 190/196]  Time: 0.376s (0.389s,  657.43/s)  Loss:  1.0391 (0.6890)  Acc@1:  75.000 ( 84.207)  Acc@5:  96.875 ( 96.832)
 * Acc@1 84.254 (15.746) Acc@5 96.858 (3.142)

Passed!

Detection

> python3 natten/gradcheck.py
Verifying backward pass...
QK+RPB Gradients Ok
AV Gradients Ok
> ./dist_test.sh \
    configs/nat/cascade_mask_rcnn_nat_base_3x_coco.py \
    http://ix.cs.uoregon.edu/\~alih/nat/checkpoints/DET/nat_base_cascademaskrcnn.pth \
    $NUM_GPUS \
    --eval bbox segm
Evaluating bbox...
Loading and preparing results...
DONE (t=0.12s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=19.72s).
Accumulating evaluation results...
DONE (t=2.99s).
Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.522
Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=1000 ] = 0.709
Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=1000 ] = 0.568
Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.352
Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.559
Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.672
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.647
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.647
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=1000 ] = 0.647
Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.476
Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.683
Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.795
Loading and preparing results...
DONE (t=0.83s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *segm*
DONE (t=22.17s).
Accumulating evaluation results...
DONE (t=2.99s).
Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.451
Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=1000 ] = 0.683
Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=1000 ] = 0.491
Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.254
Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.484
Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.640
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.566
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.566
Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=1000 ] = 0.566
Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.392
Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.604
Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.725

Segmentation

> python3 natten/gradcheck.py
Verifying backward pass...
QK+RPB Gradients Ok
AV Gradients Ok
> ./dist_test.sh \                                            
    configs/nat/upernet_nat_base_512x512_160k_ade20k.py \                                                                               
    http://ix.cs.uoregon.edu/\~alih/nat/checkpoints/SEG/nat_base_upernet.pth \
    $NUM_GPUS \
    --eval mIoU
+---------------------+-------+-------+
|        Class        |  IoU  |  Acc  |
+---------------------+-------+-------+
|         wall        | 75.53 | 87.16 |
|       building      | 81.87 | 92.29 |
|         sky         | 94.03 | 97.37 |
|        floor        | 80.33 | 90.81 |
|         tree        |  74.2 | 88.32 |
|       ceiling       | 82.15 | 90.36 |
|         road        | 82.84 | 90.57 |
|         bed         | 87.74 | 95.82 |
|      windowpane     | 60.96 | 76.56 |
|        grass        | 67.09 | 81.89 |
|       cabinet       | 61.49 | 76.18 |
|       sidewalk      | 67.19 | 81.96 |
|        person       | 79.83 | 92.83 |
|        earth        | 38.42 | 52.22 |
|         door        | 44.98 | 61.33 |
|        table        |  60.2 | 73.64 |
|       mountain      | 56.32 | 72.01 |
|        plant        | 50.81 | 63.63 |
|       curtain       | 69.39 |  84.6 |
|        chair        | 55.48 | 66.65 |
|         car         | 82.87 | 90.32 |
|        water        |  50.9 | 64.06 |
|       painting      | 70.19 | 87.05 |
|         sofa        | 67.36 | 83.98 |
|        shelf        | 41.04 | 58.61 |
|        house        | 43.24 | 52.96 |
|         sea         | 60.35 | 90.98 |
|        mirror       | 63.47 | 71.86 |
|         rug         |  57.3 | 63.43 |
|        field        | 29.64 | 47.03 |
|       armchair      |  42.7 | 61.84 |
|         seat        |  57.0 | 77.13 |
|        fence        | 40.73 | 55.21 |
|         desk        | 48.13 | 68.16 |
|         rock        | 41.16 | 63.11 |
|       wardrobe      | 51.11 | 67.45 |
|         lamp        | 61.11 | 73.86 |
|       bathtub       | 75.58 |  80.8 |
|       railing       | 35.51 | 47.36 |
|       cushion       | 58.02 | 69.98 |
|         base        | 29.12 |  37.3 |
|         box         | 24.72 | 31.98 |
|        column       | 44.52 | 56.47 |
|      signboard      | 36.14 | 54.87 |
|   chest of drawers  | 42.74 | 55.06 |
|       counter       | 27.87 | 34.73 |
|         sand        | 46.79 | 66.99 |
|         sink        | 72.58 | 79.75 |
|      skyscraper     | 51.11 | 62.53 |
|      fireplace      | 71.35 | 88.81 |
|     refrigerator    |  73.6 | 83.02 |
|      grandstand     | 39.79 | 62.46 |
|         path        | 18.91 | 25.99 |
|        stairs       | 30.03 | 35.55 |
|        runway       | 66.85 | 86.44 |
|         case        | 44.63 | 58.72 |
|      pool table     | 92.43 |  96.1 |
|        pillow       | 57.72 | 66.74 |
|     screen door     | 64.12 | 70.04 |
|       stairway      |  33.1 | 40.94 |
|        river        |  14.2 | 23.59 |
|        bridge       | 71.83 | 80.26 |
|       bookcase      |  40.3 | 59.67 |
|        blind        | 44.15 | 53.07 |
|     coffee table    | 57.86 | 82.82 |
|        toilet       | 85.69 | 91.19 |
|        flower       | 42.66 | 60.32 |
|         book        | 47.28 | 66.83 |
|         hill        |  6.35 |  8.63 |
|        bench        | 44.94 | 52.18 |
|      countertop     | 55.84 | 75.16 |
|        stove        | 75.39 | 82.15 |
|         palm        | 52.56 | 76.77 |
|    kitchen island   | 41.19 | 63.49 |
|       computer      | 63.74 | 75.12 |
|     swivel chair    | 47.24 | 66.98 |
|         boat        | 42.57 |  48.7 |
|         bar         |  28.3 | 39.14 |
|    arcade machine   | 50.75 | 53.53 |
|        hovel        | 53.32 | 62.04 |
|         bus         | 86.09 | 95.96 |
|        towel        | 60.64 | 74.56 |
|        light        | 56.48 | 64.49 |
|        truck        | 28.85 |  41.2 |
|        tower        | 12.96 | 20.85 |
|      chandelier     | 67.49 | 85.36 |
|        awning       | 26.03 | 32.65 |
|     streetlight     | 25.12 | 34.19 |
|        booth        | 47.54 | 49.03 |
| television receiver |  66.0 | 75.97 |
|       airplane      | 59.69 |  75.2 |
|      dirt track     |  4.58 | 15.82 |
|       apparel       | 32.95 | 48.14 |
|         pole        | 20.75 | 30.21 |
|         land        |  5.57 |  6.8  |
|      bannister      | 14.69 | 17.79 |
|      escalator      | 32.09 | 39.28 |
|       ottoman       | 46.33 | 57.33 |
|        bottle       | 36.13 | 63.37 |
|        buffet       | 45.26 | 51.34 |
|        poster       | 26.99 | 34.73 |
|        stage        | 19.62 | 27.18 |
|         van         | 44.84 | 60.52 |
|         ship        | 48.57 | 73.42 |
|       fountain      | 22.63 |  22.9 |
|    conveyer belt    |  66.7 | 84.57 |
|        canopy       | 18.26 | 24.12 |
|        washer       | 69.22 | 71.25 |
|      plaything      | 25.17 | 35.71 |
|    swimming pool    | 54.02 | 68.86 |
|        stool        | 44.41 | 61.67 |
|        barrel       | 57.82 | 66.71 |
|        basket       | 29.68 | 41.88 |
|      waterfall      | 44.44 | 51.48 |
|         tent        | 94.86 | 98.16 |
|         bag         | 18.95 | 24.15 |
|       minibike      | 73.53 | 85.26 |
|        cradle       | 79.39 | 94.99 |
|         oven        |  50.8 | 77.29 |
|         ball        |  50.8 |  61.8 |
|         food        | 57.46 | 68.13 |
|         step        | 13.11 | 15.34 |
|         tank        | 33.42 | 36.27 |
|      trade name     | 19.69 | 21.62 |
|      microwave      | 79.51 | 85.19 |
|         pot         | 40.92 | 48.68 |
|        animal       | 62.86 | 66.61 |
|       bicycle       | 53.48 |  76.9 |
|         lake        | 43.06 | 48.81 |
|      dishwasher     | 62.97 | 77.94 |
|        screen       | 64.13 | 86.63 |
|       blanket       | 12.65 | 15.11 |
|      sculpture      |  55.4 | 71.16 |
|         hood        |  56.1 | 61.26 |
|        sconce       | 44.63 | 54.54 |
|         vase        | 40.14 | 60.16 |
|    traffic light    | 31.21 | 48.37 |
|         tray        |  5.48 |  7.83 |
|        ashcan       | 40.68 | 59.46 |
|         fan         | 59.42 |  70.0 |
|         pier        | 36.14 | 49.78 |
|      crt screen     |  9.07 | 22.83 |
|        plate        | 50.39 | 68.97 |
|       monitor       | 13.44 | 17.81 |
|    bulletin board   | 47.89 | 68.15 |
|        shower       |  1.94 |  5.87 |
|       radiator      | 56.73 | 64.17 |
|        glass        | 14.15 | 15.59 |
|        clock        | 35.96 | 42.74 |
|         flag        | 39.61 | 43.14 |
+---------------------+-------+-------+
Summary:

+-------+-------+-------+
|  aAcc |  mIoU |  mAcc |
+-------+-------+-------+
| 82.45 | 48.53 | 60.22 |
+-------+-------+-------+
alihassanijr commented 2 years ago

All three tests passed, ready to merge into main.