remicres / otbtf

Deep learning with otb (mirror of https://forgemia.inra.fr/remi.cresson/otbtf)
Apache License 2.0
159 stars 39 forks source link

Combine two training datasets #84

Closed bjyberg closed 1 year ago

bjyberg commented 1 year ago

Hi Remi (or whomever else knows), I have data from two different sites which I could like to use to create a single training dataset for a CNN. Is this possible using the patch extraction tool? If not, is it possible to input two sets of patches to train a CNN within OTBTF? Cheers, Brayden

remicres commented 1 year ago

Hi Brayden,

Sorry for the long reply! This is basically part of the new doc, updated very soon.

TL;DR

The PatchesExtraction application performs the extraction of patches in one or multiple sources, that cover the same area. When you have different sites, then you have to run PatchesExtraction over each site. Then you can use the OTBTF Python API. It is quite new, and we made a huge step toward a very easy use with tf.keras. The documentation is incoming, you can check the work in progress here (it's nearly finished, maybe just a few refactoring without any impact on the things)

We should publish the new release (3.3.0) before wednesday! In the mean time, you can give a try on the development feature branch docker image: docker pull gitlab-registry.irstea.fr/remi.cresson/otbtf:29-add-a-model-class-in-otbtf (CPU only... the GPU flavored docker images will be publish on dockerhub with the 3.3.0 release next week)

Details

Here is what you can do.

Say you have two XS images acquired over A and B sites, with some terrain truth. First you have extracted your patches, and the output of PatchesExtraction results in the following files:

Then you can use the OTBTF Python API:

from otbtf import DatasetFromPatchesImages
ds = DatasetFromPatchesImages(filenames_dict={"xs_patches": ["xs_A", "xs_B"], 
                                              "labels_patches": ["labels_A", "labels_B"]})
tf_ds = ds.get_tf_dataset(batch_size=8)  # This is a TensorFlow dataset !

Such TensorFlow dataset can be used using the TensorFlow v1 or v2 API. If you want to use keras, it's more convenient that your dataset splits its samples in two (inputs_dict, target_dict). You just have to pass the target name to the targets_keys parameter of get_tf_dataset to get the samples in the right format for keras:

# This is a TensorFlow dataset, ready to be used with keras !
tf_ds = ds.get_tf_dataset(batch_size=8, targets_keys="predictions")

From here, you can build and train your model with Keras. But I strongly encourage you to use the new otbtf.model.ModelBase, it really helps to build a model which has the good settings to be used later in TensorflowModelServe for inference in production. For instance, here is a small model implementation:

from otbtf.model import ModelBase
import tensorflow as tf

class MyModel(ModelBase):
    def normalize_inputs(self, inputs):
        """ Normalize the remote sensing image into [0, 1] range prior to convolutions, etc. 
            Optional, but quite convenient in production. Return dict of normalized inputs """
        return {"xs": tf.cast(inputs["xs"], tf.float32) * 0.0001}  # Say input is uint16 

    def get_outputs(self, normalized_inputs):
        """ Build the model. Return the dict of outputs (must include target!) """
        inp = normalized_inputs["xs"]
        net = tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation="relu")(inp)
        net = tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation="relu")(net)
        net = tf.keras.layers.Conv2D(filters=N_CLASSES, kernel_size=3)(net)
        predictions = tf.keras.layers.Softmax(name="predictions_softmax")(net)
        return {"predictions": predictions}**

One last thing is to prepare the labels in one hot encoding, in the case you have integer values in your labels. For instance, you can do:

def dataset_preprocessing_fn(examples):
    def _to_categorical(x):
        return tf.one_hot(tf.squeeze(tf.cast(x, tf.int32), axis=-1), depth=N_CLASSES)
    return {"xs": examples["xs_patches"], "predictions": _to_categorical(examples["labels_patches"])}

Then you ask the dataset to (1) use this function to prepare the input patches so they match the model input and target names, (2) split the dataset examples in (inputs_dict, target_dict) so it can be used in keras logic.

tf_ds = ds.get_tf_dataset(batch_size=batch_size, 
                         preprocessing_fn=dataset_preprocessing_fn,
                         targets_keys="predictions")

Then you just have to train the model like any other keras model.

strategy = tf.distribute.MirroredStrategy()  # For single or multi-GPUs
with strategy.scope():
    model = MyModel(dataset_element_spec=tf_ds.element_spec)
    model.compile(loss=tf.keras.losses.CategoricalCrossentropy(),
                  optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4))
    model.fit(ds_train, epochs=params.nb_epochs, validation_data=ds_valid)
    model.save("/path/to/savedmodel")

