argmin-rs / argmin

Numerical optimization in pure Rust
http://argmin-rs.org
Apache License 2.0
992 stars 78 forks source link

`owl-qn` example does not converge #453

Closed stefan-k closed 7 months ago

stefan-k commented 7 months ago

The owl-qn example does not converge. Starting from [-1.2, 1.0] for the Rosenbrock test function, it actually moves away from the optimal point and gets stuck at [0.24, 0.054]. Regular L-BFGS does not exhibit this behaviour.

Output:

Feb 14 07:09:30.224 INFO L-BFGS
Feb 14 07:09:30.226 INFO iter: 0, cost: 6.58849571151903, best_cost: 6.58849571151903, gradient_count: 6, cost_count: 5, gamma: 1, time: 0.001579498
Feb 14 07:09:30.227 INFO iter: 1, cost: 6.22357846075582, best_cost: 6.22357846075582, gradient_count: 8, cost_count: 6, time: 0.000867555, gamma: 0.0008093585709663768
Feb 14 07:09:30.228 INFO iter: 2, cost: 6.199967551126373, best_cost: 6.199967551126373, gradient_count: 10, cost_count: 7, time: 0.001147719, gamma: 0.0009907673749112687
Feb 14 07:09:30.230 INFO iter: 3, cost: 6.185825435563604, best_cost: 6.185825435563604, gradient_count: 12, cost_count: 8, time: 0.00157656, gamma: 0.0009639410797237515
Feb 14 07:09:30.233 INFO iter: 4, cost: 6.084202470096303, best_cost: 6.084202470096303, gradient_count: 14, cost_count: 9, gamma: 0.000959222616990754, time: 0.001724485
Feb 14 07:09:30.234 INFO iter: 5, cost: 6.0186454111042185, best_cost: 6.0186454111042185, gradient_count: 17, cost_count: 11, gamma: 0.0009947875172352386, time: 0.002248385
Feb 14 07:09:30.235 INFO iter: 6, cost: 5.987978044206608, best_cost: 5.987978044206608, gradient_count: 20, cost_count: 13, time: 0.000972447, gamma: 0.0009904381049670756
Feb 14 07:09:30.236 INFO iter: 7, cost: 5.957200409071776, best_cost: 5.957200409071776, gradient_count: 23, cost_count: 15, time: 0.000574982, gamma: 0.0010021805566172747
Feb 14 07:09:30.237 INFO iter: 8, cost: 5.9262132392184, best_cost: 5.9262132392184, gradient_count: 26, cost_count: 17, gamma: 0.0010089540205003356, time: 0.000673917
Feb 14 07:09:30.237 INFO iter: 9, cost: 5.89519588186027, best_cost: 5.89519588186027, gradient_count: 29, cost_count: 19, time: 0.000469076, gamma: 0.0010165270531326798
Feb 14 07:09:30.238 INFO iter: 10, cost: 5.863913062201508, best_cost: 5.863913062201508, gradient_count: 32, cost_count: 21, gamma: 0.0010235474152182125, time: 0.000515433
Feb 14 07:09:30.238 INFO iter: 11, cost: 5.832647875802045, best_cost: 5.832647875802045, gradient_count: 35, cost_count: 23, gamma: 0.001031361549617527, time: 0.000577604
Feb 14 07:09:30.241 INFO iter: 12, cost: 5.800906261479851, best_cost: 5.800906261479851, gradient_count: 38, cost_count: 25, gamma: 0.001038668347140358, time: 0.000552869
Feb 14 07:09:30.242 INFO iter: 13, cost: 5.769381515638788, best_cost: 5.769381515638788, gradient_count: 41, cost_count: 27, time: 0.001179234, gamma: 0.0010467486877914533
Feb 14 07:09:30.244 INFO iter: 14, cost: 5.737687208561294, best_cost: 5.737687208561294, gradient_count: 44, cost_count: 29, gamma: 0.001054270608621585, time: 0.00112414
Feb 14 07:09:30.245 INFO iter: 15, cost: 5.705895411200603, best_cost: 5.705895411200603, gradient_count: 47, cost_count: 31, time: 0.001073339, gamma: 0.0010625899116485636
Feb 14 07:09:30.247 INFO iter: 16, cost: 5.673926939908691, best_cost: 5.673926939908691, gradient_count: 50, cost_count: 33, time: 0.001165698, gamma: 0.0010703932554199875
Feb 14 07:09:30.250 INFO iter: 17, cost: 5.641857778005761, best_cost: 5.641857778005761, gradient_count: 53, cost_count: 35, time: 0.001453321, gamma: 0.0010789895337989447
Feb 14 07:09:30.251 INFO iter: 18, cost: 5.609604329729833, best_cost: 5.609604329729833, gradient_count: 56, cost_count: 37, time: 0.002525351, gamma: 0.0010870909925395872
Feb 14 07:09:30.253 INFO iter: 19, cost: 5.577246809849281, best_cost: 5.577246809849281, gradient_count: 59, cost_count: 39, gamma: 0.0010959797070832905, time: 0.002345301
Feb 14 07:09:30.254 INFO iter: 20, cost: 5.544696842048378, best_cost: 5.544696842048378, gradient_count: 62, cost_count: 41, time: 0.001147985, gamma: 0.0011043973549547316
Feb 14 07:09:30.255 INFO iter: 21, cost: 5.512039225882255, best_cost: 5.512039225882255, gradient_count: 65, cost_count: 43, time: 0.000905634, gamma: 0.0011135952136723966
Feb 14 07:09:30.255 INFO iter: 22, cost: 5.4791803933585665, best_cost: 5.4791803933585665, gradient_count: 68, cost_count: 45, time: 0.000426415, gamma: 0.0011223486564211794
Feb 14 07:09:30.256 INFO iter: 23, cost: 5.4462101251910395, best_cost: 5.4462101251910395, gradient_count: 71, cost_count: 47, gamma: 0.001131873749096574, time: 0.000433809
Feb 14 07:09:30.257 INFO iter: 24, cost: 5.413029202280484, best_cost: 5.413029202280484, gradient_count: 74, cost_count: 49, time: 0.000699562, gamma: 0.0011409842928039385
Feb 14 07:09:30.260 INFO iter: 25, cost: 5.379732827133597, best_cost: 5.379732827133597, gradient_count: 77, cost_count: 51, time: 0.00043798, gamma: 0.0011508562436373221
Feb 14 07:09:30.262 INFO iter: 26, cost: 5.3462156161702765, best_cost: 5.3462156161702765, gradient_count: 80, cost_count: 53, time: 0.000375197, gamma: 0.0011603470873517846
Feb 14 07:09:30.263 INFO iter: 27, cost: 5.312578687104351, best_cost: 5.312578687104351, gradient_count: 83, cost_count: 55, gamma: 0.0011705872294129915, time: 0.000395576
Feb 14 07:09:30.264 INFO iter: 28, cost: 5.278709915360938, best_cost: 5.278709915360938, gradient_count: 86, cost_count: 57, time: 0.000538215, gamma: 0.0011804836848568966
Feb 14 07:09:30.265 INFO iter: 29, cost: 5.244716888126674, best_cost: 5.244716888126674, gradient_count: 89, cost_count: 59, time: 0.001292655, gamma: 0.001191115259756563
Feb 14 07:09:30.266 INFO iter: 30, cost: 5.210480090260123, best_cost: 5.210480090260123, gradient_count: 92, cost_count: 61, time: 0.001785796, gamma: 0.0012014450029235748
Feb 14 07:09:30.266 INFO iter: 31, cost: 5.176114203128059, best_cost: 5.176114203128059, gradient_count: 95, cost_count: 63, time: 0.00150833, gamma: 0.0012124933903012163
Feb 14 07:09:30.266 INFO iter: 32, cost: 5.141491586557158, best_cost: 5.141491586557158, gradient_count: 98, cost_count: 65, time: 0.001409648, gamma: 0.0012232867505638854
Feb 14 07:09:30.267 INFO iter: 33, cost: 5.106734722694281, best_cost: 5.106734722694281, gradient_count: 101, cost_count: 67, gamma: 0.0012347797328984974, time: 0.00138985
Feb 14 07:09:30.267 INFO iter: 34, cost: 5.071707012777161, best_cost: 5.071707012777161, gradient_count: 104, cost_count: 69, gamma: 0.0012460700264207624, time: 0.000665142
Feb 14 07:09:30.267 INFO iter: 35, cost: 5.036539541970521, best_cost: 5.036539541970521, gradient_count: 107, cost_count: 71, gamma: 0.0012580380958974954, time: 0.000730645
Feb 14 07:09:30.267 INFO iter: 36, cost: 5.001085803147246, best_cost: 5.001085803147246, gradient_count: 110, cost_count: 73, time: 0.000600395, gamma: 0.0012698620116106227
Feb 14 07:09:30.268 INFO iter: 37, cost: 4.965486398957118, best_cost: 4.965486398957118, gradient_count: 113, cost_count: 75, gamma: 0.001282338727329345, time: 0.000559845
Feb 14 07:09:30.268 INFO iter: 38, cost: 4.929583827130609, best_cost: 4.929583827130609, gradient_count: 116, cost_count: 77, gamma: 0.001294736775562112, time: 0.000651402
Feb 14 07:09:30.269 INFO iter: 39, cost: 4.893529254646268, best_cost: 4.893529254646268, gradient_count: 119, cost_count: 79, time: 0.001046858, gamma: 0.0013077591813510827
Feb 14 07:09:30.271 INFO iter: 40, cost: 4.857152934934762, best_cost: 4.857152934934762, gradient_count: 122, cost_count: 81, gamma: 0.0013207762175130287, time: 0.001473837
Feb 14 07:09:30.272 INFO iter: 41, cost: 4.820617803143742, best_cost: 4.820617803143742, gradient_count: 125, cost_count: 83, time: 0.001470479, gamma: 0.0013343853331500455
Feb 14 07:09:30.274 INFO iter: 42, cost: 4.783740425668066, best_cost: 4.783740425668066, gradient_count: 128, cost_count: 85, time: 0.001459462, gamma: 0.001348071171793901
Feb 14 07:09:30.276 INFO iter: 43, cost: 4.746696896952292, best_cost: 4.746696896952292, gradient_count: 131, cost_count: 87, gamma: 0.0013623125737246171, time: 0.001931905
Feb 14 07:09:30.278 INFO iter: 44, cost: 4.709288421409103, best_cost: 4.709288421409103, gradient_count: 134, cost_count: 89, gamma: 0.001376722712051374, time: 0.001964458
Feb 14 07:09:30.280 INFO iter: 45, cost: 4.671705868734218, best_cost: 4.671705868734218, gradient_count: 137, cost_count: 91, time: 0.001474785, gamma: 0.0013916472239905737
Feb 14 07:09:30.281 INFO iter: 46, cost: 4.633733125994796, best_cost: 4.633733125994796, gradient_count: 140, cost_count: 93, gamma: 0.0014068436986696263, time: 0.001424652
Feb 14 07:09:30.282 INFO iter: 47, cost: 4.5955777258033725, best_cost: 4.5955777258033725, gradient_count: 143, cost_count: 95, gamma: 0.0014225082181386305, time: 0.000624446
Feb 14 07:09:30.282 INFO iter: 48, cost: 4.557003941444094, best_cost: 4.557003941444094, gradient_count: 146, cost_count: 97, gamma: 0.0014385606255499699, time: 0.000328232
Feb 14 07:09:30.283 INFO iter: 49, cost: 4.518238186873647, best_cost: 4.518238186873647, gradient_count: 149, cost_count: 99, time: 0.000514702, gamma: 0.001455029119960876
Feb 14 07:09:30.283 INFO iter: 50, cost: 4.47902240707014, best_cost: 4.47902240707014, gradient_count: 152, cost_count: 101, time: 0.000548171, gamma: 0.0014720158381191094
Feb 14 07:09:30.284 INFO iter: 51, cost: 4.439604521570657, best_cost: 4.439604521570657, gradient_count: 155, cost_count: 103, gamma: 0.0014893605542072553, time: 0.000543754
Feb 14 07:09:30.285 INFO iter: 52, cost: 4.399700915708991, best_cost: 4.399700915708991, gradient_count: 158, cost_count: 105, time: 0.000573701, gamma: 0.001507370215381989
Feb 14 07:09:30.286 INFO iter: 53, cost: 4.359584140963058, best_cost: 4.359584140963058, gradient_count: 161, cost_count: 107, time: 0.000470282, gamma: 0.0015256731597012945
Feb 14 07:09:30.286 INFO iter: 54, cost: 4.318941146952946, best_cost: 4.318941146952946, gradient_count: 164, cost_count: 109, gamma: 0.0015448064370716952, time: 0.000308103
Feb 14 07:09:30.286 INFO iter: 55, cost: 4.278072870512574, best_cost: 4.278072870512574, gradient_count: 167, cost_count: 111, gamma: 0.001564161204518049, time: 0.000319802
Feb 14 07:09:30.286 INFO iter: 56, cost: 4.236632137099813, best_cost: 4.236632137099813, gradient_count: 170, cost_count: 113, time: 0.000383083, gamma: 0.0015845329954697202
Feb 14 07:09:30.287 INFO iter: 57, cost: 4.194952813308875, best_cost: 4.194952813308875, gradient_count: 173, cost_count: 115, time: 0.000367186, gamma: 0.001605047049824034
Feb 14 07:09:30.287 INFO iter: 58, cost: 4.152647877100838, best_cost: 4.152647877100838, gradient_count: 176, cost_count: 117, time: 0.000448394, gamma: 0.0016267891647017327
Feb 14 07:09:30.288 INFO iter: 59, cost: 4.110089678081848, best_cost: 4.110089678081848, gradient_count: 179, cost_count: 119, time: 0.00060093, gamma: 0.0016485867137420048
Feb 14 07:09:30.291 INFO iter: 60, cost: 4.066844289058501, best_cost: 4.066844289058501, gradient_count: 182, cost_count: 121, time: 0.001437755, gamma: 0.001671851214922245
Feb 14 07:09:30.292 INFO iter: 61, cost: 4.023329398327402, best_cost: 4.023329398327402, gradient_count: 185, cost_count: 123, gamma: 0.0016950768786386451, time: 0.001564587
Feb 14 07:09:30.293 INFO iter: 62, cost: 3.9790553723770374, best_cost: 3.9790553723770374, gradient_count: 188, cost_count: 125, gamma: 0.0017200402649889565, time: 0.001571564
Feb 14 07:09:30.294 INFO iter: 63, cost: 3.934493798051691, best_cost: 3.934493798051691, gradient_count: 191, cost_count: 127, gamma: 0.0017448638183263396, time: 0.001003116
Feb 14 07:09:30.294 INFO iter: 64, cost: 3.8890882221035743, best_cost: 3.8890882221035743, gradient_count: 194, cost_count: 129, gamma: 0.0017717323209405053, time: 0.000658218
Feb 14 07:09:30.295 INFO iter: 65, cost: 3.843374953213498, best_cost: 3.843374953213498, gradient_count: 197, cost_count: 131, time: 0.000633723, gamma: 0.0017983549179081397
Feb 14 07:09:30.296 INFO iter: 66, cost: 3.7967164869944106, best_cost: 3.7967164869944106, gradient_count: 200, cost_count: 133, time: 0.000338817, gamma: 0.0018273712742855924
Feb 14 07:09:30.296 INFO iter: 67, cost: 3.7497277341838178, best_cost: 3.7497277341838178, gradient_count: 203, cost_count: 135, time: 0.000489699, gamma: 0.001856033754688677
Feb 14 07:09:30.297 INFO iter: 68, cost: 3.701671623674794, best_cost: 3.701671623674794, gradient_count: 206, cost_count: 137, time: 0.000729304, gamma: 0.0018874859754082164
Feb 14 07:09:30.299 INFO iter: 69, cost: 3.6532597557049615, best_cost: 3.6532597557049615, gradient_count: 209, cost_count: 139, time: 0.001628506, gamma: 0.0019184801651724042
Feb 14 07:09:30.301 INFO iter: 70, cost: 3.6036309631722405, best_cost: 3.6036309631722405, gradient_count: 212, cost_count: 141, gamma: 0.0019527130231494307, time: 0.002032972
Feb 14 07:09:30.301 INFO iter: 71, cost: 3.5536175390691307, best_cost: 3.5536175390691307, gradient_count: 215, cost_count: 143, gamma: 0.0019863974481288886, time: 0.000663457
Feb 14 07:09:30.302 INFO iter: 72, cost: 3.5022010397845635, best_cost: 3.5022010397845635, gradient_count: 218, cost_count: 145, time: 0.000515144, gamma: 0.0020238277425229947
Feb 14 07:09:30.303 INFO iter: 73, cost: 3.450366979601956, best_cost: 3.450366979601956, gradient_count: 221, cost_count: 147, gamma: 0.002060650042085171, time: 0.000515925
Feb 14 07:09:30.304 INFO iter: 74, cost: 3.3968936521347843, best_cost: 3.3968936521347843, gradient_count: 224, cost_count: 149, gamma: 0.0021017871737855015, time: 0.001293822
Feb 14 07:09:30.306 INFO iter: 75, cost: 3.3429649623470636, best_cost: 3.3429649623470636, gradient_count: 227, cost_count: 151, time: 0.001425186, gamma: 0.0021423170370887348
Feb 14 07:09:30.307 INFO iter: 76, cost: 3.2870903518662615, best_cost: 3.2870903518662615, gradient_count: 230, cost_count: 153, time: 0.000933628, gamma: 0.0021877911651070987
Feb 14 07:09:30.307 INFO iter: 77, cost: 3.2307166661804496, best_cost: 3.2307166661804496, gradient_count: 233, cost_count: 155, time: 0.000592144, gamma: 0.0022327704718591145
Feb 14 07:09:30.308 INFO iter: 78, cost: 3.171987666900195, best_cost: 3.171987666900195, gradient_count: 236, cost_count: 157, gamma: 0.002283371609465784, time: 0.000580573
Feb 14 07:09:30.308 INFO iter: 79, cost: 3.1127086101657344, best_cost: 3.1127086101657344, gradient_count: 239, cost_count: 159, time: 0.000508441, gamma: 0.0023337940862202216
Feb 14 07:09:30.309 INFO iter: 80, cost: 3.0505084650259655, best_cost: 3.0505084650259655, gradient_count: 242, cost_count: 161, time: 0.000436585, gamma: 0.002390527009821323
Feb 14 07:09:30.309 INFO iter: 81, cost: 2.987698139492217, best_cost: 2.987698139492217, gradient_count: 245, cost_count: 163, gamma: 0.0024477715816427995, time: 0.000549621
Feb 14 07:09:30.310 INFO iter: 82, cost: 2.9211496490202045, best_cost: 2.9211496490202045, gradient_count: 248, cost_count: 165, gamma: 0.002511933010056234, time: 0.000570224
Feb 14 07:09:30.311 INFO iter: 83, cost: 2.8539188215733517, best_cost: 2.8539188215733517, gradient_count: 251, cost_count: 167, time: 0.000563097, gamma: 0.0025780022507546716
Feb 14 07:09:30.311 INFO iter: 84, cost: 2.7816992893007875, best_cost: 2.7816992893007875, gradient_count: 254, cost_count: 169, time: 0.000583398, gamma: 0.0026512857261964807
Feb 14 07:09:30.312 INFO iter: 85, cost: 2.7087076156960967, best_cost: 2.7087076156960967, gradient_count: 257, cost_count: 171, gamma: 0.0027292702844516326, time: 0.000527552
Feb 14 07:09:30.312 INFO iter: 86, cost: 2.6286531469534724, best_cost: 2.6286531469534724, gradient_count: 260, cost_count: 173, gamma: 0.002813885420430187, time: 0.000316243
Feb 14 07:09:30.312 INFO iter: 87, cost: 2.5477036372899096, best_cost: 2.5477036372899096, gradient_count: 263, cost_count: 175, time: 0.000312321, gamma: 0.00290897933942485
Feb 14 07:09:30.313 INFO iter: 88, cost: 2.456064811825737, best_cost: 2.456064811825737, gradient_count: 266, cost_count: 177, time: 0.000307948, gamma: 0.003007644573128093
Feb 14 07:09:30.313 INFO iter: 89, cost: 2.3527324652180748, best_cost: 2.3527324652180748, gradient_count: 269, cost_count: 179, time: 0.000368376, gamma: 0.003129765209795245
Feb 14 07:09:30.314 INFO iter: 90, cost: 2.301810464861445, best_cost: 2.301810464861445, gradient_count: 272, cost_count: 181, gamma: 0.0032549880239466974, time: 0.000484389
Feb 14 07:09:30.315 INFO iter: 91, cost: 1.9251228338098287, best_cost: 1.9251228338098287, gradient_count: 274, cost_count: 182, gamma: 0.003433559731775239, time: 0.000413395
Feb 14 07:09:30.316 INFO iter: 92, cost: 1, best_cost: 1, gradient_count: 276, cost_count: 183, gamma: 0.003498440927062263, time: 0.000411253
Feb 14 07:09:30.317 INFO iter: 93, cost: 0.9196079384099718, best_cost: 0.9196079384099718, gradient_count: 279, cost_count: 185, time: 0.00051874, gamma: 0.004632374211594264
Feb 14 07:09:30.317 INFO iter: 94, cost: 0.9002430277478184, best_cost: 0.9002430277478184, gradient_count: 282, cost_count: 187, time: 0.000514279, gamma: 0.013663699443665227
Feb 14 07:09:30.318 INFO iter: 95, cost: 0.890379346947646, best_cost: 0.890379346947646, gradient_count: 285, cost_count: 189, time: 0.00051492, gamma: 0.019677349359792273
Feb 14 07:09:30.318 INFO iter: 96, cost: 0.8766110124590296, best_cost: 0.8766110124590296, gradient_count: 287, cost_count: 190, time: 0.000481042, gamma: 0.0045593023865662975
Feb 14 07:09:30.318 INFO iter: 97, cost: 0.8757426776987766, best_cost: 0.8757426776987766, gradient_count: 290, cost_count: 192, gamma: 0.027107596927414865, time: 0.000712073
Feb 14 07:09:30.319 INFO iter: 98, cost: 0.8727072804981657, best_cost: 0.8727072804981657, gradient_count: 292, cost_count: 193, gamma: 0.004251868256317531, time: 0.000522592
Feb 14 07:09:30.319 INFO iter: 99, cost: 0.8725632888164387, best_cost: 0.8725632888164387, gradient_count: 295, cost_count: 195, time: 0.00066335, gamma: 0.043684520414275736
OptimizationResult:
    Solver:        L-BFGS
    param (best):  [0.2443746637238865, 0.05471897627016246], shape=[2], strides=[1], layout=CFcf (0xf), const ndim=1
    cost (best):   0.8725632888164387
    iters (best):  99
    iters (total): 100
    termination:   Maximum number of iterations reached
    time:          94.836463ms

