neurosim / DNN_NeuroSim_V1.3

Benchmark framework of compute-in-memory based accelerators for deep neural network (inference engine focused)
62 stars 36 forks source link

Significant speedup (using CUDA) can be achieved by initialising arrays on the gpu #8

Closed datMaffin closed 3 years ago

datMaffin commented 3 years ago

I have noticed that, for example, torch.zeros_like() do not define a device. A significant speedup can be achieved when one is defining the device to be equal to 'cuda' in those calls:

Currently:

torch.zeros_like(outputOrignal)

Improved:

torch.zeros_like(outputOrignal, device='cuda')

The same is relevant for torch.normal.

The changes would mainly be needed to be applied to: https://github.com/neurosim/DNN_NeuroSim_V1.3/blob/3754e10e939e80b4952ba4e09a3afb7972456fc9/Inference_pytorch/modules/quantization_cpu_np_infer.py

neurosim commented 3 years ago

Hi, thanks very much for your suggestion! But for functions like torch.zeros_like(), the default device is the same as the input tensor if not specified, so the output tensor is already created in cuda in our codes. But anyway, we also want to speed it up and would like to accept any useful ideas.

datMaffin commented 3 years ago

Thanks for the reply.

Hmm, I double checked and my patch definitely seems to significantly speed up the FC layer calculations.

The following patch(es) were applied and inference set to 1:


Notes

I was pretty sure that the .cuda() calls did not really matter (nor make it worse)

The first patch alone did not change the bad performance of the FC layer (the second relies on it though).


From 48e10984168b6cbe3266fd792b761c5d6e04cbe1 Mon Sep 17 00:00:00 2001
From: Marvin Dostal <st148727@stud.uni-stuttgart.de>
Date: Mon, 2 Aug 2021 13:15:37 +0200
Subject: [PATCH] Replace numpy random with torch random

---
 .../modules/quantization_cpu_np_infer.py      | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/Inference_pytorch/modules/quantization_cpu_np_infer.py b/Inference_pytorch/modules/quantization_cpu_np_infer.py