In production you want to use the cropped outputs (named <node_name>_crop<int_crop_value> (in doubt, you can retrieve their names using saved_model_cli show --dir /path/to/savedmodel --all) to avoid blocking artifacts of your fully convolutional model (that's another story, we won't cover it here, but you know that convolutions have side effects. The otbtf.model.ModelBase comes with automatically added extra outputs to avoid blocking artifacts in inference production pipelines).

In CLI you will typically do:

otbcli_TensorflowModelServe               \
-source1.il "fullsize_xs_image.tif"       \
-source1.rfieldx 64                       \
-source1.rfieldy 64                       \
-source1.placeholder "xs"                 \
-model.dir "/path/to/savedmodel"          \
-model.fullyconv on                       \
-output.names "prediction_softmax_crop16" \
-output.efieldx 32                        \
-output.efieldy 32                        \
-out "output.tif"

Hope that helps, and that you will enjoy the new API. Also do not hesitate to tell us if something isn't going as expected, or if you found some nasty bug :+1:

Regards,

Rémi

bjyberg commented 1 year ago

Amazing, Thank you so much!!

remicres commented 1 year ago

Oops.

The link looks wrong. You go here and you should be able to pull the container

docker pull gitlab-registry.irstea.fr/remi.cresson/otbtf:29-add-a-model-class-in-otbtf
bjyberg commented 1 year ago

Hi Rémi, A few more questions if you have a moment. I am using 2 datasets from each site (multispectral and DSM). I think I have the general architecture of the model (based on the example from your book, it's incredibly helpful!), but I'm struggling to determine the best method of creating a dataset + implementing the labels for each site. I would assume that I would want to create two datasets in total, one for multispectral data and one for the DSM, as the patches are different sizes due to differing dataset resolutions. Is that correct? If so, how would I implement the two labeled patches for the prediction? Is this still possible using ModelBase?

Thanks for all the help, Brayden

remicres commented 1 year ago

Hi @bjyberg ,

You have to create one single dataset using PatchesExtraction with OTB_TF_NSOURCES=3 (3 sources: labels, DSM, and XS if you have densely annotated labels like in "semantic segmentation" chapter of the book, 2 sources: labels, DSM if you have sparsely annotated poins like in early chapters).

Deep nets often perform internally some downsampling and upsampling of the computed features. Generally, these resampling operations work with ratios of powers of 2 (i.e. halving or doubling the feature maps in spatial dimensions). So generally talking, for fully convolutional models, you better have to work with sources which have physical spacing ratios of 2, 4, 8, etc. In your case, what I would do prior to PatchesExtraction (and later TensorflowModelServe at inference time), is resampling your DSM and XS at the highest resolution (using the SuperImpose OTB application for instance). If the DSM resolution is, say 2,5 times greater than the XS, then you could resample it at exactly 2 times the XS resolution, and tweak your network to input this source after the first feature map downsampling (you can take a look in our paper here we did that for optical image reconstruction, using XS and DSM at different resolution (10m for XS, 20m for DSM)). However this approach is more complex since you will have to deal with the adaptation of the model to you inputs resolutions.

If your XS and DSM resolutions are close, I would recommend to resample the coarsest at the higher resolution, and concatenate the inputs in the network before doing any convolution.

Hope that helps

bjyberg commented 1 year ago

Hi Remi, Sorry to keep bothering you... I'm trying to use this for my dissertation, and am still really struggling. I simplified things, and am just trying to get a model using the multi-spec at each site. I've finally made a basic CNN that I can train, but when I try to use the otbcli_TensorflowModelServe, I receive the attached error. I, guessing this has something to do with the -output.efieldx values? They are currently set at 1, with the input values set at 12 (the size of my patches). The architecture for the basic CNN is: inp = normalized_inputs["xs"] con1 = tf.keras.layers.Conv2D(filters=16, kernel_size=3, activation="relu")(inp) con2 = tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation="relu")(con1) con3 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation="relu")(con2) con4 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation="relu")(con3) con5 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation="relu")(con4) con6 = tf.keras.layers.Conv2D(filters=N_CLASSES, kernel_size=2, activation="relu")(con5) predictions = tf.keras.layers.Softmax(name="predictions_softmax")(con6) return {"predictions": predictions} Screenshot_2022-07-26_03-55-34

remicres commented 1 year ago