This is by using rosenbrock_derivative to calculate the derivative. Using finitediff it ends up at a similar point, but much quicker.

stefan-k commented 7 months ago

Maybe the original author of OWL-QN @vbkaisetsu can help with this.

vbkaisetsu commented 7 months ago

The objective function of the L1 regularization is different (a regularization term is added), so the best parameter is different in the first place. I would like to check if this behavior is correct in OWL-QN, but may not be able to address this right away.

vbkaisetsu commented 7 months ago

I tried a grid search for the parameters that minimize the following two functions:

fn rosenbrock(x: f32, y: f32) -> f32 {
    (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2)
}

fn rosenbrock_with_l1(x: f32, y: f32) -> f32 {
    (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2) + x.abs() + y.abs()
}

fn main() {
    let mut result = (0.0, 0.0, f32::INFINITY);
    for x in 0..2000 {
        let x = x as f32 / 2000.0;
        for y in 0..2000 {
            let y = y as f32 / 2000.0;
            let z = rosenbrock(x, y);
            if z < result.2 {
                result = (x, y, z);
            }
        }
    }
    dbg!(result);
}

Results:

I think this result is similar to the example.

stefan-k commented 7 months ago

The objective function of the L1 regularization is different (a regularization term is added), so the best parameter is different in the first place.

Ah yeah, that absolutely makes sense! Now that you mention it, I faintly remember having this confusion already in the past. Sorry about the fuzz and a huge thanks for the quick response! :)