index 8d13013..ab72dd5 100644
--- a/Inference_pytorch/modules/quantization_cpu_np_infer.py
+++ b/Inference_pytorch/modules/quantization_cpu_np_infer.py
@@ -77,13 +77,13 @@ class QConv2d(nn.Conv2d):
                                 remainder = torch.fmod(X_decimal, cellRange)*mask
                                 # retention
                                 remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                                variation = np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
                                 X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                                 # Now also consider weight has on/off ratio effects
                                 # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                                 # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]
                                 remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                                remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                                remainderQ = remainderQ + remainderQ*variation.cuda()
                                 outputPartial= F.conv2d(input, remainderQ*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                 outputDummyPartial= F.conv2d(input, dummyP*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                 scaler = cellRange**k
@@ -106,13 +106,13 @@ class QConv2d(nn.Conv2d):
                                     remainder = torch.fmod(X_decimal, cellRange)*mask
                                     # retention
                                     remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                                    variation = np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                    variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
                                     X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                                     # Now also consider weight has on/off ratio effects
                                     # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                                     # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]
                                     remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                                    remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                                    remainderQ = remainderQ + remainderQ*variation.cuda()
                                     outputPartial= F.conv2d(inputB, remainderQ*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                     outputDummyPartial= F.conv2d(inputB, dummyP*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                     # Add ADC quanization effects here !!!
@@ -143,13 +143,13 @@ class QConv2d(nn.Conv2d):
                                     remainder = torch.fmod(X_decimal, cellRange)*mask
                                     # retention
                                     remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                                    variation = np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                    variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
                                     X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                                     # Now also consider weight has on/off ratio effects
                                     # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                                     # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]*(cellRange-1)
                                     remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                                    remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                                    remainderQ = remainderQ + remainderQ*variation.cuda()
                                     outputPartial= F.conv2d(inputB, remainderQ*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                     outputDummyPartial= F.conv2d(inputB, dummyP*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                     # Add ADC quanization effects here !!!
@@ -255,13 +255,13 @@ class QLinear(nn.Linear):
                         remainder = torch.fmod(X_decimal, cellRange)*mask
                         # retention
                         remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                        variation = np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                        variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
                         X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                         # Now also consider weight has on/off ratio effects
                         # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                         # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]
                         remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                        remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                        remainderQ = remainderQ + remainderQ*variation.cuda()
                         outputPartial= F.linear(inputB, remainderQ*mask, self.bias)
                         outputDummyPartial= F.linear(inputB, dummyP*mask, self.bias)
                         # Add ADC quanization effects here !!!
@@ -291,13 +291,13 @@ class QLinear(nn.Linear):
                             remainder = torch.fmod(X_decimal, cellRange)*mask
                             # retention
                             remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                            variation = np.random.normal(0, self.vari, list(remainder.size())).astype(np.float32)
+                            variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
                             X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                             # Now also consider weight has on/off ratio effects
                             # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                             # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]*(cellRange-1)
                             remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                            remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                            remainderQ = remainderQ + remainderQ*variation.cuda()
                             outputPartial= F.linear(inputB, remainderQ*mask, self.bias)
                             outputDummyPartial= F.linear(inputB, dummyP*mask, self.bias)
                             # Add ADC quanization effects here !!!
From 2e8bdb3720009715c7544f5a1356d74193d31f7c Mon Sep 17 00:00:00 2001
From: Marvin Dostal <st148727@stud.uni-stuttgart.de>
Date: Tue, 3 Aug 2021 12:25:55 +0200
Subject: [PATCH] Optimize

---
 .../modules/quantization_cpu_np_infer.py      | 122 +++++++++++-------
 1 file changed, 76 insertions(+), 46 deletions(-)

diff --git a/Inference_pytorch/modules/quantization_cpu_np_infer.py b/Inference_pytorch/modules/quantization_cpu_np_infer.py
index ab72dd5..d22629b 100644
--- a/Inference_pytorch/modules/quantization_cpu_np_infer.py
+++ b/Inference_pytorch/modules/quantization_cpu_np_infer.py
@@ -52,12 +52,12 @@ class QConv2d(nn.Conv2d):
             upper = 1
             lower = 1/onoffratio

-            output = torch.zeros_like(outputOrignal)
+            output = torch.zeros_like(outputOrignal, device='cuda')
             del outputOrignal
             cellRange = 2**self.cellBit   # cell precision is 4

             # Now consider on/off ratio
-            dummyP = torch.zeros_like(weight)
+            dummyP = torch.zeros_like(weight, device='cuda')
             dummyP[:,:,:,:] = (cellRange-1)*(upper+lower)/2

             for i in range (self.weight.shape[2]):
@@ -66,19 +66,25 @@ class QConv2d(nn.Conv2d):
                     numSubArray = int(weight.shape[1]/self.subArray)
                     # cut into different subArrays
                     if numSubArray == 0:
-                        mask = torch.zeros_like(weight)
+                        mask = torch.zeros_like(weight, device='cuda')
                         mask[:,:,i,j] = 1
                         if weight.shape[1] == 3:
                             # after get the spacial kernel, need to transfer floating weight [-1, 1] to binarized ones
                             X_decimal = torch.round((2**bitWeight - 1)/2 * (weight+1) + 0)*mask
-                            outputP = torch.zeros_like(output)
-                            outputD = torch.zeros_like(output)
+                            X_decimal = X_decimal.cuda()
+                            outputP = torch.zeros_like(output, device='cuda')
+                            outputD = torch.zeros_like(output, device='cuda')
                             for k in range (int(bitWeight/self.cellBit)):
                                 remainder = torch.fmod(X_decimal, cellRange)*mask
                                 # retention
                                 remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                                variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                remainder = remainder.cuda()
+                                variation = torch.zeros(weight.size(), device='cuda')
+                                if (self.vari != 0):
+                                    variation = torch.normal(torch.zeros(weight.size(), device='cuda'), torch.full(weight.size(), self.vari, device='cuda')) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                variation = variation.cuda()
                                 X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
+                                X_decimal = X_decimal.cuda()
                                 # Now also consider weight has on/off ratio effects
                                 # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                                 # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]
@@ -94,19 +100,24 @@ class QConv2d(nn.Conv2d):
                         else:
                             # quantize input into binary sequence
                             inputQ = torch.round((2**bitActivation - 1)/1 * (input-0) + 0)
-                            outputIN = torch.zeros_like(output)
+                            inputQ = inputQ.cuda()
+                            outputIN = torch.zeros_like(output, device='cuda')
                             for z in range(bitActivation):
                                 inputB = torch.fmod(inputQ, 2)
                                 inputQ = torch.round((inputQ-inputB)/2)
-                                outputP = torch.zeros_like(output)
+                                outputP = torch.zeros_like(output, device='cuda')
                                 # after get the spacial kernel, need to transfer floating weight [-1, 1] to binarized ones
                                 X_decimal = torch.round((2**bitWeight - 1)/2 * (weight+1) + 0)*mask
-                                outputD = torch.zeros_like(output)
+                                X_decimal = X_decimal.cuda()
+                                outputD = torch.zeros_like(output, device='cuda')
                                 for k in range (int(bitWeight/self.cellBit)):
                                     remainder = torch.fmod(X_decimal, cellRange)*mask
                                     # retention
                                     remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                                    variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                    remainder = remainder.cuda()
+                                    variation = torch.zeros(weight.size(), device='cuda')
+                                    if (self.vari != 0):
+                                        variation = torch.normal(torch.zeros(weight.size(), device='cuda'), torch.full(weight.size(), self.vari, device='cuda')) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
                                     X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                                     # Now also consider weight has on/off ratio effects
                                     # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
@@ -127,24 +138,28 @@ class QConv2d(nn.Conv2d):
                     else:
                         # quantize input into binary sequence
                         inputQ = torch.round((2**bitActivation - 1)/1 * (input-0) + 0)
-                        outputIN = torch.zeros_like(output)
+                        inputQ = inputQ.cuda()
+                        outputIN = torch.zeros_like(output, device='cuda')
                         for z in range(bitActivation):
                             inputB = torch.fmod(inputQ, 2)
                             inputQ = torch.round((inputQ-inputB)/2)
-                            outputP = torch.zeros_like(output)
+                            outputP = torch.zeros_like(output, device='cuda')
                             for s in range(numSubArray):
-                                mask = torch.zeros_like(weight)
+                                mask = torch.zeros_like(weight, device='cuda')
                                 mask[:,(s*self.subArray):(s+1)*self.subArray, i, j] = 1
                                 # after get the spacial kernel, need to transfer floating weight [-1, 1] to binarized ones
                                 X_decimal = torch.round((2**bitWeight - 1)/2 * (weight+1) + 0)*mask
-                                outputSP = torch.zeros_like(output)
-                                outputD = torch.zeros_like(output)
+                                X_decimal = X_decimal.cuda()
+                                outputSP = torch.zeros_like(output, device='cuda')
+                                outputD = torch.zeros_like(output, device='cuda')
                                 for k in range (int(bitWeight/self.cellBit)):
                                     remainder = torch.fmod(X_decimal, cellRange)*mask
                                     # retention
                                     remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                                    variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                    remainder = remainder.cuda()
+                                    variation = torch.normal(torch.zeros(weight.size(), device='cuda'), torch.full(weight.size(), self.vari, device='cuda')) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
                                     X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
+                                    X_decimal = X_decimal.cuda()
                                     # Now also consider weight has on/off ratio effects
                                     # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                                     # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]*(cellRange-1)
@@ -167,16 +182,16 @@ class QConv2d(nn.Conv2d):
             output = output/(2**bitWeight)   # since weight range was convert from [-1, 1] to [-256, 256]

         elif self.inference == 1:
-            weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach()
-            weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach()
+            weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach().cuda()
+            weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach().cuda()
             weight = wage_quantizer.Retention(weight,self.t,self.v,self.detect,self.target)
             input = wage_quantizer.Q(input,self.wl_input)
             output= F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
             output = wage_quantizer.LinearQuantizeOut(output, self.ADCprecision)
         else:
             # original WAGE QCov2d
-            weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach()
-            weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach()
+            weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach().cuda()
+            weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach().cuda()
             weight = wage_quantizer.Retention(weight,self.t,self.v,self.detect,self.target)
             output= F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
         output = output/self.scale
@@ -217,10 +232,11 @@ class QLinear(nn.Linear):
     @weak_script_method
     def forward(self, input):

-        weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach()
-        weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach()
+        weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach().cuda()
+        weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach().cuda()
+        input = input.cuda()
         outputOrignal = F.linear(input, weight, self.bias)
-        output = torch.zeros_like(outputOrignal)
+        output = torch.zeros_like(outputOrignal, device='cuda')

         bitWeight = int(self.wl_weight)
         bitActivation = int(self.wl_input)
@@ -230,33 +246,40 @@ class QLinear(nn.Linear):
             onoffratio = self.onoffratio
             upper = 1
             lower = 1/onoffratio
-            output = torch.zeros_like(outputOrignal)
+            output = torch.zeros_like(outputOrignal, device='cuda')
             cellRange = 2**self.cellBit   # cell precision is 4
             # Now consider on/off ratio
-            dummyP = torch.zeros_like(weight)
+            dummyP = torch.zeros_like(weight, device='cuda')
             dummyP[:,:] = (cellRange-1)*(upper+lower)/2
             # need to divide to different subArray
             numSubArray = int(weight.shape[1]/self.subArray)

             if numSubArray == 0:
-                mask = torch.zeros_like(weight)
+                mask = torch.zeros_like(weight, device='cuda')
                 mask[:,:] = 1
                 # quantize input into binary sequence
                 inputQ = torch.round((2**bitActivation - 1)/1 * (input-0) + 0)
-                outputIN = torch.zeros_like(outputOrignal)
+                inputQ = inputQ.cuda()
+                outputIN = torch.zeros_like(outputOrignal, device='cuda')
                 for z in range(bitActivation):
                     inputB = torch.fmod(inputQ, 2)
                     inputQ = torch.round((inputQ-inputB)/2)
                     # after get the spacial kernel, need to transfer floating weight [-1, 1] to binarized ones
                     X_decimal = torch.round((2**bitWeight - 1)/2 * (weight+1) + 0)*mask
-                    outputP = torch.zeros_like(outputOrignal)
-                    outputD = torch.zeros_like(outputOrignal)
+                    outputP = torch.zeros_like(outputOrignal, device='cuda')
+                    outputD = torch.zeros_like(outputOrignal, device='cuda')
                     for k in range (int(bitWeight/self.cellBit)):
                         remainder = torch.fmod(X_decimal, cellRange)*mask
+                        remainder = remainder.cuda()
                         # retention
                         remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                        variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                        remainder = remainder.cuda()
+                        variation = torch.zeros(weight.size(), device='cuda')
+                        if (self.vari != 0):
+                            variation = torch.normal(torch.zeros(weight.size(), device='cuda'), torch.full(weight.size(), self.vari, device='cuda')) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                        variation = variation.cuda()
                         X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
+                        X_decimal = X_decimal.cuda()
                         # Now also consider weight has on/off ratio effects
                         # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                         # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]
@@ -275,24 +298,31 @@ class QLinear(nn.Linear):
                 output = output + outputIN/(2**bitActivation)
             else:
                 inputQ = torch.round((2**bitActivation - 1)/1 * (input-0) + 0)
-                outputIN = torch.zeros_like(outputOrignal)
+                inputQ = inputQ.cuda()
+                outputIN = torch.zeros_like(outputOrignal, device='cuda')
                 for z in range(bitActivation):
                     inputB = torch.fmod(inputQ, 2)
                     inputQ = torch.round((inputQ-inputB)/2)
-                    outputP = torch.zeros_like(outputOrignal)
+                    outputP = torch.zeros_like(outputOrignal, device='cuda')
                     for s in range(numSubArray):
-                        mask = torch.zeros_like(weight)
+                        mask = torch.zeros_like(weight, device='cuda')
                         mask[:,(s*self.subArray):(s+1)*self.subArray] = 1
                         # after get the spacial kernel, need to transfer floating weight [-1, 1] to binarized ones
                         X_decimal = torch.round((2**bitWeight - 1)/2 * (weight+1) + 0)*mask
-                        outputSP = torch.zeros_like(outputOrignal)
-                        outputD = torch.zeros_like(outputOrignal)
+                        outputSP = torch.zeros_like(outputOrignal, device='cuda')
+                        outputD = torch.zeros_like(outputOrignal, device='cuda')
                         for k in range (int(bitWeight/self.cellBit)):
                             remainder = torch.fmod(X_decimal, cellRange)*mask
                             # retention
+                            remainder = remainder.cuda()
                             remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                            variation = torch.normal(torch.zeros(weight.size()), torch.full(weight.size(), self.vari)) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                            remainder = remainder.cuda()
+                            variation = torch.zeros(weight.size(), device='cuda')
+                            if (self.vari != 0):
+                                variation = torch.normal(torch.zeros(weight.size(), device='cuda'), torch.full(weight.size(), self.vari, device='cuda')) # np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                            variation = variation.cuda()
                             X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
+                            X_decimal = X_decimal.cuda()
                             # Now also consider weight has on/off ratio effects
                             # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                             # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]*(cellRange-1)
@@ -314,18 +344,18 @@ class QLinear(nn.Linear):
             output = output/(2**bitWeight)

         elif self.inference == 1:
-            weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach()
-            weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach()
-            weight = wage_quantizer.Retention(weight,self.t,self.v,self.detect,self.target)
-            input = wage_quantizer.Q(input,self.wl_input)
-            output= F.linear(input, weight, self.bias)
-            output = wage_quantizer.LinearQuantizeOut(output, self.ADCprecision)
+            weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach().cuda()
+            weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach().cuda()
+            weight = wage_quantizer.Retention(weight,self.t,self.v,self.detect,self.target).cuda()
+            input = wage_quantizer.Q(input,self.wl_input).cuda()
+            output= F.linear(input, weight, self.bias).cuda()
+            output = wage_quantizer.LinearQuantizeOut(output, self.ADCprecision).cuda()
         else:
             # original WAGE QCov2d
-            weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach()
-            weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach()
-            weight = wage_quantizer.Retention(weight,self.t,self.v,self.detect,self.target)
-            output = F.linear(input, weight, self.bias)
+            weight1 = self.weight * self.scale + (self.weight - self.weight * self.scale).detach().cuda()
+            weight = weight1 + (wage_quantizer.Q(weight1,self.wl_weight) -weight1).detach().cuda()
+            weight = wage_quantizer.Retention(weight,self.t,self.v,self.detect,self.target).cuda()
+            output = F.linear(input, weight, self.bias).cuda()

         output = output/self.scale
         output = wage_quantizer.WAGEQuantizer_f(output,self.wl_activate, self.wl_error)
datMaffin commented 3 years ago

But anyway, we also want to speed it up and would like to accept any useful ideas.

I also have a patch that enabled multi-gpu support:

From d0d097ecb6e6882b30ee9d5bf91a6aa569137888 Mon Sep 17 00:00:00 2001
From: Marvin Dostal <st148727@stud.uni-stuttgart.de>
Date: Tue, 3 Aug 2021 13:14:35 +0200
Subject: [PATCH] Implement multi gpu support

---
 Inference_pytorch/inference.py | 6 +++++-
 Inference_pytorch/utee/hook.py | 5 ++++-
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/Inference_pytorch/inference.py b/Inference_pytorch/inference.py
index 5de29b7..797eba9 100644
--- a/Inference_pytorch/inference.py
+++ b/Inference_pytorch/inference.py
@@ -4,6 +4,7 @@ import time
 from utee import misc
 import numpy as np
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
 import torch.optim as optim
 from torch.autograd import Variable
@@ -83,6 +84,8 @@ if args.model == 'VGG8':
     from models import VGG
     model_path = './log/VGG8.pth'   # WAGE mode pretrained model
     modelCF = VGG.vgg8(args = args, logger=logger, pretrained = model_path)
+    modelCF = modelCF.cuda()
+    # modelCF = nn.DataParallel(modelCF, device_ids=[0,1])
 elif args.model == 'DenseNet40':
     from models import DenseNet
     model_path = './log/DenseNet40.pth'     # WAGE mode pretrained model
@@ -97,7 +100,7 @@ else:
     raise ValueError("Unknown model type")

 if args.cuda:
-   modelCF.cuda()
+   modelCF.cuda(); modelCF = nn.DataParallel(modelCF)

 best_acc, old_file = 0, None
 t_begin = time.time()
@@ -113,6 +116,7 @@ criterion = torch.nn.CrossEntropyLoss()

 # for data, target in test_loader:
 for i, (data, target) in enumerate(test_loader):
+    print("Data " + str(i) + " of " + str(len(test_loader)))
     if i==0:
         hook_handle_list = hook.hardware_evaluation(modelCF,args.wl_weight,args.wl_activate,args.model,args.mode)
     indx_target = target.clone()
diff --git a/Inference_pytorch/utee/hook.py b/Inference_pytorch/utee/hook.py
index a482f0d..fd64249 100644
--- a/Inference_pytorch/utee/hook.py
+++ b/Inference_pytorch/utee/hook.py
@@ -12,6 +12,9 @@ from utee import float_quantizer
 def Neural_Sim(self, input, output): 
     global model_n, FP

+    if input[0].get_device() != 0:
+        return
+
     print("quantize layer ", self.name)
     input_file_name =  './layer_record_' + str(model_n) + '/input' + str(self.name) + '.csv'
     weight_file_name =  './layer_record_' + str(model_n) + '/weight' + str(self.name) + '.csv'
@@ -121,4 +124,4 @@ def hardware_evaluation(model,wl_weight,wl_activation,model_name,mode):
     for i, layer in enumerate(model.modules()):
         if isinstance(layer, (FConv2d, QConv2d, nn.Conv2d)) or isinstance(layer, (FLinear, QLinear, nn.Linear)):
             hook_handle_list.append(layer.register_forward_hook(Neural_Sim))
-    return hook_handle_list
\ No newline at end of file
+    return hook_handle_list
neurosim commented 3 years ago

Thanks very much for your help! I think the main problem happens to the codes about "variation" of creating normal distributions. The codes have been updated. Thanks again!

datMaffin commented 3 years ago

Thanks very much for your help! I think the main problem happens to the codes about "variation" of creating normal distributions. The codes have been updated. Thanks again!

Ok, after further testing the variation calculation seems to be the only place that was not executed on the GPU.

Am I correct in assuming that the following change was based on this issue discussion? https://github.com/neurosim/3D_NeuroSim_V1.0/commit/99b62756df414206a35fcd4747df7e3f486875a9

In addition to your change of replacing numpy with pytorch, I had to add device='cuda' for the torch.full call to get the full acceleration and GPU utilization as with my (shotgun approach :wink: ) patches. See also: https://pytorch.org/docs/1.0.0/torch.html?highlight=full#torch.full

The patch for DNN V1.3 would look like:

diff --git a/Inference_pytorch/modules/quantization_cpu_np_infer.py b/Inference_pytorch/modules/quantization_cpu_np_infer.py
index 8d13013..4faf7bd 100644
--- a/Inference_pytorch/modules/quantization_cpu_np_infer.py
+++ b/Inference_pytorch/modules/quantization_cpu_np_infer.py
@@ -77,13 +77,13 @@ class QConv2d(nn.Conv2d):
                                 remainder = torch.fmod(X_decimal, cellRange)*mask
                                 # retention
                                 remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                                variation = np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                variation = torch.normal(0., torch.full(weight.size(), self.vari, device='cuda')) 
                                 X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                                 # Now also consider weight has on/off ratio effects
                                 # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                                 # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]
                                 remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                                remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                                remainderQ = remainderQ + remainderQ*variation
                                 outputPartial= F.conv2d(input, remainderQ*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                 outputDummyPartial= F.conv2d(input, dummyP*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                 scaler = cellRange**k
@@ -106,13 +106,13 @@ class QConv2d(nn.Conv2d):
                                     remainder = torch.fmod(X_decimal, cellRange)*mask
                                     # retention
                                     remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                                    variation = np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                    variation = torch.normal(0., torch.full(weight.size(), self.vari, device='cuda')) 
                                     X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                                     # Now also consider weight has on/off ratio effects
                                     # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                                     # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]
                                     remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                                    remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                                    remainderQ = remainderQ + remainderQ*variation
                                     outputPartial= F.conv2d(inputB, remainderQ*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                     outputDummyPartial= F.conv2d(inputB, dummyP*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                     # Add ADC quanization effects here !!!
@@ -143,13 +143,13 @@ class QConv2d(nn.Conv2d):
                                     remainder = torch.fmod(X_decimal, cellRange)*mask
                                     # retention
                                     remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                                    variation = np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                                    variation = torch.normal(0., torch.full(weight.size(), self.vari, device='cuda')) 
                                     X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                                     # Now also consider weight has on/off ratio effects
                                     # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                                     # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]*(cellRange-1)
                                     remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                                    remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                                    remainderQ = remainderQ + remainderQ*variation
                                     outputPartial= F.conv2d(inputB, remainderQ*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                     outputDummyPartial= F.conv2d(inputB, dummyP*mask, self.bias, self.stride, self.padding, self.dilation, self.groups)
                                     # Add ADC quanization effects here !!!
@@ -255,13 +255,13 @@ class QLinear(nn.Linear):
                         remainder = torch.fmod(X_decimal, cellRange)*mask
                         # retention
                         remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                        variation = np.random.normal(0, self.vari, list(weight.size())).astype(np.float32)
+                        variation = torch.normal(0., torch.full(weight.size(), self.vari, device='cuda')) 
                         X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                         # Now also consider weight has on/off ratio effects
                         # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                         # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]
                         remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                        remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                        remainderQ = remainderQ + remainderQ*variation
                         outputPartial= F.linear(inputB, remainderQ*mask, self.bias)
                         outputDummyPartial= F.linear(inputB, dummyP*mask, self.bias)
                         # Add ADC quanization effects here !!!
@@ -291,13 +291,13 @@ class QLinear(nn.Linear):
                             remainder = torch.fmod(X_decimal, cellRange)*mask
                             # retention
                             remainder = wage_quantizer.Retention(remainder,self.t,self.v,self.detect,self.target)
-                            variation = np.random.normal(0, self.vari, list(remainder.size())).astype(np.float32)
+                            variation = torch.normal(0., torch.full(weight.size(), self.vari, device='cuda')) 
                             X_decimal = torch.round((X_decimal-remainder)/cellRange)*mask
                             # Now also consider weight has on/off ratio effects
                             # Here remainder is the weight mapped to Hardware, so we introduce on/off ratio in this value
                             # the range of remainder is [0, cellRange-1], we truncate it to [lower, upper]*(cellRange-1)
                             remainderQ = (upper-lower)*(remainder-0)+(cellRange-1)*lower   # weight cannot map to 0, but to Gmin
-                            remainderQ = remainderQ + remainderQ*torch.from_numpy(variation).cuda()
+                            remainderQ = remainderQ + remainderQ*variation
                             outputPartial= F.linear(inputB, remainderQ*mask, self.bias)
                             outputDummyPartial= F.linear(inputB, dummyP*mask, self.bias)
                             # Add ADC quanization effects here !!!
datMaffin commented 2 years ago

@neurosim Quick heads up that your changes contain a typo, "deivce" should be "device".

neurosim commented 2 years ago

Oh no. Thanks for pointing it out.