matlab-deep-learning / pix2pix

Image to Image Translation Using Generative Adversarial Networks
Other
30 stars 12 forks source link

Multi Channel Issue #26

Open posadad1 opened 8 months ago

posadad1 commented 8 months ago

Hi,

I am trying to use a six-channel image; the output is only a three-channel image. I added a quick fix to the training method to let me concatenate the images to make the six channels by adding an extra folder parameter passed down to PairedImageDatastore. The third folder would load and add the image to the A image.

        function data = read(obj)
            imagesA = obj.ImagesA.read();
            imagesB = obj.ImagesB.read();
            imagesC = obj.ImagesC.read();

            % for batch size 1 imagedatastore doesn't wrap in a cell
            if ~iscell(imagesA)
                imagesA(:,:,4:6) = imagesC;
                imagesA = {imagesA};
                imagesB = {imagesB};
            end
           [transformedA, transformedB] = ...
                p2p.data.transformImagePair(imagesA, imagesB, ...
                                            obj.PreSize, obj.CropSize, ...
                                            obj.Augmenter);
            [A, B] = obj.normaliseImages(transformedA, transformedB);
            data = table(A, B);
        end

Before the training, I changed the options as:

options = p2p.trainingOptions('InputChannels',6,'OutputChannels',3);

Then, when training starts, it completes the first iteration of the first epoch, but then it breaks. With the following output:

epoch: 1, it: 50, G: 77.888550 (L1: 0.771494, GAN: 0.739171), D: 0.664209
Error using dlnetwork/forward
Layer 'inputImage': Invalid input data. Invalid size of channel dimension. Layer expects input with channel dimension size
6 but received input with size 3.

Error in p2p.vis.TrainingPlot/updateImages (line 62)
            output = tanh(generator.forward(obj.ExampleInputs));

Error in p2p.vis.TrainingPlot/update (line 47)
            obj.updateImages(generator)

Error in p2p.train (line 100)
                    trainingPlot.update(logArgs{:}, g);

Error in trainDepth (line 11)
p2pModel = p2p.train(labelFolder, targetFolder,options);