xmartlabs / Bender

Easily craft fast Neural Networks on iOS! Use TensorFlow models. Metal under the hood.
https://xmartlabs.github.io/Bender/
MIT License
1.8k stars 90 forks source link

Convolution updateWeights function needs to be modified. #106

Closed leeys888 closed 6 years ago

leeys888 commented 6 years ago
open func updateWeights(device: MTLDevice) {
        guard let network = network else {
            return
        }

        if #available(iOS 11.0, *) {
            if let weightsPointer = weightsPointer {
                dataSource = ConvolutionDataSource(cnnDescriptor: cnnDescriptor,
                                                   weights: UnsafeMutableRawPointer(mutating: weightsPointer.pointer()),
                                                   bias: UnsafeMutablePointer(mutating: biasPointer?.pointer() as UnsafePointer<Float>?))
            } else {
                dataSource = ConvolutionDataSource(cnnDescriptor: cnnDescriptor, parameterLoader: network.parameterLoader,
                                                   layerId: id, weightCount: getWeightsSize(), biasCount:  convSize.outputChannels)
            }
            makeConv(device: device, weights: nil, bias: nil)
        } else {
            let weights = weightsPointer?.pointer() ?? network.parameterLoader.loadWeights(for: id,
                                                                                           modifier: Convolution.weightModifier,
                                                                                           size: getWeightsSize())

            var bias: UnsafePointer<Float>? = nil
            if useBias {
                bias = biasPointer?.pointer() ?? network.parameterLoader.loadWeights(for: id,
                                                                                     modifier: Convolution.biasModifier,
                                                                                     size: convSize.outputChannels)
            }
            makeConv(device: device, weights: weights, bias: bias)
        }
    }

i think that "ConvolutionDataSource" code will be modified. This code does not consider whether to use bias or not. So I suggest the following code.

else {
                dataSource = ConvolutionDataSource(cnnDescriptor: cnnDescriptor, parameterLoader: network.parameterLoader,
                                                   layerId: id, weightCount: getWeightsSize(), biasCount: useBias ? convSize.outputChannels : 0)
            }

Thank you.

bryant1410 commented 6 years ago

It seems you are right. @mats-claassen can you take a look when you are back?