Closed leeys888 closed 6 years ago
open func updateWeights(device: MTLDevice) { guard let network = network else { return } if #available(iOS 11.0, *) { if let weightsPointer = weightsPointer { dataSource = ConvolutionDataSource(cnnDescriptor: cnnDescriptor, weights: UnsafeMutableRawPointer(mutating: weightsPointer.pointer()), bias: UnsafeMutablePointer(mutating: biasPointer?.pointer() as UnsafePointer<Float>?)) } else { dataSource = ConvolutionDataSource(cnnDescriptor: cnnDescriptor, parameterLoader: network.parameterLoader, layerId: id, weightCount: getWeightsSize(), biasCount: convSize.outputChannels) } makeConv(device: device, weights: nil, bias: nil) } else { let weights = weightsPointer?.pointer() ?? network.parameterLoader.loadWeights(for: id, modifier: Convolution.weightModifier, size: getWeightsSize()) var bias: UnsafePointer<Float>? = nil if useBias { bias = biasPointer?.pointer() ?? network.parameterLoader.loadWeights(for: id, modifier: Convolution.biasModifier, size: convSize.outputChannels) } makeConv(device: device, weights: weights, bias: bias) } }
i think that "ConvolutionDataSource" code will be modified. This code does not consider whether to use bias or not. So I suggest the following code.
else { dataSource = ConvolutionDataSource(cnnDescriptor: cnnDescriptor, parameterLoader: network.parameterLoader, layerId: id, weightCount: getWeightsSize(), biasCount: useBias ? convSize.outputChannels : 0) }
Thank you.
It seems you are right. @mats-claassen can you take a look when you are back?
i think that "ConvolutionDataSource" code will be modified. This code does not consider whether to use bias or not. So I suggest the following code.
Thank you.