Open edwbarker opened 6 years ago
Very much appreciate the time taken to read the code and provide feedback :)
No probs. I checked, and yes, re pt 2, I think I was misinterpreting what the cardsToX function did.
I'm not sure whether your problem at present is you're unable to actually SOLVE blackjack, or whether your network doesn't appear to be learning at all, but I've had a play and I can at least get the network to do better than random ... (see average reward per 10,000 iterations below ... yeah, it pretty much sucks, but it's at least learning the way it's supposed to ... ran it for 1M iterations). This is the bljk.py file.
I used the 20 input x, any card possible, even though, as we've noted, you'd probably get a smarter agent with, for e.g., binary input.
Anyway, not sure if this helps at all!! @;D
Oh, also, I removed the episodes ... treated it like 1 giant episode ... not for any good reason, just made it easier for me to visualise / code. Leaving in episodes w/ discount = 1 is probably better.
I got a working result! The issue was with something I that I always thought looked a bit fishy in the code:
targetQ = np.copy(Q1) targetQ[0, act] = reward + discount * maxQ2
The problem is Q1 is a prediction with 2 values -> value for stand (0) and value for hit (1). targetQ only overwrites the chosen action, leaving the network to believe that the error for the non chosen action is exactly correct. This would then require, in theory, an update to the weights that improves the prediction for the chosen action, but leaves the non-chosen action unchanged.
I dont know if this makes sense mathematically, but I do know that I changed it now so it loops through all possible actions from a state (in this case, just 2) and then calcs the discounted reward for both actions and appropriately updates targetQ[0,
This is the result I got after 120k iterations on the actual blackjack game (not the version where the only card in the deck is 5):
Mean reward (rolling window = 5000)
Q_Values for each state of the epsilon-greedy policy:
{'H10_10': '-0.4465 -0.3595', 'H10_11': '-0.8504 -0.2956', 'H10_2': '-0.3573 -0.3901', 'H10_3': '-0.4502 -0.3709', 'H10_4': '-0.1454 -0.4285', 'H10_5': '-0.1978 -0.4233', 'H10_6': '-0.4365 -0.3390', 'H10_7': '-0.5847 -0.3554', 'H10_8': '-0.4283 -0.3661', 'H10_9': '-0.2250 -0.4127', 'H11_10': '-0.6099 -0.2863', 'H11_11': '-1.0202 -0.2188', 'H11_2': '-0.6188 -0.3050', 'H11_3': '-0.5196 -0.3689', 'H11_4': '0.0569 0.3629', 'H11_5': '-0.0127 0.5686', 'H11_6': '-0.4267 0.2155', 'H11_7': '-0.1751 -0.3960', 'H11_8': '0.0032 -0.4381', 'H11_9': '-0.4976 -0.3483', 'H12_10': '-0.6479 -0.3089', 'H12_11': '-1.1561 -0.6266', 'H12_2': '-0.4022 -0.3915', 'H12_3': '-0.1128 -0.4057', 'H12_4': '-0.1215 -0.0077', 'H12_5': '-0.0927 -0.5174', 'H12_6': '-0.1126 -0.4017', 'H12_7': '-0.5824 -0.3607', 'H12_8': '-0.5655 -0.3745', 'H12_9': '-0.5762 -0.2389', 'H13_10': '-0.5278 -0.3290', 'H13_11': '-0.7024 -0.2701', 'H13_2': '-0.3812 -0.4006', 'H13_3': '-0.3208 -0.3762', 'H13_4': '-0.2333 -0.4086', 'H13_5': '0.0279 -0.4644', 'H13_6': '-0.2470 -0.3751', 'H13_7': '-0.7624 -0.3316', 'H13_8': '-0.4457 -0.3528', 'H13_9': '-0.4771 -0.3523', 'H14_10': '-0.6449 -0.3044', 'H14_11': '-1.0260 -0.5809', 'H14_2': '-0.4793 -0.3883', 'H14_3': '-0.3873 -0.3505', 'H14_4': '-0.4056 -0.3813', 'H14_5': '-0.2887 -0.3960', 'H14_6': '-0.2185 -0.4206', 'H14_7': '-0.4418 -0.3730', 'H14_8': '-0.7000 -0.3257', 'H14_9': '-0.2217 -0.4141', 'H15_10': '-0.5471 -0.3177', 'H15_11': '-1.0032 -0.5895', 'H15_2': '-0.3770 -0.3828', 'H15_3': '-0.3327 -0.3667', 'H15_4': '-0.0993 -0.4431', 'H15_5': '0.1172 -0.4565', 'H15_6': '0.1319 -0.4736', 'H15_7': '-0.5351 -0.3633', 'H15_8': '-0.5238 -0.3207', 'H15_9': '-0.5904 -0.3487', 'H16_10': '-0.6130 -0.3035', 'H16_11': '-0.7704 -0.3198', 'H16_2': '-0.2136 -0.3884', 'H16_3': '-0.2878 -0.3722', 'H16_4': '0.0348 -0.4623', 'H16_5': '-0.2399 -0.4190', 'H16_6': '-0.2588 -0.4118', 'H16_7': '-0.4394 -0.3648', 'H16_8': '-0.3165 -0.3660', 'H16_9': '-0.3384 -0.3954', 'H17_10': '-0.4633 -0.3502', 'H17_11': '-0.5977 -0.3424', 'H17_2': '-0.3289 -0.4026', 'H17_3': '-0.0719 -0.4205', 'H17_4': '0.1487 -0.4727', 'H17_5': '0.0955 -0.4738', 'H17_6': '-0.1119 -0.4415', 'H17_7': '-0.0498 -0.4451', 'H17_8': '-0.3663 -0.3689', 'H17_9': '-0.2255 -0.4193', 'H18_10': '-0.2659 -0.6736', 'H18_11': '-0.1905 -0.4109', 'H18_2': '0.0844 -0.4785', 'H18_3': '0.2795 -0.5068', 'H18_4': '0.4536 -0.5369', 'H18_5': '0.3379 -0.5145', 'H18_6': '0.3483 -0.5097', 'H18_7': '0.0831 -0.4693', 'H18_8': '0.3041 -0.5018', 'H18_9': '-0.1706 -0.7252', 'H19_10': '0.0225 -0.4404', 'H19_11': '-0.3307 -0.3995', 'H19_2': '0.4555 -0.5366', 'H19_3': '0.4548 -0.5364', 'H19_4': '0.4538 -0.5368', 'H19_5': '0.3320 -0.5162', 'H19_6': '0.4550 -0.5363', 'H19_7': '0.4560 -0.5361', 'H19_8': '0.4562 -0.5361', 'H19_9': '0.1222 -0.4692', 'H20_10': '0.4510 -0.5335', 'H20_11': '0.2662 -0.5074', 'H20_2': '0.4539 -0.5370', 'H20_3': '0.4558 -0.5362', 'H20_4': '0.4536 -0.5370', 'H20_5': '0.4457 -0.5355', 'H20_6': '0.4564 -0.5356', 'H20_7': '0.4554 -0.5362', 'H20_8': '0.4223 -0.5292', 'H20_9': '0.4281 -0.5325', 'H21_10': '0.4568 -0.5350', 'H21_11': '0.4422 -0.5316', 'H21_2': '0.4537 -0.5370', 'H21_3': '0.4758 -0.5244', 'H21_4': '0.4540 -0.5370', 'H21_5': '0.4564 -0.5356', 'H21_6': '0.4572 -0.5348', 'H21_7': '0.4541 -0.5369', 'H21_8': '0.4562 -0.5361', 'H21_9': '0.4542 -0.5361', 'H6_10': '-0.7365 -0.3099', 'H6_11': '-0.3680 -0.3667', 'H6_2': '-0.4175 -0.3901', 'H6_3': '-0.1294 -0.4141', 'H6_4': '-0.1986 -0.3835', 'H6_5': '-0.1569 -0.3849', 'H6_6': '-0.3292 -0.3945', 'H6_7': '-0.7953 -0.3906', 'H6_8': '-0.6051 -0.3226', 'H6_9': '-0.1527 -0.4120', 'H7_10': '-0.6568 -0.3439', 'H7_11': '-0.1344 -0.4391', 'H7_2': '-0.3519 -0.4160', 'H7_3': '-0.4621 -0.4147', 'H7_4': '0.3082 -0.4854', 'H7_5': '-0.0870 -0.4487', 'H7_6': '-0.1393 -0.3944', 'H7_7': '-0.4230 -0.3868', 'H7_8': '-0.4647 -0.3590', 'H7_9': '-0.8199 -0.1856', 'H8_10': '-0.6691 -0.3463', 'H8_11': '-0.1587 -0.4293', 'H8_2': '-0.1202 -0.4049', 'H8_3': '-0.2901 -0.3650', 'H8_4': '-0.0273 -0.4514', 'H8_5': '0.0575 -0.4554', 'H8_6': '-0.4661 -0.3104', 'H8_7': '-0.1370 -0.3852', 'H8_8': '-0.7071 -0.2829', 'H8_9': '-0.3805 -0.3541', 'H9_10': '-0.5621 -0.3244', 'H9_11': '-0.6539 -0.3362', 'H9_2': '-0.0347 -0.4333', 'H9_3': '-0.6611 -0.3368', 'H9_4': '-0.2613 -0.3786', 'H9_5': '0.2122 -0.4943', 'H9_6': '-0.2445 -0.3811', 'H9_7': '-0.1521 -0.4187', 'H9_8': '-0.4892 -0.3360', 'H9_9': '-0.8418 -0.2804', 'S13_10': '-0.6717 -0.3481', 'S13_11': '-0.4032 -0.3454', 'S13_2': '0.2262 -0.4567', 'S13_3': '-0.3202 -0.3737', 'S13_4': '-0.0807 -0.4262', 'S13_5': '0.3411 -0.5195', 'S13_6': '0.1540 -0.4685', 'S13_7': '-0.2783 -0.3666', 'S13_8': '-0.2844 -0.4322', 'S13_9': '-0.3481 -0.4455', 'S14_10': '-0.9915 -0.2140', 'S14_11': '-0.3477 -0.3583', 'S14_2': '-0.0480 -0.4513', 'S14_3': '-0.4516 -0.3615', 'S14_4': '0.0227 -0.4297', 'S14_5': '0.2693 -0.4811', 'S14_6': '-0.2118 -0.4022', 'S14_7': '-0.0321 -0.4566', 'S14_8': '-0.6108 -0.3267', 'S14_9': '-0.4140 -0.3459', 'S15_10': '-0.9563 -0.2969', 'S15_11': '-0.2623 -0.3433', 'S15_2': '-0.1710 -0.3849', 'S15_3': '-0.7560 -0.3394', 'S15_4': '-0.1527 -0.4174', 'S15_5': '0.0349 -0.4418', 'S15_6': '0.1786 -0.4733', 'S15_7': '-0.7728 -0.2704', 'S15_8': '-0.1071 -0.4061', 'S15_9': '-0.7964 -0.2542', 'S16_10': '-0.6308 -0.3521', 'S16_11': '-0.7214 -0.2789', 'S16_2': '-0.4062 -0.3868', 'S16_3': '-0.4847 -0.3637', 'S16_4': '0.0128 -0.4472', 'S16_5': '-0.3346 -0.3852', 'S16_6': '-0.5251 -0.3718', 'S16_7': '-0.6608 -0.3435', 'S16_8': '-0.4846 -0.3848', 'S16_9': '-0.4670 -0.3203', 'S17_10': '-0.3942 -0.3528', 'S17_11': '0.0507 -0.4625', 'S17_2': '-0.2532 -0.4113', 'S17_3': '-0.0481 -0.4128', 'S17_4': '0.3906 -0.5027', 'S17_5': '0.3336 -0.5161', 'S17_6': '-0.1147 -0.4531', 'S17_7': '-0.2537 -0.4200', 'S17_8': '-0.2371 -0.3992', 'S17_9': '0.0711 -0.4476', 'S18_10': '0.0434 -0.4638', 'S18_11': '-0.0542 -0.4506', 'S18_2': '0.4139 -0.5263', 'S18_3': '0.2710 -0.4887', 'S18_4': '0.4555 -0.5333', 'S18_5': '0.1306 -0.4627', 'S18_6': '0.2665 -0.4896', 'S18_7': '0.1396 -0.4597', 'S18_8': '0.4570 -0.5348', 'S18_9': '0.4555 -0.5367', 'S19_10': '-0.1576 -0.4371', 'S19_11': '0.1679 -0.4834', 'S19_2': '0.3102 -0.5106', 'S19_3': '0.3498 -0.5176', 'S19_4': '0.4549 -0.5376', 'S19_5': '0.3779 -0.5170', 'S19_6': '0.4553 -0.5362', 'S19_7': '0.2546 -0.4843', 'S19_8': '0.4028 -0.5281', 'S19_9': '-0.0820 -0.4082', 'S20_10': '0.4554 -0.5325', 'S20_11': '0.3790 -0.5125', 'S20_2': '0.4546 -0.5354', 'S20_3': '0.9619 -0.2244', 'S20_4': '0.4562 -0.5337', 'S20_5': '1.0788 -0.1726', 'S20_6': '0.4000 -0.5235', 'S20_7': '0.4406 -0.5296', 'S20_8': '0.1630 -0.4756', 'S20_9': '0.4555 -0.5348', 'S21_10': '0.4531 -0.5368', 'S21_11': '0.4568 -0.5350', 'S21_2': '0.4564 -0.5334', 'S21_3': '1.4788 0.0588', 'S21_4': '0.4552 -0.5351', 'S21_5': '0.4526 -0.5354', 'S21_6': '0.4526 -0.5350', 'S21_7': '0.4455 -0.5343', 'S21_8': '1.5208 0.0877', 'S21_9': '0.4343 -0.5291'}
Ah yes, nice. Interesting about the need to cycle through every action. I don't think that's standard (or perhaps I should say, I don't think that's essential in general), but, hey, if it's working ... !
The other interesting thing is that it doesn't converge to the optimal result - even if I always take an optimal action. The expected return if you play optimally is -2%, but the network only gets as good as approx -10%. I wonder why? Even if I make the fully connected network highly complicated, it still doesn't converge. In theory it should just overfit, since its the same game with the same rules every time....so strange!!!
I guess it's fair to expect it to find an optimal solution if you have sufficiently complex network (provided discount = 1). There's no formal guarantee though (unlike the tabular case). The NN can fall into a local max. But yeah, unlikely, maybe just an error someplace.
http://cs230.stanford.edu/files_winter_2018/projects/6940282.pdf is basically an exact copy of what I did, and they seemed to observe similar results. Interesting huh.
I had a look at the code. Interesting! There’s a couple of things which could be causing problems:
The way you’ve encoded the state might make it quite difficult for the AI to figure out what’s going on. Particularly for the 5’s only problem. This is because it’s looking for a linear relationship between Q and each state, i.e. Q = ax_i. Have a look at how they encoded the state here: http://incompleteideas.net/book/bookdraft2017nov5.pdf (pages 129-130, or just ctrl-f “blackjack”). I think it will make it easier if each element of x is only ever 1 or 0 typically (not always, but typically);
I may have missed something, but the way you’ve coded the update to Q, it appears you skip every second step. You calculate x(t), Q1(t), then x(t+1), Q2(t+1), compare Q2 and Q1 (i.e. calculate loss), then generate x(t+2), Q1(t+2), compare x(t+3), Q2(t+3), but points in time t+1 and t+2 never get compared. This would break the chain it needs to predict the future. Again, not 100% clear, I may have missed something, but if you’re missing a step in the chain like this it will definitely break;
The loss function you defined should maybe work, however it’s not the most obvious / direct way to use the NN. See formula 16.3 from the link above, which is the more typical method for merging a NN and RL (at least for VF estimation). I think what you’ve done is called “residual” learning or something. Another example is the code I uploaded. I think technically you’re using Q-learning, not SARSA (not that it really matters). The NN code I generated uses SARSA.
Anyway, my best guess is that either 1 or 2 are making it fail to work. If you fix 1, I can try to run my NN code against it (I can run it already, but haven’t investigated it closely since I’m sceptical it will get decent results with the input defined the way it is at the moment, or at least not without a very big network and lots of training).