JisongXie / StegaStamp_pytorch

StegaStamp of pytorch version
MIT License
62 stars 21 forks source link

Problems in training #5

Open ghost opened 2 years ago

ghost commented 2 years ago

Hello, I've been studying your code these days. I am very grateful for you because the code is very helpful to me, but I found some small problems in the training process:

  1. I noticed that you use the code in model.py https://github.com/JisongXie/StegaStamp_pytorch/blob/184642e61d2fa95bcaf4131e64963202267c875d/model.py#L264-L290 which will make the pixel value of encoded_image exceed the boundary (0,1), which will have the image producing noise when displaying, for example: image

Since the image trained by the program itself allows a certain color gamut distortion, I think it is necessary to add an torch.clamp sentence to reduce it to the boundary of (0,1) before being input into the decoded net part, and if you agree, I would like to submit a merge request about this.

  1. The original author's program removes the image loss function from the total loss function in the first 1500 iterations, so as to advance the bit accuracy to a certain extent, but I find that this increase occurs randomly in your program, which means the bit accuracy will increase in not every training in the first 1500 rounds - they sometimes seem to fluctuate around 0.5. Of course, this is not a big trouble. My current solution is to detect the average value of 10 rounds before and after the 1500th iteration, set a threshold of 0.7, and restart training if it is lower than the threshold. But I am still confused that, have you ever encountered this situation? How do you usually solve it?
JisongXie commented 2 years ago

@Woodley-Griffith Hi, very grateful for you to find out the problem in my code.

  1. The function torchgeometry.warp_perspective transforms the source image using the specified transformation matrix. It seems some small fluctuation will make the transformed result overflow the range [0, 1]. I agree that it is necessary to add an torch.clamp, which is more reasonable and careful. More merge requests are very welcome, to make the code work better.
  2. I remember that I have never encountered this situation. This is so strange. So, according to what you say, sometimes the training is successful, and sometimes the bit accuracy fluctuate around 0.5, and training is fail? It seems that maybe the initialization of model weight is to blame? I think your solution can solve the problem temporarily, but a more reasonable solution needs to be found.
sunutf commented 2 years ago

@Woodley-Griffith Nice suggestion and I also agree with what you saying, is there any update about this issue?

I can not find a commit about this issue.

JisongXie commented 2 years ago

@sunutf not update and test yet

lschirmer commented 2 years ago

@Woodley-Griffith Hi, very grateful for you to find out the problem in my code.

1. The function `torchgeometry.warp_perspective` transforms the source image using the specified transformation matrix. It seems some small fluctuation will make the transformed result overflow the range [0, 1]. I agree that it is necessary to add an `torch.clamp`, which is more reasonable and careful. More merge requests are very welcome, to make the code work better.

2. I remember that I have never encountered this situation. This is so strange. So, according to what you say, sometimes the training is successful, and sometimes the bit accuracy fluctuate around 0.5, and training is fail? It seems that maybe the initialization of model weight is to blame? I think your solution can solve the problem temporarily, but a more reasonable solution needs to be found.

The problem related to accuracy is directly related to the weights initialization and the spatial transformer block. I just removed them, and it works well. However, we still need to deal with warping.

Asphocarp commented 4 months ago

@lschirmer Hi there, your reply was very helpful. Just to clarify, did you mean you removed the entire spatial transformer block, like removing this?

# in class StegaStampDecoder(nn.Module):
self.stn = SpatialTransformerNetwork()
...
transformed_image = self.stn(image)

The original paper seems to emphasize using a spatial transformer: "A spatial transformer network [24] is used to develop robustness against small perspective changes that are introduced while capturing and rectifying the encoded image."

Git-CYQ commented 4 months ago

Using a mobile phone to shoot encoded pictures on the screen, the decoding level is much lower than the tensorflow model by author(tancik),someone has this problem?