Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
Apache License 2.0
3.07k stars 307 forks source link

Fix PReLU Broadcasting Bug for Multiple Parameters #565

Open hishambarakat16 opened 2 months ago

hishambarakat16 commented 2 months ago

#################Summary################# Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors.

#################Changes Made################# Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1.

#################Original Code:#################

def init(self, numparameters=1, init=0.25): self.num_parameters = num_parameters self.weight = init.constant((numparameters,), "float32", init)

def execute(self, x): if self.num_parameters != 1: assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU" return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) jt.minimum(0, x) else: return jt.maximum(0, x) + self.weight jt.minimum(0, x)

############Updated Code:##############

def init(self, numparameters=1, init=0.25): self.num_parameters = num_parameters self.weight = init.constant((numparameters,), "float32", init)

def execute(self, x): if self.num_parameters != 1: assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU" weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, ([1] (len(x.shape) - 2))]) return jt.maximum(0, x) + weight_broadcasted jt.minimum(0, x) else: return jt.maximum(0, x) + self.weight jt.minimum(0, x)

#################Testing################# Tested the updated PReLU function with various configurations to ensure proper functionality:

import jittor as jt from jittor import nn

Create input data with the specified shape

def create_input_data(shape): num_elements = 1 for dim in shape: num_elements *= dim return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape)

Test the PReLU activation function

def test_prelu(num_parameters, input_shape): prelu_layer = nn.PReLU(num_parameters=num_parameters) input_data = create_input_data(input_shape) print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}") print(f"Input Data:\n{input_data.numpy()}") output_data = prelu_layer(input_data) print(f"Output Data (PReLU):\n{output_data.numpy()}\n")

if name == "main": test_configs = [ (1, (5,)), # Single parameter (5, (5, 5)), # Five parameters matching the number of channels (3, (3, 3)), # Three parameters matching the number of channels ] for num_parameters, input_shape in test_configs: test_prelu(num_parameters, input_shape)

#################Test Results:#################

Testing PReLU with num_parameters=1 and input_shape=(5,) Input Data: [-3. -2. -1. 0. 1.] Output Data (PReLU): [-0.75 -0.5 -0.25 0. 1. ]

Testing PReLU with num_parameters=5 and input_shape=(5, 5) Input Data: [[-13. -12. -11. -10. -9.] [ -8. -7. -6. -5. -4.] [ -3. -2. -1. 0. 1.] [ 2. 3. 4. 5. 6.] [ 7. 8. 9. 10. 11.]] Output Data (PReLU): [[-3.25 -3. -2.75 -2.5 -2.25] [-2. -1.75 -1.5 -1.25 -1. ] [-0.75 -0.5 -0.25 0. 1. ] [ 2. 3. 4. 5. 6. ] [ 7. 8. 9. 10. 11. ]]

Testing PReLU with num_parameters=3 and input_shape=(3, 3) Input Data: [[-5. -4. -3.] [-2. -1. 0.] [ 1. 2. 3.]] Output Data (PReLU): [[-1.25 -1. -0.75] [-0.5 -0.25 0. ] [ 1. 2. 3. ]]

################################## This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions.