herobd / handwriting_line_generation

Code for BMVC2020 paper "Text and Style Conditioned GAN for Generation of Offline Handwriting Lines"
Other
66 stars 28 forks source link

CUDA out of memory #22

Open AICampB4 opened 2 years ago

AICampB4 commented 2 years ago

Hi, thanks a lot your efforts. It's a greatwork. I was trying to train your model with CVL dataset on Google Colab Pro. I've transform the image format from .tif to .png and feed to the model. But after 200 iterations, the error Cuda out of memory appears. This is the output: `NumExpr defaulting to 4 threads. /usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 6 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) /usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 6 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) model style_extrator Trainable parameters: 10136912 HWWithStyle( (hwr): CNNOnlyHWR( (cnn): Sequential( (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu0): ReLU(inplace=True) (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu1): ReLU(inplace=True) (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu2): ReLU(inplace=True) (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu3): ReLU(inplace=True) (pooling2): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False) (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu4): ReLU(inplace=True) (conv5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) (relu5): ReLU(inplace=True) (pooling3): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False) (conv6): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) (batchnorm6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu6): ReLU(inplace=True) ) (cnn1d): Sequential( (0): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,)) (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,)) (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) (6): Conv1d(512, 512, kernel_size=(3,), stride=(1,)) (7): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (8): ReLU(inplace=True) (9): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,)) (10): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (11): ReLU(inplace=True) (12): Conv1d(512, 80, kernel_size=(3,), stride=(1,)) (13): LogSoftmax(dim=1) ) ) ) Begin training WARNING: upsampling image to fit size /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) WARNING: upsampling image to fit size WARNING: upsampling image to fit size WARNING: upsampling image to fit size WARNING: upsampling image to fit size WARNING: upsampling image to fit size Train iteration: 100, loss: 3.290166, recogLoss: 3.290166, CER: 1.000000, WER: 1.000000, Train iteration: 200, loss: 2.875033, recogLoss: 2.875033, CER: 1.000000, WER: 1.000000, sec_per_iter: 0.623328, avg_loss: 3.082350, avg_recogLoss: 3.082350, avg_CER: 1.000000, avg_WER: 1.000000, Traceback (most recent call last): File "train.py", line 133, in main(config, args.resume) File "train.py", line 79, in main trainer.train() File "/content/drive/.shortcut-targets-by-id/1gLhWu0Me1satHwX83jnDJLd9nHAl9Mp6/mockproject/paper1/base/base_trainer.py", line 219, in train result = self._train_iteration(self.iteration) File "/content/drive/.shortcut-targets-by-id/1gLhWu0Me1satHwX83jnDJLd9nHAl9Mp6/mockproject/paper1/trainer/hw_with_style_trainer.py", line 378, in _train_iteration pred, recon, losses = self.run(instance) File "/content/drive/.shortcut-targets-by-id/1gLhWu0Me1satHwX83jnDJLd9nHAl9Mp6/mockproject/paper1/trainer/hw_with_style_trainer.py", line 736, in run pred = self.model.hwr(image, style) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, kwargs) File "/content/drive/.shortcut-targets-by-id/1gLhWu0Me1satHwX83jnDJLd9nHAl9Mp6/mockproject/paper1/model/cnn_only_hwr.py", line 131, in forward conv = self.cnn(input) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, *kwargs) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py", line 443, in forward return self._conv_forward(input, self.weight, self.bias) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py", line 440, in _conv_forward self.padding, self.dilation, self.groups) RuntimeError: CUDA out of memory. Tried to allocate 3.33 GiB (GPU 0; 15.90 GiB total capacity; 11.91 GiB already allocated; 3.02 GiB free; 11.97 GiB reserved in total by PyTorch)

MicrosoftTeams-image

I don't change anything in your architechture and loss functions. Could you advice something me to fix this issue. Thank you again Sorry for my bad english.

herobd commented 2 years ago

This is just trying to train the handwriting recognition network, right? Things to try:

  1. Be sure the images are getting resized correctly (64 pix height)
  2. GPU size. I'd always have nvidia-smi as the first cell is your Colab notebooks so you know what GPU has been assigned to you. I trained the model with an 11GB GPU, although the CVL dataset may have longer images.
  3. Adjust the max_width in the data_loader of the config. The crash occurs on a particularly long image; this parameter will force the resizing to reduce the width to atleast this (it defaults to 3000). BE CAREFUL though, as these adjusted images now are a bit outside of the distribution of the rest of the images. (It may be better to just throw them out...)
  4. Reduce the batch size. Although be careful as the config has the CRNN using batch norm. You can change this to groupnorm if you need to drop the batch size significantly.
  5. Make the model smaller. As the handwriting recognition network is only an auxilary network, reducing the number of channels shouldn't be terrible. Change the nm variable in handwriting_line_generation/model/cnn_only_hwr.py (Assuming you are using the CNN only handwriting network)
AICampB4 commented 2 years ago

Thank you very much, I'll try those things and response to you soon. Have a good day!

AICampB4 commented 2 years ago

Hello again! You're right, the problem is the width of cvl. There was the line bounding box of cvl. Some of lines were not bounding correctly and when pass into the model, their width explode significantly. So I remove all of them out of database and the model worked well.
Sorry for this inconvenient, Could you give me some advice for using latin language like vietnamese? I added a VN_charset.json into folder data which contains english & latin characters and theirs idx. Thank you very much.

herobd commented 2 years ago

The charset file defines the characters the model can see and produce. Besure each model component gets data_loader: char_file updated in the config and are trained. Also be sure to update the text file (trainer: text_data in config) when training the generator to have text from the desired output language.

AICampB4 commented 2 years ago

Thank you so much, Sir

AICampB4 commented 2 years ago

Sorry for disturbing again, sir. Now, I have done training the model but I don't know how to calculate FID scores as in your paper. Could you give me some instruction please. THank for your help.

herobd commented 2 years ago

I had modified someones FID code to resize handwriting images properly: https://github.com/herobd/pytorch-fid-handwriting

You'll need to generate a bunch of handwriting images (which I think was the point of the Random option in generate.py and then have another directory with the (testing) dataset line images.

AICampB4 commented 2 years ago

Thank you very much, Sir.

AICampB4 commented 2 years ago

Hello again, I've change the charset.json file with the character of my language handwritten text database. Then I started the recognizer training process, but I got all loss is nan at the first iteration and the process stopped. The encoder has the same result as recognizer. Could you give me some advices, please. Thank you in advance.

herobd commented 2 years ago

Hmm, double check the charset.json to be sure it looks like mine (especially starting at index of 1 instead of 0). If nothing's wrong I would print the inputs to the CTC loss call and be sure they look fine. You could even run my original set up to get an example of what the inputs should be. Also, this is unlikely, but if the predictions given to the CTC loss is shorter than the target, that causes problems (CTC can't be computed properly as it assumes a longer input).

AICampB4 commented 2 years ago

This is my charset.json : { "char_to_idx": { " ": 1, "!": 2, "\"": 3, "#": 4, "&": 5, "'": 6, "(": 7, ")": 8, "": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "?": 27, "A": 28, "B": 29, "C": 30, "D": 31, "E": 32, "F": 33, "G": 34, "H": 35, "I": 36, "J": 37, "K": 38, "L": 39, "M": 40, "N": 41, "O": 42, "P": 43, "Q": 44, "R": 45, "S": 46, "T": 47, "U": 48, "V": 49, "W": 50, "X": 51, "Y": 52, "Z": 53, "a": 54, "b": 55, "c": 56, "d": 57, "e": 58, "f": 59, "g": 60, "h": 61, "i": 62, "j": 63, "k": 64, "l": 65, "m": 66, "n": 67, "o": 68, "p": 69, "q": 70, "r": 71, "s": 72, "t": 73, "u": 74, "v": 75, "w": 76, "x": 77, "y": 78, "z": 79, "\u00e1": 80, "\u00e0": 81, "\u1ea3": 83, "\u00e3": 84, "\u1ea1": 85, "\u0103": 86, "\u1eaf": 87, "\u1eb1": 88, "\u1eb3": 89, "\u1eb5": 90, "\u1eb7": 91, "\u00e2": 92, "\u1ea5": 93, "\u1ea7": 94, "\u1ea9": 95, "\u1eab": 96, "\u1ead": 97, "\u00ed": 98, "\u00ec": 99, "\u1ec9": 100, "\u0129": 101, "\u1ecb": 102, "\u00fa": 103, "\u00f9": 104, "\u1ee7": 105, "\u0169": 106, "\u1ee5": 107, "\u01b0": 108, "\u1ee9": 109, "\u1eeb": 110, "\u1eed": 111, "\u1eef": 112, "\u1ef1": 113, "\u00f3": 114, "\u00f2": 115, "\u1ecf": 116, "\u00f5": 117, "\u1ecd": 118, "\u00f4": 119, "\u1ed1": 120, "\u1ed3": 121, "\u1ed5": 122, "\u1ed7": 123, "\u1ed9": 124, "\u01a1": 125, "\u1edb": 126, "\u1edd": 127, "\u1edf": 128, "\u1ee1": 129, "\u1ee3": 130, "\u00e9": 131, "\u00e8": 132, "\u1ebb": 133, "\u1ebd": 134, "\u1eb9": 135, "\u00ea": 136, "\u1ebf": 137, "\u1ec1": 138, "\u1ec3": 139, "\u1ec5": 140, "\u1ec7": 141, "\u00fd": 142, "\u1ef3": 143, "\u1ef7": 144, "\u1ef9": 145, "\u1ef5": 146, "\u0111": 147, "\u00c1": 148, "\u00c0": 149, "\u1ea2": 151, "\u00c3": 152, "\u1ea0": 153, "\u0102": 154, "\u1eae": 155, "\u1eb0": 156, "\u1eb2": 157, "\u1eb4": 158, "\u1eb6": 159, "\u00c2": 160, "\u1ea4": 161, "\u1ea6": 162, "\u1ea8": 163, "\u1eaa": 164, "\u1eac": 165, "\u00cd": 166, "\u00cc": 167, "\u1ec8": 168, "\u0128": 169, "\u1eca": 170, "\u00da": 171, "\u00d9": 172, "\u1ee6": 173, "\u0168": 174, "\u1ee4": 175, "\u01af": 176, "\u1ee8": 177, "\u1eea": 178, "\u1eec": 179, "\u1eee": 180, "\u1ef0": 181, "\u00d3": 182, "\u00d2": 183, "\u1ece": 184, "\u00d5": 185, "\u1ecc": 186, "\u00d4": 187, "\u1ed0": 188, "\u1ed2": 189, "\u1ed4": 190, "\u1ed6": 191, "\u1ed8": 192, "\u01a0": 193, "\u1eda": 194, "\u1edc": 195, "\u1ede": 196, "\u1ee0": 197, "\u1ee2": 198, "\u00c9": 199, "\u00c8": 200, "\u1eba": 201, "\u1ebc": 202, "\u1eb8": 203, "\u00ca": 204, "\u1ebe": 205, "\u1ec0": 206, "\u1ec2": 207, "\u1ec4": 208, "\u1ec6": 209, "\u00dd": 210, "\u1ef2": 211, "\u1ef6": 212, "\u1ef8": 213, "\u1ef4": 214, "\u0110": 215}, "idx_to_char": { "1": " ", "2": "!", "3": "\"", "4": "#", "5": "&", "6": "'", "7": "(", "8": ")", "9": "", "10": "+", "11": ",", "12": "-", "13": ".", "14": "/", "15": "0", "16": "1", "17": "2", "18": "3", "19": "4", "20": "5", "21": "6", "22": "7", "23": "8", "24": "9", "25": ":", "26": ";", "27": "?", "28": "A", "29": "B", "30": "C", "31": "D", "32": "E", "33": "F", "34": "G", "35": "H", "36": "I", "37": "J", "38": "K", "39": "L", "40": "M", "41": "N", "42": "O", "43": "P", "44": "Q", "45": "R", "46": "S", "47": "T", "48": "U", "49": "V", "50": "W", "51": "X", "52": "Y", "53": "Z", "54": "a", "55": "b", "56": "c", "57": "d", "58": "e", "59": "f", "60": "g", "61": "h", "62": "i", "63": "j", "64": "k", "65": "l", "66": "m", "67": "n", "68": "o", "69": "p", "70": "q", "71": "r", "72": "s", "73": "t", "74": "u", "75": "v", "76": "w", "77": "x", "78": "y", "79": "z", "80": "\u00e1", "81": "\u00e0", "82": "\u1ea3", "83": "\u1ea3", "84": "\u00e3", "85": "\u1ea1", "86": "\u0103", "87": "\u1eaf", "88": "\u1eb1", "89": "\u1eb3", "90": "\u1eb5", "91": "\u1eb7", "92": "\u00e2", "93": "\u1ea5", "94": "\u1ea7", "95": "\u1ea9", "96": "\u1eab", "97": "\u1ead", "98": "\u00ed", "99": "\u00ec", "100": "\u1ec9", "101": "\u0129", "102": "\u1ecb", "103": "\u00fa", "104": "\u00f9", "105": "\u1ee7", "106": "\u0169", "107": "\u1ee5", "108": "\u01b0", "109": "\u1ee9", "110": "\u1eeb", "111": "\u1eed", "112": "\u1eef", "113": "\u1ef1", "114": "\u00f3", "115": "\u00f2", "116": "\u1ecf", "117": "\u00f5", "118": "\u1ecd", "119": "\u00f4", "120": "\u1ed1", "121": "\u1ed3", "122": "\u1ed5", "123": "\u1ed7", "124": "\u1ed9", "125": "\u01a1", "126": "\u1edb", "127": "\u1edd", "128": "\u1edf", "129": "\u1ee1", "130": "\u1ee3", "131": "\u00e9", "132": "\u00e8", "133": "\u1ebb", "134": "\u1ebd", "135": "\u1eb9", "136": "\u00ea", "137": "\u1ebf", "138": "\u1ec1", "139": "\u1ec3", "140": "\u1ec5", "141": "\u1ec7", "142": "\u00fd", "143": "\u1ef3", "144": "\u1ef7", "145": "\u1ef9", "146": "\u1ef5", "147": "\u0111", "148": "\u00c1", "149": "\u00c0", "150": "\u1ea2", "151": "\u1ea2", "152": "\u00c3", "153": "\u1ea0", "154": "\u0102", "155": "\u1eae", "156": "\u1eb0", "157": "\u1eb2", "158": "\u1eb4", "159": "\u1eb6", "160": "\u00c2", "161": "\u1ea4", "162": "\u1ea6", "163": "\u1ea8", "164": "\u1eaa", "165": "\u1eac", "166": "\u00cd", "167": "\u00cc", "168": "\u1ec8", "169": "\u0128", "170": "\u1eca", "171": "\u00da", "172": "\u00d9", "173": "\u1ee6", "174": "\u0168", "175": "\u1ee4", "176": "\u01af", "177": "\u1ee8", "178": "\u1eea", "179": "\u1eec", "180": "\u1eee", "181": "\u1ef0", "182": "\u00d3", "183": "\u00d2", "184": "\u1ece", "185": "\u00d5", "186": "\u1ecc", "187": "\u00d4", "188": "\u1ed0", "189": "\u1ed2", "190": "\u1ed4", "191": "\u1ed6", "192": "\u1ed8", "193": "\u01a0", "194": "\u1eda", "195": "\u1edc", "196": "\u1ede", "197": "\u1ee0", "198": "\u1ee2", "199": "\u00c9", "200": "\u00c8", "201": "\u1eba", "202": "\u1ebc", "203": "\u1eb8", "204": "\u00ca", "205": "\u1ebe", "206": "\u1ec0", "207": "\u1ec2", "208": "\u1ec4", "209": "\u1ec6", "210": "\u00dd", "211": "\u1ef2", "212": "\u1ef6", "213": "\u1ef8", "214": "\u1ef4", "215": "\u0110" }} I just append our charset into your original charset.json and save in another file. The first index of course is 1. And I only replace the new charset.json in the config files. The setup is the same, and the result is nan of all losses.

AICampB4 commented 2 years ago

I've went to model/loss.py/CTCloss and print input_length and target_length and got this result: input_len: tensor([262, 262, 262, 262, 262, 262, 262, 262, 262, 262, 262, 262, 262, 262, 262, 262], dtype=torch.int32) target_len: tensor([32, 32, 7, 48, 23, 30, 34, 23, 7, 19, 30, 4, 7, 41, 35, 25], dtype=torch.int32 It's always input_length > target_length. Could you explain what are input length and target length and why target length is so short. Is the target_length number fine ?

AICampB4 commented 2 years ago

And some of input is nan

AICampB4 commented 2 years ago

I've run debug mode on PyCharm in recognizer training and I found out that problem is the predict label. After prediction and converting, label2string_single, the output character list only contain characters in English despite we're using a latin language database. I think that the model for IAM dataset not have capability for learning the latin character. And we consider to use the model for RIMES dataset beacause of our langague has some French characters. Another way is that we'll add some layers to the model to make it more deeply. Could you give me some advices. Thank you in advance

herobd commented 2 years ago

Input length is the image width (after downsampling), the target length is the length of the target string. The model predicts something for each image unit (after the network downsamples it), which should be more predictions than the target string (as each written character is multiple image units long).

The capacity of the network won't be causing the NaNs. Trace where those NaNs back to where they originate.

The model for the IAM and RIMES datasets are identical with the exception of the output classes. Be sure in the config that model -> num_class is set to the number of characters in charset.json. (looks like 216 for yours?)

AICampB4 commented 2 years ago

Oh, I didn't notice that num_class in the config. I changed the num_class to 216 and the model run smoothly. I really appreciate it. I'm really grateful.

AICampB4 commented 2 years ago

Hello, we have to disturb you again. We have train both the recog and the encoder, they're ok, but when we train generator at 40k iterations, the recon_sample and recon_gt_mask are all blank, only the recon_gt has text. Could you give us some advice for this problem. Thank you very much.

herobd commented 2 years ago

You're referring to the output images? Your're sure it's writing new images? I've never seen it generate blank images. Generally at the beginning of training it generates weird blurry images.

AICampB4 commented 2 years ago

<img width="655" alt="Capture1" src="https://user-images.githubusercontent.com/91046245/145511900-eaa63e97-0e9b-4dfd-a284-5ecce2644c84.

Capture2

PNG"> Only the recon_gt have words

AICampB4 commented 2 years ago
Capture1
AICampB4 commented 2 years ago

this is our text file

VN1.txt

herobd commented 2 years ago

What do the losses look like? You can use python graph.py -c path/to/latest/snapshot.pth to graph them.

herobd commented 2 years ago

Also, do you have any of the generated ("samples") images from the beginning of training?

AICampB4 commented 2 years ago

The image below is generated images.
image

herobd commented 2 years ago

Maybe turn down the learning rate? It would be good to see what the loss is doing.

AICampB4 commented 2 years ago

this is the result after running code python graph.py -c path/to/latest/snapshot.pth

loaded iteration 40750 summed loss max: 12.960875740402116, min 0.0 autoLoss max: 0.4286576807498932, min 0.0250049140304327 perceptualLoss max: 3.558936834335327, min 1.9109687805175781 reconRecogLoss max: 9.085136844078079e-05, min 4.423267000674969e-06 generatorLoss max: 9.976031303405762, min -0.29183000326156616 CER max: 0, min 0 WER max: 0, min 0 discriminatorLoss max: 0.013461818918585777, min 0.0 sec_per_iter max: 2.9882827183921474, min 2.3992390221759705 avg_loss max: 6.829719831926175, min 2.0683593205882085 avg_genRecogLoss max: 0.000307868676725775, min 0.00013851409091148524 avg_generatorLoss max: 5.6224011726379395, min 0.9937937784194947 avg_CER max: 0.0, min 0.0 avg_WER max: 0.0, min 0.0 avg_autoLoss max: 0.10185332185029984, min 0.01183332721143961 avg_perceptualLoss max: 0.9597496643066407, min 0.6026059198379516 avg_reconRecogLoss max: 5.132893413247075e-06, min 2.0161328866379337e-06 avg_discriminatorLoss max: 0.07236941555517842, min 0.0 avg_countLoss max: 0.6470937192440033, min 0.23567486727237702 genRecogLoss max: 0.0012050500372424722, min 0.00044575537322089076 countLoss max: 11.75268268585205, min 0.5220853686332703 val_loss max: 5.792953810724987, min 4.854049627309792 val_CER max: 0.0, min 0.0 val_WER max: 0.0, min 0.0 val_autoLoss max: 0.1761924303182871, min 0.04305112739081989 val_countLoss max: 2.304927395859567, min 2.159969479435741 val_perceptualLoss max: 3.311824187193767, min 2.6480215297375533 val_reconRecogLoss max: 1.3890943866481573e-05, min 8.966553432945605e-06

image image

AICampB4 commented 2 years ago

can you give me advice about this? thank you so much

herobd commented 2 years ago

The losses seem more chaotic than what the IAM model had (you can use graph.py on the pre-trained snapshot to see). I would definitly turn down the learning rate and see if that helps.

AICampB4 commented 2 years ago

Hello, I've dropped the lr 10 times but after about 30k iters, the gen_sample images become blank. Those are lr I've change and the gen_sample images MicrosoftTeams-image (1)

Capture3
herobd commented 2 years ago

Ok, so it is generating good things at the beginning. Try droping the main optimizer's learning rate, but keep the discriminators the same. And the reverse. The model is hitting a very common problem with GANs which is unstable training. I'll admit I didn't run into this much after settling on good hyperparameters, so I'm surprised it's happening with a new dataset. In general, getting the right balance of the learning rates as well as how often the discriminator is getting updated should do the trick.