mantasu / cs231n

Shortest solutions for CS231n 2021-2024
248 stars 55 forks source link

The problem in the conv_backward_naive #3

Closed ChuyuWang949 closed 1 year ago

ChuyuWang949 commented 1 year ago

In this function, the author uses the np.insert() function to supplement the missing rows and columns in dout. Just like: dout = np.insert(dout, range(1, H_o), [[0]]*(stride-1), axis=2) if stride > 1 else dout But when the stride is more than 3, this cannot be done because dout is missing at least two columns between each column element, as well as between rows, and np.insert() cannot insert more than one row in more than one row at a time. np.insert() only can insert one row in more than one row at a time.

mantasu commented 1 year ago

You're right. To insert multiple rows/cols, indices should be repeated rather than values.

ChuyuWang949 commented 1 year ago

Yes! That's how it works