Hi @bjyberg ,

I believe that you have to enable the fullyconvolutional mode in TensorflowModelServe, with -model.fullyconv on ! Here you are in patch-based mode, so the application forms 16x16=256 batches of patches. It is really slow to compute.

(Here it fails to work because it looks like the output tensor is not in the good shape for the patch-based mode, but I don't understand why)

Hope that helps

bjyberg commented 1 year ago

Hi Rémi, I get a similar error with fullyconvolutional mode on. The only difference is that the tensor shape goes from {256, 0, 0, 5} in the patch-based error and {1, 0, 0, 5} in the fully convolutional error. Any other suggestions?

Cheers, Brayden

remicres commented 1 year ago

I don't know, it is a bit strange. I suspect that there is something wrong with your network. Are you sure to have posted the right code? If you have (0, 0) elements in (x, y), then I don't know how you could have trained successfully your network. I think that you should try with a receptive field of 13 instead of 12 since no element is in the output. I suspect that you have used a kernel of size 3 in the last convolution, instead of 2?

bjyberg commented 1 year ago

So I've re-run my model script, ensuring it matches the code I sent (it is definitely kernel=2 in my final convolution), and then tested it trying both 12 and 13 as the receptive field. I still receive the same error, quite odd. I also attempted to use the FCN from model 2 in the book (using 4 bands of my multispec data from 1 site), though I adjusted the number of classes in the code. This model trained (though the accuracy metrics seemed odd in hindsight...), but labeled everything as 0 for my output image. Could these issues be regarding the number of classes in my (5) or my input/training data (5 bands)?

bjyberg commented 1 year ago

If it helps, here is my full python script. Sorry, it's not terribly pretty/well organized, I'm rather new to Python haha! Screenshot_2022-07-26_16-26-36

remicres commented 1 year ago

Hi Rémi, I get a similar error with fullyconvolutional mode on. The only difference is that the tensor shape goes from {256, 0, 0, 5} in the patch-based error and {1, 0, 0, 5} in the fully convolutional error. Any other suggestions?

Cheers, Brayden

What name did you used for the model output ? It should be "predictions_softmax"

bjyberg commented 1 year ago

Amazing, that seemed to work!! Such a simple fix... sorry I missed it! One final question, looking at my outputs, the accuracy appears to be really low... would an increased patch size, more layers, more training data, and adjusting the number of epochs be my best options to increase accuracy?

Thanks again for all the help, I truly appreciate it :)

remicres commented 1 year ago

I am glad that you made it work.

There is no general rule, but in short you should find for a given model, a good value for the batch size and for the learning rate. After that, the model might be tuned a bit but ultimately its performance is directly linked with the quality and quantity of available terrain truth data. Another useful rule of thumb for beginners is small dataset --> small model, big dataset --> big model. I cannot really help you more at this point, but you should find plenty of things in blogs, papers, etc

Good luck!

bjyberg commented 1 year ago

Sorry to keep bothering you Rémi - I'm getting what appears to be good results from my model, however, the issue comes when trying to map it. As an output, I get a 5 layer raster (though it looks like all layer values are about the same). The values are extremely small, ranging from 2.4011909902082e-23 to 8.9866404096028e-08. On inspection, it appears classes are determined the size of the number (ie. 1.19688e-15 and 5.18925e-15 are of the same class). I've played with the scaling in Qgis, though it still is not producing useful class separation. Any suggestions?

remicres commented 1 year ago

The number of neurons in the last layer must be = the number of classes

bjyberg commented 1 year ago

I have checked this on two different models (with 12x12 and 32x32 inputs) and receive the same result. Can you think of anything else that could be causing this? I've played around with different ModelServe parameters in the case I had missed something there; however, still no success. Would the addition of an argmax layer to my pre-trained model make any sort of difference? I wondered if it was an issue with the labels, but after trying multiple sets of patches, I got similar results...

bjyberg commented 1 year ago

Update: It appears it was something to do with my labels from the patch extraction tool. It appears there was some issue with the shapefile I had used for creating the data. The patch extraction tool was able to extract patches, but the labels had no values. After fixing this, it appears all is well!

It appears that the Patch Extraction tool was not extracting the class values from the shapefile points if they were characters. Once I converted the characters to numbers(1-5) in the shapefile, the extraction and labels worked fine. Not sure if this has been a prior issue, if it was a mistake on my part when exporting the shapefile, or just a weird chaotic occurrence, but it could be worth noting.

Cheers! Brayden