TommasoBelluzzo / PyDTMC

A library for discrete-time Markov chains analysis.
MIT License
80 stars 18 forks source link

Better Handling of Fractions and Normalization #4

Closed mike-duran-mitchell closed 5 years ago

mike-duran-mitchell commented 5 years ago

Description

Getting validation errors around summing to 1 when I'm using fractions that convert to floats in transition matrices that are outside the acceptable tolerances of the numpy functions being used for validation. For example:

transition_matrix_sum 0.9999999999999999
[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.11627907 0.
 0.         0.         0.         0.04651163 0.         0.04651163
 0.65116279 0.13953488 0.        ]

transition_matrix_sum 1.0000000000000002
[0.00436535 0.38415044 0.00011193 0.00145512 0.06525632 0.00223864
 0.00044773 0.00570853 0.0003358  0.00111932 0.26393553 0.
 0.         0.00873069 0.02070741 0.09032908 0.00391762 0.14651892
 0.         0.         0.00067159]

Motivation and Context

Being able to use the python fractions library or a helper function that rounds a matrix to one without losing much info would be helpful for larger numbers and states.

Possible Implementation

I attempted to implement this (https://stackoverflow.com/a/34959983) and even roll my own custom function below, but neither are working:

def force_transition_matrix_row_to_one(transition_list: list, markov_state: str, default_index_to_add_for_sum: int):
    sum_of_list = transition_list.sum()
    if sum_of_list == 1:
        return transition_list
    else:
        difference_in_values = 1 - sum_of_list
        if difference_in_values > 0:
            index_to_update = next(
                x[0] for x in enumerate(transition_list) if x[1] > difference_in_values
            )
            transition_list[index_to_update] = transition_list[index_to_update] - difference_in_values
            return transition_list
        else:
            transition_list[default_index_to_add_for_sum] = transition_list[default_index_to_add_for_sum] + difference_in_values
            return transition_list
TommasoBelluzzo commented 5 years ago

Hi Mike, thanks for your feedback. Can you kindly provide a few reproducible examples of your issue? Those matrices doesn't really look 2d square since they count 21 elements.

mike-duran-mitchell commented 5 years ago

Sorry if I was unclear, those are parameter rows within the entire matrix that are summing up to something other than 1. An example of an entire matrix might be something like:

[[0.022222222222222223, 0.5111111111111111, 0.0, 0.0, 0.011111111111111112, 0.0, 0.0, 0.0, 0.0, 0.0, 0.36666666666666664, 0.0, 0.0, 0.03333333333333333, 0.022222222222222223, 0.0, 0.0, 0.03333333333333333, 0.0, 0.0, 0.0], [0.005440552016985138, 0.35589171974522293, 0.0001326963906581741, 0.0003980891719745223, 0.06621549893842887, 0.0025212314225053077, 0.0009288747346072187, 0.0035828025477707007, 0.0001326963906581741, 0.0001326963906581741, 0.2867569002123142, 0.0, 0.0, 0.008890658174097664, 0.017914012738853503, 0.08386411889596602, 0.007032908704883227, 0.16003184713375795, 0.0, 0.0, 0.0001326963906581741], [0.0, 0.031578947368421054, 0.0, 0.0, 0.3684210526315789, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3263157894736842, 0.0, 0.0, 0.021052631578947368, 0.08421052631578947, 0.021052631578947368, 0.0, 0.14736842105263157, 0.0, 0.0, 0.0], [0.0027434842249657062, 0.0027434842249657062, 0.0, 0.30315500685871055, 0.2674897119341564, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1906721536351166, 0.0, 0.0, 0.03429355281207133, 0.03292181069958848, 0.09602194787379972, 0.03566529492455418, 0.03292181069958848, 0.0, 0.0, 0.0013717421124828531], [0.00012187195320116997, 0.021733831654208644, 0.0017062073448163796, 0.004915502112447189, 0.4666883327916802, 0.004306142346441339, 0.0004874878128046799, 0.004509262268443289, 0.00012187195320116997, 0.0003656158596035099, 0.263649658758531, 0.0, 0.0, 0.004549886252843679, 0.008571660708482289, 0.06223594410139747, 0.012349691257718558, 0.14336204094897628, 4.062398440038999e-05, 0.0, 0.0002843678908027299], [0.0, 0.014557670772676373, 0.0011198208286674132, 0.0011198208286674132, 0.040313549832026875, 0.02127659574468085, 0.0, 0.03471444568868981, 0.0, 0.0, 0.6842105263157895, 0.0, 0.0, 0.0033594624860022394, 0.013437849944008958, 0.051511758118701005, 0.083986562150056, 0.05039193729003359, 0.0, 0.0, 0.0], [0.0, 0.0032679738562091504, 0.0, 0.0032679738562091504, 0.03594771241830065, 0.0, 0.4738562091503268, 0.0915032679738562, 0.0, 0.0, 0.29411764705882354, 0.0, 0.0, 0.013071895424836602, 0.00980392156862745, 0.0457516339869281, 0.0, 0.029411764705882353, 0.0, 0.0, 0.0], [0.0, 0.017576664173522813, 0.00037397157816005983, 0.0, 0.02692595362752431, 0.13911742707554225, 0.0033657442034405387, 0.20830216903515333, 0.0, 0.00037397157816005983, 0.4487658937920718, 0.0, 0.0, 0.007853403141361256, 0.008975317875841436, 0.04712041884816754, 0.03178758414360509, 0.05946148092744952, 0.0, 0.0, 0.0], [0.0, 0.009009009009009009, 0.0, 0.0, 0.02702702702702703, 0.0, 0.0, 0.0, 0.11711711711711711, 0.24324324324324326, 0.21621621621621623, 0.0, 0.0, 0.02702702702702703, 0.05405405405405406, 0.1891891891891892, 0.036036036036036036, 0.08108108108108109, 0.0, 0.0, 0.0], [0.0, 0.014446227929373997, 0.0, 0.0, 0.0, 0.0, 0.0, 0.004815409309791332, 0.04012841091492777, 0.4044943820224719, 0.18298555377207062, 0.0, 0.0, 0.006420545746388443, 0.011235955056179775, 0.2568218298555377, 0.008025682182985553, 0.06902086677367576, 0.0, 0.0, 0.0016051364365971107], [0.00023623907394283014, 0.01440296289522416, 0.00021337722807739497, 0.0013488489060606754, 0.04619616987875601, 0.0021109104349085147, 0.0007163378371169688, 0.01052406971338866, 0.00016765353634652462, 0.0006553729148091417, 0.6976749502754852, 0.0, 0.0, 0.00815405835867188, 0.014585857662147641, 0.07974211837863789, 0.012276811229738688, 0.1104913010676482, 6.096492230782713e-05, 3.810307644239196e-05, 0.00040389261028935474], [0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7377049180327869, 0.0, 0.0, 0.0, 0.0, 0.26229508196721313, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7526881720430108, 0.0, 0.0, 0.0, 0.24731182795698925, 0.0, 0.0, 0.0, 0.0], [0.00027731558513588466, 0.019412090959511925, 0.00027731558513588466, 0.0074875207986688855, 0.019966722129783693, 0.0011092623405435386, 0.0008319467554076539, 0.0027731558513588465, 0.0, 0.0005546311702717693, 0.34165280088740985, 0.0, 0.0, 0.12950637825845812, 0.12201885745978924, 0.21575152523571825, 0.003882418191902385, 0.13422074320576816, 0.0, 0.0, 0.00027731558513588466], [0.0009108505066605943, 0.017420015939883866, 0.00045542525333029714, 0.00717294773995218, 0.01878629169987476, 0.00159398838665604, 0.0005692815666628715, 0.005123534099965843, 0.00045542525333029714, 0.001480132073323466, 0.2864624843447569, 0.0, 0.0, 0.03142434247979051,
0.18080382557212799, 0.34077194580439485, 0.00159398838665604, 0.10360924513264261, 0.0, 0.0, 0.0013662757599908915], [1.3318771476519006e-05, 0.010948030153698622, 3.995631442955702e-05, 0.0007858075171146213, 0.026491036466796302, 0.0011587331184571535, 0.00033296928691297517, 0.002463972723156016, 0.0004794757731546842, 0.002703710609733358, 0.16633813697024585, 0.0, 0.0, 0.009509602834234571, 0.04805412748728057, 0.5407953970325777, 0.004701526331211209, 0.18462481020750646, 2.663754295303801e-05, 1.3318771476519006e-05, 0.0005194320875842412], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], [6.783225084366361e-05, 0.021384117078464958, 0.00020349675253099087, 0.0006444063830148044, 0.0686631959164985, 0.0006952805711475522, 0.00011870643897641133, 0.00215367396428632, 0.00013566450168732723, 0.0004409096304838135, 0.1658837694381794, 0.0, 0.0,
0.014516101680544015, 0.014007359799216537, 0.29829232308501075, 0.0034255286676050127, 0.409164137088979, 0.0, 1.6958062710915903e-05, 0.00018653868982007496], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1111111111111111, 0.0, 0.0, 0.0, 0.0, 0.07407407407407407, 0.0, 0.07407407407407407, 0.037037037037037035, 0.5925925925925926, 0.1111111111111111], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2222222222222222, 0.0, 0.0, 0.0, 0.0, 0.037037037037037035, 0.0, 0.0, 0.5555555555555556, 0.07407407407407407, 0.1111111111111111], [0.0, 0.01932367149758454, 0.0, 0.0, 0.01932367149758454, 0.0, 0.0, 0.01932367149758454, 0.0, 0.0, 0.3961352657004831, 0.0, 0.0, 0.004830917874396135, 0.014492753623188406, 0.1642512077294686, 0.0821256038647343, 0.07246376811594203, 0.0, 0.00966183574879227, 0.19806763285024154]]
TommasoBelluzzo commented 5 years ago

Mh, I tried that on my PC and the validation error didn't show up. Actually, the check is performed using numpy allclose, which includes an absolute tolerance factor. It should accept matrices whose rows sum to "something very close to one on both directions". Meanwhile, all I can do is to include the sanitization function that you proposed so that the MarkovChain instances can be initialized using non-conforming data.

TommasoBelluzzo commented 5 years ago

I implemented a "from_matrix" method that can be used to initialize DTMC objects with non-conforming matrices. Try it out and tell me if it can fit your needs. If it doesn't, I can try to implement your sanitization algorithm, but it needs to be tweaked a little bit before being pushed into a release. Feel free to open a new issue in case.