Open dq0309 opened 2 years ago
What compiler are you using? You may want to be careful with int128 when using non-GNU compilers.
You are right. The compilation is successfully done on Linux. Another question: I successfully run train.py, how to select the final model for compression without using validate dataset in train.py?
Pick any checkpoint after 5000 epochs. There will not be a significant difference.
It takes about 340 seconds/epoch, with a 1080 Ti. How much time do you spend on this training?
I don't remember it was this slow. Check if the you are fully utilizing the GPU. Are you using the DIV2K dataset? You may also reduce the training epochs when the training log shows convergence.
The GPU is fully utilized. The input images are 3619 images of 1600X1152X3.
That might be too many. I trained with 1600 256x256 patches per epoch. My training would take about 2 days. You may also want to reduce the number of epochs because it almost converges at 3000.
@dq0309 I used the DIV2k as the setting in the repo. The GPU 3080 10GB when channels=384, one epoch spends the 160s around. The batch size is 8. To run 3000 epochs may need 4 days. And in another Titan XP 12 GB, the time is longer. The time in the first stage every epoch is shorter, after the first stage the GPU memory allocated will highly increased.
@huzi96 Hi hu. As the training process is very time-consuming, I changed the number of channels to 192 as your network_low. It trained for almost 4 days with 4000 epochs, the bpp now is 1.5 and the corresponding psnr is 32 which is still far from the performance mentioned in your paper, is this normal?
@achel-x @huzi96 I trained the model with a single GPU (1080Ti). With 2 GPUs and the batch size set as 16. 190 seconds per epoch. Since the number of my own data set is 3619, it is maybe reasonable to train this model with 1500 epochs, which would take about 3 days.
@huzi96 Due to the limited GPUs, I plan to load your pre-trained weights and finetune this model with my own input images? Could you provide some pieces of advice about this finetune, e.g. learning rate and the training stages? Should I start training at stage 1 or start training at the last stage? Thanks!
If you fine-tune with pre-trained weights, you should directly go to the last stage. You would best keep the bit-rate range when you fine-tune with new data. However, given the same "lambda", different data has different rate-distortion tradeoff and consequently has a different range of bit-rate. So you may want to try different lambdas for a given pre-trained weight and your data. If you should alter the bit-rate range for a given pre-trained model, the best practice would be to fine-tune from a higher-bit-rate model to a lower bit-rate .
You can just keep the learning rate 1e-4. I found it makes no significant difference.
I get some errors when I load these pretrained models from tf platform. RuntimeError: Error(s) in loading state_dict for Net: size mismatch for a_model.transform.1.weight: copying a param with shape torch.Size([192, 3, 5, 5]) from checkpoint, the shape in current model is torch.Size([384, 3, 5, 5]). size mismatch for a_model.transform.1.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for a_model.transform.2.beta: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for a_model.transform.2.gamma: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismatch for a_model.transform.4.weight: copying a param with shape torch.Size([192, 192, 5, 5]) from checkpoint, the shape in current model is torch.Size([384, 384, 5, 5]). size mismatch for a_model.transform.4.bias: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for a_model.transform.5.beta: copying a param with shape torch.Size([192]) from checkpoint, the shape in current model is torch.Size([384]). size mismatch for a_model.transform.5.gamma: copying a param with shape torch.Size([192, 192]) from checkpoint, the shape in current model is torch.Size([384, 384]). size mismat Could you provide your pretrained models which are trained on torch platform?
This is the code that loads the model: if args.load_weights != '': model_type = args.model_type qp = args.qp net.apply(weight_init) sd = net.state_dict() d = load_tf_weights(sd, f'models/model{model_type}_qp{qp}.pk')
nd = {}
for k in d.keys():
if 'proj_head_z2' in k:
continue
nd[k] = d[k]
net.load_state_dict(nd, strict=False)
print('load weights')
else:
net.apply(weight_init)
print('training from scratch')
I figure it out. It seems I need to load high-QP models (trained on tf platform) to pre-train this model on the torch platform. And I still have some questions.
Thank you for your kind answering to my questions and I get it. It seems you only provide NetHigh in your train/train.py. And I need to add the NetLow in train/train.py for finetune. Am I right?
Yes
_self.sampler_3 = NeighborSample((b, 384, h // 8, w // 8)) self.proj_headz3 = ProjHead(384, 384) These codes are in NetHigh. Should I make some modifications to get the code for NetLow? In addition, these modified codes are as follows: _self.sampler_3 = NeighborSample((b, 192, h // 8, w // 8)) self.proj_headz3 = ProjHead(192, 192)
When I load the trained models and test, I get some IncompatibleKeys, which are as follows Is it because it is not necessary to load some parameters for the helping loss (auxiliary loss), during the evaluation?
Yes. Since you are loading the weights, you will go directly to the last stage. So you may just ignore the warning (use strict=False).
https://github.com/huzi96/Coarse2Fine-PyTorch/issues/11#issuecomment-1075973610 Is this right? Thanks!
Yes. You may check the netlow inference code for a reference.
A reminder about the reconstructed frames. In _AppEncDecTrained.py, you don't reconstruct the compressed frames and I can get the reconstructed frames with decompressing. Thus, just a reminder. Thanks a lot!
Could you provide the corresponding lambda of the provided models, including model0_qp1,model0_qp2,model0_qp3,model0_qp4,model0_qp5,model0_qp6,model0_qp7?
@dq0309
hi, have you got the values of lambda? If you have, could you please share them with me?
MSE: λ ∈ {1.2 × 10−3, 1.5 × 10−3, 2.5×10−3, 8×10−3, 1.5×10−2, 2.0×10−2, 3.0×10−2}
MS-SSIM loss: λ ∈ {10, 25, 45, 70, 100, 200, 300, 360}
This is also in the paper. Note that if you modified the model, bit-rate to lambda correspondence will change.
MSE: λ ∈ {1.2 × 10−3, 1.5 × 10−3, 2.5×10−3, 8×10−3, 1.5×10−2, 2.0×10−2, 3.0×10−2}
MS-SSIM loss: λ ∈ {10, 25, 45, 70, 100, 200, 300, 360}
This is also in the paper. Note that if you modified the model, bit-rate to lambda correspondence will change.
Appreciate your helpful reply! I am all gratitude.
There is something wrong with the compilation of the arithmetic coder.
module_arithmeticcoding.cpp:11:9: error: expected unqualified-id before 'int128' typedef int128 comint; ^
~~~ module_arithmeticcoding.cpp:12:7: error: 'comint' does not name a type const comint STATE_SIZE = 64; ^~module_arithmeticcoding.cpp:13:7: error: 'comint' does not name a type const comint MAX_RANGE = (comint)1 << STATE_SIZE; ^~module_arithmeticcoding.cpp:14:7: error: 'comint' does not name a type const comint MIN_RANGE = (MAX_RANGE >> 2) + 2; ^~module_arithmeticcoding.cpp:15:7: error: 'comint' does not name a type const comint MAX_TOTAL = MIN_RANGE; ^~module_arithmeticcoding.cpp:16:7: error: 'comint' does not name a type const comint MASK = MAX_RANGE - 1; ^~module_arithmeticcoding.cpp:17:7: error: 'comint' does not name a type const comint TOP_MASK = MAX_RANGE >> 1; ^~module_arithmeticcoding.cpp:18:7: error: 'comint' does not name a type const comint SECOND_MASK = TOP_MASK >> 1; ^~module_arithmeticcoding.cpp:142:13: error: 'comint' has not been declared int get(comint symbol) { ^~module_arithmeticcoding.cpp:155:5: error: 'comint' does not name a type comint get_low(comint symbol) { ^~module_arithmeticcoding.cpp:162:5: error: 'comint' does not name a type comint get_high(comint symbol) { ^~module_arithmeticcoding.cpp:169:24: error: 'comint' has not been declared void _check_symbol(comint symbol) { ^~module_arithmeticcoding.cpp:183:5: error: 'comint' does not name a type comint _low, _high; ^~module_arithmeticcoding.cpp:189:53: error: 'comint' has not been declared virtual void update(ModelFrequencyTable freqs, comint symbol) { ^~module_arithmeticcoding.cpp: In constructor 'ArithmeticCoderBase::ArithmeticCoderBase()': module_arithmeticcoding.cpp:184:27: error: class 'ArithmeticCoderBase' does not have any field named '_low' ArithmeticCoderBase():_low(0),_high(MASK) {} ^~~~ module_arithmeticcoding.cpp:184:35: error: class 'ArithmeticCoderBase' does not have any field named '_high' ArithmeticCoderBase():_low(0),_high(MASK) {} ^~~~~ module_arithmeticcoding.cpp:184:41: error: 'MASK' was not declared in this scope ArithmeticCoderBase():_low(0),_high(MASK) {} ^~~~ module_arithmeticcoding.cpp: In member function 'virtual void ArithmeticCoderBase::update(ModelFrequencyTable, int)': module_arithmeticcoding.cpp:190:9: error: 'comint' was not declared in this scope comint low=_low, high=_high; ^~module_arithmeticcoding.cpp:191:15: error: 'low' was not declared in this scope if ( (low >= high) || ((low & MASK) != low) || ((high & MASK) != high) ) { ^~~ module_arithmeticcoding.cpp:191:22: error: 'high' was not declared in this scope if ( (low >= high) || ((low & MASK) != low) || ((high & MASK) != high) ) { ^~~~ module_arithmeticcoding.cpp:191:39: error: 'MASK' was not declared in this scope if ( (low >= high) || ((low & MASK) != low) || ((high & MASK) != high) ) { ^~~~ modulearithmeticcoding.cpp:195:16: error: expected ';' before 'range' comint range_ = high - low + 1; ^~module_arithmeticcoding.cpp:196:15: error: 'MIN_RANGE' was not declared in this scope if (!(MINRANGE <= range && range_ <= MAX_RANGE)) { ^~~~~ modulearithmeticcoding.cpp:196:28: error: 'range' was not declared in this scope if (!(MINRANGE <= range && range_ <= MAX_RANGE)) { ^~module_arithmeticcoding.cpp:196:48: error: 'MAX_RANGE' was not declared in this scope if (!(MINRANGE <= range && range_ <= MAX_RANGE)) { ^~~~~ module_arithmeticcoding.cpp:201:16: error: expected ';' before 'total' comint total = freqs -> get_total(); ^~~~~ module_arithmeticcoding.cpp:202:16: error: expected ';' before 'symlow' comint symlow = freqs -> get_low(symbol); ^~module_arithmeticcoding.cpp:203:16: error: expected ';' before 'symhigh' comint symhigh = freqs -> get_high(symbol); ^~~ module_arithmeticcoding.cpp:204:13: error: 'symlow' was not declared in this scope if (symlow == symhigh) { ^~module_arithmeticcoding.cpp:204:23: error: 'symhigh' was not declared in this scope if (symlow == symhigh) { ^~~ module_arithmeticcoding.cpp:208:55: error: 'class ModelFrequencyTable' has no member named 'get_low' fprintf(stderr, "symlow: %ld", freqs->get_low(idx)); ^~~ module_arithmeticcoding.cpp:209:56: error: 'class ModelFrequencyTable' has no member named 'get_high'; did you mean 'set_sigma'? fprintf(stderr, "symhigh: %ld", freqs->get_high(idx)); ^~~~ module_arithmeticcoding.cpp:216:13: error: 'total' was not declared in this scope if (total > MAX_TOTAL) { ^~~~~ module_arithmeticcoding.cpp:216:21: error: 'MAX_TOTAL' was not declared in this scope if (total > MAX_TOTAL) { ^~~~~ modulearithmeticcoding.cpp:220:16: error: expected ';' before 'newlow' comint newlow = low + symlow * range / total; ^~modulearithmeticcoding.cpp:221:16: error: expected ';' before 'newhigh' comint newhigh = low + symhigh * range / total - 1; ^~~ module_arithmeticcoding.cpp:222:9: error: '_low' was not declared in this scope _low = newlow; ^~~~ module_arithmeticcoding.cpp:222:16: error: 'newlow' was not declared in this scope _low = newlow; ^~module_arithmeticcoding.cpp:223:9: error: '_high' was not declared in this scope _high = newhigh; ^~~~~ module_arithmeticcoding.cpp:223:17: error: 'newhigh' was not declared in this scope _high = newhigh; ^~~ module_arithmeticcoding.cpp:227:35: error: 'TOP_MASK' was not declared in this scope while ( ((_low ^ _high) & TOP_MASK) == 0 ) { ^~~~ module_arithmeticcoding.cpp:229:34: error: 'MASK' was not declared in this scope _low = (_low << 1) & MASK; ^~~~ module_arithmeticcoding.cpp:233:35: error: 'SECOND_MASK' was not declared in this scope while ((( _low & ~_high & SECOND_MASK )) != 0) { ^~~module_arithmeticcoding.cpp:235:35: error: 'MASK' was not declared in this scope _low = (_low << 1) & (MASK >> 1); ^~~~ module_arithmeticcoding.cpp:236:52: error: 'TOP_MASK' was not declared in this scope _high = ((_high << 1) & (MASK >> 1)) | TOP_MASK | 1; ^~~~ module_arithmeticcoding.cpp: At global scope: module_arithmeticcoding.cpp:246:5: error: 'comint' does not name a type comint num_underflow; ^~module_arithmeticcoding.cpp:252:44: error: 'comint' has not been declared void write(ModelFrequencyTable freqs, comint symbol) { ^~module_arithmeticcoding.cpp: In constructor 'ArithmeticEncoder::ArithmeticEncoder(BitOutputStream)': module_arithmeticcoding.cpp:249:9: error: 'num_underflow' was not declared in this scope num_underflow = 0; ^~~~~ module_arithmeticcoding.cpp: In member function 'virtual void ArithmeticEncoder::shift()': module_arithmeticcoding.cpp:261:9: error: 'comint' was not declared in this scope comint bit = _low >> (STATE_SIZE - 1); ^~module_arithmeticcoding.cpp:262:23: error: 'bit' was not declared in this scope output->write(bit); ^~~ module_arithmeticcoding.cpp:263:29: error: 'num_underflow' was not declared in this scope for (int i = 0; i < num_underflow; i++) { ^~~~~ module_arithmeticcoding.cpp:266:9: error: 'num_underflow' was not declared in this scope num_underflow = 0; ^~~~~ module_arithmeticcoding.cpp: In member function 'virtual void ArithmeticEncoder::underflow()': module_arithmeticcoding.cpp:270:9: error: 'num_underflow' was not declared in this scope num_underflow++; ^~~~~ module_arithmeticcoding.cpp: At global scope: module_arithmeticcoding.cpp:278:5: error: 'comint' does not name a type comint code; ^~module_arithmeticcoding.cpp:287:5: error: 'comint' does not name a type comint read(ModelFrequencyTable freqs) { ^~module_arithmeticcoding.cpp:358:5: error: 'comint' does not name a type comint read_code_bit() { ^~module_arithmeticcoding.cpp:364:5: error: 'comint' does not name a type comint getlow() {return _low;} ^~module_arithmeticcoding.cpp:365:5: error: 'comint' does not name a type comint gethigh() {return _high;} ^~module_arithmeticcoding.cpp: In constructor 'ArithmeticDecoder::ArithmeticDecoder(BitInputStream)': module_arithmeticcoding.cpp:281:9: error: 'code' was not declared in this scope code = 0; ^~~~ module_arithmeticcoding.cpp:282:29: error: 'STATE_SIZE' was not declared in this scope for (int i = 0; i < STATE_SIZE; i++) { ^~~~~~ module_arithmeticcoding.cpp:283:48: error: 'read_code_bit' was not declared in this scope code = (code << 1) | read_code_bit(); ^ module_arithmeticcoding.cpp: In member function 'virtual void ArithmeticDecoder::shift()': module_arithmeticcoding.cpp:350:43: error: 'read_code_bit' was not declared in this scope unsigned char bit = read_code_bit(); ^ module_arithmeticcoding.cpp:351:9: error: 'code' was not declared in this scope code = ((code << 1) & MASK ) | bit; ^~~~ module_arithmeticcoding.cpp:351:31: error: 'MASK' was not declared in this scope code = ((code << 1) & MASK ) | bit; ^~~~ module_arithmeticcoding.cpp: In member function 'virtual void ArithmeticDecoder::underflow()': module_arithmeticcoding.cpp:355:9: error: 'code' was not declared in this scope code = ((code & TOP_MASK) | (code << 1) & (MASK >> 1)) | read_code_bit(); ^~~~ module_arithmeticcoding.cpp:355:25: error: 'TOP_MASK' was not declared in this scope code = ((code & TOP_MASK) | (code << 1) & (MASK >> 1)) | read_code_bit(); ^~~~ module_arithmeticcoding.cpp:355:52: error: 'MASK' was not declared in this scope code = ((code & TOP_MASK) | (code << 1) & (MASK >> 1)) | read_code_bit(); ^~~~ module_arithmeticcoding.cpp:355:80: error: 'read_code_bit' was not declared in this scope code = ((code & TOP_MASK) | (code << 1) & (MASK >> 1)) | read_code_bit(); ^ module_arithmeticcoding.cpp: In function 'int main(int, char**)': module_arithmeticcoding.cpp:423:32: error: 'class ArithmeticDecoder' has no member named 'read' short symbol = dec.read(&freq);