Skip to content

Conversation

@hishambarakat16
Copy link

#################Summary#################
Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors.

#################Changes Made#################
Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1.

#################Original Code:#################

def init(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
if self.num_parameters != 1:
assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

############Updated Code:##############

def init(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
if self.num_parameters != 1:
assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU"
weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, *([1] * (len(x.shape) - 2))])
return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

#################Testing#################
Tested the updated PReLU function with various configurations to ensure proper functionality:

import jittor as jt
from jittor import nn

Create input data with the specified shape

def create_input_data(shape):
num_elements = 1
for dim in shape:
num_elements *= dim
return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape)

Test the PReLU activation function

def test_prelu(num_parameters, input_shape):
prelu_layer = nn.PReLU(num_parameters=num_parameters)
input_data = create_input_data(input_shape)
print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}")
print(f"Input Data:\n{input_data.numpy()}")
output_data = prelu_layer(input_data)
print(f"Output Data (PReLU):\n{output_data.numpy()}\n")

if name == "main":
test_configs = [
(1, (5,)), # Single parameter
(5, (5, 5)), # Five parameters matching the number of channels
(3, (3, 3)), # Three parameters matching the number of channels
]
for num_parameters, input_shape in test_configs:
test_prelu(num_parameters, input_shape)

#################Test Results:#################

Testing PReLU with num_parameters=1 and input_shape=(5,) Input Data:
[-3. -2. -1. 0. 1.]
Output Data (PReLU):
[-0.75 -0.5 -0.25 0. 1. ]

Testing PReLU with num_parameters=5 and input_shape=(5, 5) Input Data:
[[-13. -12. -11. -10. -9.]
[ -8. -7. -6. -5. -4.]
[ -3. -2. -1. 0. 1.]
[ 2. 3. 4. 5. 6.]
[ 7. 8. 9. 10. 11.]]
Output Data (PReLU):
[[-3.25 -3. -2.75 -2.5 -2.25]
[-2. -1.75 -1.5 -1.25 -1. ]
[-0.75 -0.5 -0.25 0. 1. ]
[ 2. 3. 4. 5. 6. ]
[ 7. 8. 9. 10. 11. ]]

Testing PReLU with num_parameters=3 and input_shape=(3, 3) Input Data:
[[-5. -4. -3.]
[-2. -1. 0.]
[ 1. 2. 3.]]
Output Data (PReLU):
[[-1.25 -1. -0.75]
[-0.5 -0.25 0. ]
[ 1. 2. 3. ]]

##################################
This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions.

co63oc and others added 30 commits May 22, 2023 13:18
add complex matmul, inv, qr, eig, and svd
fix issue 531,530;update jt.nn.PixelShuffle/jt.histc
fix issue 525;update jt.nn.Reflection2d/Replication2d
fix issue 527,526;update jt.zeros/ones/full/randn/randint/random
fix issue 529;update contrib.argmax_pool()
fix issue 528;update conv_transpose
fix issue 521;update jt.nn.MaxUnpool2d/MaxUnpool3d
fix issue 522,520,519,516; update jt.Pool/Pool3d
fix issue 523;update jt.nn.Conv1d/Conv3d/conv2d/conv3d
Update ACL library and fix bugs in ACL integration
fix: some function&class input illegal paramters
LDYang694 and others added 11 commits June 5, 2024 22:31
fix numpy version
check parameters' positive in jt.nn.fold
#################Summary#################
Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors.

#################Changes Made#################
Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1.

#################Code Changes#################
#################Original Code:#################

def __init__(self, num_parameters=1, init_=0.25):
    self.num_parameters = num_parameters
    self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
    if self.num_parameters != 1:
        assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
        return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
    else:
        return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

############Updated Code:##############

def __init__(self, num_parameters=1, init_=0.25):
    self.num_parameters = num_parameters
    self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
    if self.num_parameters != 1:
        assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU"
        weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, *([1] * (len(x.shape) - 2))])
        return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x)
    else:
        return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

#################Testing#################
Tested the updated PReLU function with various configurations to ensure proper functionality:

import jittor as jt
from jittor import nn

# Create input data with the specified shape
def create_input_data(shape):
    num_elements = 1
    for dim in shape:
        num_elements *= dim
    return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape)

# Test the PReLU activation function
def test_prelu(num_parameters, input_shape):
    prelu_layer = nn.PReLU(num_parameters=num_parameters)
    input_data = create_input_data(input_shape)
    print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}")
    print(f"Input Data:\n{input_data.numpy()}")
    output_data = prelu_layer(input_data)
    print(f"Output Data (PReLU):\n{output_data.numpy()}\n")

if __name__ == "__main__":
    test_configs = [
        (1, (5,)),      # Single parameter
        (5, (5, 5)),    # Five parameters matching the number of channels
        (3, (3, 3)),    # Three parameters matching the number of channels
    ]
    for num_parameters, input_shape in test_configs:
        test_prelu(num_parameters, input_shape)
#################Test Results:#################

Testing PReLU with num_parameters=1 and input_shape=(5,)
Input Data:
[-3. -2. -1.  0.  1.]
Output Data (PReLU):
[-0.75 -0.5  -0.25  0.    1.  ]

Testing PReLU with num_parameters=5 and input_shape=(5, 5)
Input Data:
[[-13. -12. -11. -10.  -9.]
 [ -8.  -7.  -6.  -5.  -4.]
 [ -3.  -2.  -1.   0.   1.]
 [  2.   3.   4.   5.   6.]
 [  7.   8.   9.  10.  11.]]
Output Data (PReLU):
[[-3.25 -3.   -2.75 -2.5  -2.25]
 [-2.   -1.75 -1.5  -1.25 -1.  ]
 [-0.75 -0.5  -0.25  0.    1.  ]
 [ 2.    3.    4.    5.    6.  ]
 [ 7.    8.    9.   10.   11.  ]]

Testing PReLU with num_parameters=3 and input_shape=(3, 3)
Input Data:
[[-5. -4. -3.]
 [-2. -1.  0.]
 [ 1.  2.  3.]]
Output Data (PReLU):
[[-1.25 -1.   -0.75]
 [-0.5  -0.25  0.  ]
 [ 1.    2.    3.  ]]

##################################
This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants