Open giammy677dev opened 10 months ago
Maybe the error is related to user_embeddings. It should be an numpy array of shape (n, m) where n is number of nodes and m is the dimensions of user embedding. The embedding values should be between 0 and 1.
Try to normalize the embedding. If it doesn't work, give me more information about the user embeddings:
print( users_embeddings[0,:] )
Hi @faalatawi and thanks for the rapid response :)
I normalized the user_embeddings as you said and it seems it go to the next line of code. Now I'm getting this error:
Traceback (most recent call last):
File "C:\Users\ogiam\PycharmProjects\ECS\1_echo_chamber_score.py", line 97, in <module>
user_emb = EchoGAE_algorithm(G, user_embeddings=users_embeddings, show_progress=False, hidden_channels=20, out_channels=10, epochs=300)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\ogiam\PycharmProjects\ECS\src\EchoGAE.py", line 49, in EchoGAE_algorithm
model, x, train_pos_edge_index = run(
^^^^
File "C:\Users\ogiam\PycharmProjects\ECS\src\GAE.py", line 68, in run
auc, ap = __test(
^^^^^^^
File "C:\Users\ogiam\PycharmProjects\ECS\src\GAE.py", line 43, in __test
return model.test(z, pos_edge_index, neg_edge_index)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\torch_geometric\nn\models\autoencoder.py", line 136, in test
return roc_auc_score(y, pred), average_precision_score(y, pred)
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 214, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\sklearn\metrics\_ranking.py", line 605, in roc_auc_score
y_true = check_array(y_true, ensure_2d=False, dtype=None)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\sklearn\utils\validation.py", line 967, in check_array
raise ValueError(
ValueError: Found array with 0 sample(s) (shape=(0,)) while a minimum of 1 is required.
To be clear, in the load_data.py, you convert user_embeddings to a dictionary:
users_embeddings = df.set_index("user_id")["embeddings"].to_dict()
The following is the users_embedding I got after normalization, as you suggested:
{0: array([0.5903886 , 0.5463845 , 0.6314844 , 0.2994573 , 0.53178114,
0.72305346, 0.3940116 , 0.5983963 , 0.36165282, 0.46718055,
0.371573 , 0.86012286, 0.75624 , 0.65208215, 0.31125316,
0.56961894, 0.36351144, 0.40996528, 0.23478517, 0.77074856,
0.21189088, 0.6337178 , 0.53798616, 0.49903655, 0.37051472,
0.3854286 , 0.4718904 , 0.5163373 , 0.19723426, 0.39895633,
0.8976927 , 0.21628626, 0.3593885 , 0.48840365, 0.48226994,
0.43893492, 0.48543227, 0.5892171 , 0.38826406, 0.40871355,
0.59276766, 0.08683099, 0.30721337, 0.5162417 , 0.6030433 ,
0.57679516, 0.37648603, 0.5437432 , 0.7303272 , 0.27150497,
0.33576313, 0.49856034, 0.57819676, 0.4675867 , 0.33486864,
0.5247566 , 0.6328391 , 0.6870526 , 0.38876027, 0.4387255 ,
0.5059428 , 0.4151606 , 0.56148404, 0.45583224, 0.26991713,
0.5404365 , 0.36613652, 0.50812584, 0.27858657, 0.77044636,
0.44390696, 0.7116696 , 0.74979055, 0.53280395, 0.5741853 ,
0.50218266, 0.5360777 , 0.53478837, 1. , 0.41538575,
0.9712991 , 0.6728056 , 0.63580185, 0.41534263, 0.58824044,
0.39682275, 0.35553297, 0.12713297, 0.4352373 , 0.4530455 ,
0.21358362, 0.44481987, 0.4569379 , 0.6189847 , 0.41602683,
0.5464685 , 0.44565925, 0.75979394, 0. , 0.5848383 ,
0.02218434, 0.3135237 , 0.29259095, 0.1951891 , 0.59517765,
0.46691978, 0.64640707, 0.34184644, 0.4921255 , 0.5788995 ,
0.6034956 , 0.6086504 , 0.26671427, 0.35178533, 0.7536429 ,
0.32784587, 0.39693376, 0.36548433, 0.3502349 , 0.41376254,
0.6751696 , 0.32571805, 0.29139057, 0.5199206 , 0.7914655 ,
0.35013318, 0.40610346, 0.474595 , 0.715731 , 0.55157787,
0.69133556, 0.6740565 , 0.2891 , 0.58775306, 0.62619764,
0.22411384, 0.52803564, 0.5504468 , 0.42701572, 0.49767405,
0.58265716, 0.8593093 , 0.19029467, 0.29489756, 0.31142318,
0.58041364, 0.35520896, 0.32476845, 0.49472684, 0.5327737 ,
0.5418649 , 0.63241345, 0.7387447 , 0.22186379, 0.69398755,
0.49795955, 0.44908583, 0.54026526, 0.25971207, 0.6299726 ,
0.36077514, 0.67199713, 0.5334661 , 0.1798632 , 0.39948094,
0.55486876, 0.43986547, 0.22933358, 0.7387832 , 0.92368567,
0.29325098, 0.58843464, 0.75875854, 0.5492905 , 0.5206712 ,
0.47314996, 0.37598154, 0.36545646, 0.5015214 , 0.6523398 ,
0.56226677, 0.596226 , 0.4993254 , 0.3105751 , 0.3684766 ,
0.5790526 , 0.42430234, 0.21187836, 0.47814125, 0.65961665,
0.41557905, 0.50058234, 0.55442435, 0.42274082, 0.1365884 ,
0.1645132 , 0.29398423, 0.6055696 , 0.77949506, 0.35396424,
0.15519348, 0.26841366, 0.28413072, 0.40042803, 0.7029442 ,
0.14210477, 0.6940655 , 0.19428387, 0.11259534, 0.4314326 ,
0.8772073 , 0.5604053 , 0.54182166, 0.38905093, 0.57011956,
0.43929538, 0.6414684 , 0.50285524, 0.42010292, 0.5579324 ,
0.72515714, 0.5751192 , 0.10681102, 0.474595 , 0.38437125,
0.4455422 , 0.5125746 , 0.70831937, 0.5460206 , 0.57248586,
0.34819818, 0.52376235, 0.5426327 , 0.33171818, 0.38510123,
0.6092513 , 0.6695501 , 0.67054087, 0.31687284, 0.527404 ,
0.5847972 , 0.3198751 , 0.4546798 , 0.42352444, 0.70881397,
0.9764365 , 0.40056378, 0.7918312 , 0.3442903 , 0.45757028,
0.6102687 , 0.55319154, 0.71293557, 0.17177549, 0.18230942,
0.66884625, 0.7329294 , 0.46095172, 0.4112841 , 0.8137524 ,
0.69304144, 0.16537918, 0.39652187, 0.38389784, 0.51303506,
0.07949679, 0.34835663, 0.19845116, 0.51711726, 0.5326349 ,
0.54591066, 0.57445216, 0.4025225 , 0.53146696, 0.35923636,
0.47883376, 0.22849476, 0.75719845, 0.28265658, 0.37521958,
0.40089855, 0.33905256, 0.35004425, 0.6059098 , 0.33998072,
0.4290143 , 0.45183882, 0.36477107, 0.58499044, 0.664054 ,
0.5714948 , 0.33190066, 0.675558 , 0.27499285, 0.39646444,
0.1741357 , 0.43009242, 0.41168192, 0.6453223 , 0.4173889 ,
0.2664299 , 0.48379079, 0.59220153, 0.7751846 , 0.41976035,
0.02754701, 0.47559944, 0.4104384 , 0.36888206, 0.530058 ,
0.5023671 , 0.13987543, 0.33606586, 0.7380448 , 0.3089485 ,
0.50874645, 0.25042897, 0.4580534 , 0.5849977 , 0.4745948 ,
0.7601535 , 0.33843175, 0.3520388 , 0.45736083, 0.4769383 ,
0.6366737 , 0.01269261, 0.35455778, 0.73761207, 0.7977838 ,
0.8040379 , 0.59207374, 0.549212 , 0.37760028, 0.46691293,
0.53902924, 0.34738073, 0.5861576 , 0.25502285, 0.5389182 ,
0.3297547 , 0.583315 , 0.32766637, 0.569069 , 0.8713058 ,
0.3744244 , 0.22712737, 0.44884762, 0.44120434, 0.34806576,
0.2522678 , 0.46934766, 0.30780926, 0.5544345 , 0.43582988,
0.44257188, 0.36943206, 0.5749146 , 0.7811406 , 0.42525542,
0.7214907 , 0.5741349 , 0.4951715 , 0.39221266, 0.26793188,
0.24854879, 0.32631084, 0.77015775, 0.5320774 , 0.23697053,
0.08882525, 0.54427993, 0.5777584 , 0.4866187 , 0.61439764,
0.6713501 , 0.542743 , 0.5242594 , 0.65211165, 0.5329402 ,
0.50373816, 0.3688049 , 0.35005218, 0.4660658 ], dtype=float32), 1: array([0.5504856 , 0.5994761 , 0.7668127 , 0.40479252, 0.6021334 ,
0.65550673, 0.49918935, 0.5916094 , 0.38146687, 0.51254594,
0.5037622 , 0.82190436, 0.69192666, 0.7278631 , 0.40627387,
0.6097601 , 0.47449362, 0.52652866, 0.38925773, 0.79302955,
0.3318664 , 0.63591677, 0.5368127 , 0.5571479 , 0.4492227 ,
0.5301211 , 0.4952623 , 0.53548306, 0.33398268, 0.49999142,
0.85397255, 0.2610326 , 0.483789 , 0.48730516, 0.5092014 ,
0.47482398, 0.6071238 , 0.6307133 , 0.52939606, 0.58217114,
0.5603396 , 0.21029054, 0.4193466 , 0.5934032 , 0.7729158 ,
0.61854917, 0.3926683 , 0.60682935, 0.71347207, 0.4038011 ,
0.38194218, 0.49573156, 0.636773 , 0.60056764, 0.47029474,
0.47534442, 0.60573375, 0.5910211 , 0.40359557, 0.47699314,
0.49011517, 0.46399236, 0.53056073, 0.593339 , 0.41638085,
0.5195405 , 0.43436658, 0.52034014, 0.41826865, 0.6791163 ,
0.5859266 , 0.6711173 , 0.6697913 , 0.5201541 , 0.58509666,
0.6142568 , 0.54922146, 0.55967075, 1. , 0.6134324 ,
0.8166271 , 0.5827241 , 0.69778234, 0.46431488, 0.6503128 ,
0.42055455, 0.48216805, 0.29052272, 0.4626449 , 0.5175474 ,
0.31328282, 0.4980816 , 0.54858106, 0.7439629 , 0.3890496 ,
0.6421289 , 0.5331454 , 0.6628377 , 0.19336227, 0.6261516 ,
0.18801202, 0.32611975, 0.3060235 , 0.26275033, 0.5535011 ,
0.56534064, 0.62437934, 0.42243594, 0.5483634 , 0.6074568 ,
0.56832886, 0.6379465 , 0.43839157, 0.47260424, 0.7210698 ,
0.41984734, 0.4971962 , 0.31881064, 0.47105277, 0.36896485,
0.64241767, 0.42088464, 0.38107595, 0.5341317 , 0.72590035,
0.5268472 , 0.63523835, 0.53221124, 0.7368571 , 0.4755117 ,
0.79182297, 0.70873946, 0.54206204, 0.69074744, 0.6341343 ,
0.39392188, 0.5516004 , 0.5458956 , 0.47847182, 0.46813655,
0.5993048 , 0.81980145, 0.35527578, 0.34878874, 0.4191267 ,
0.5242655 , 0.5142229 , 0.3976383 , 0.5661604 , 0.5192766 ,
0.57318205, 0.7227642 , 0.8097223 , 0.5119465 , 0.6859499 ,
0.50233614, 0.53391176, 0.61771804, 0.38389373, 0.6997867 ,
0.34473568, 0.6623587 , 0.53183115, 0.32129094, 0.36221623,
0.5237297 , 0.42702368, 0.37889093, 0.66812634, 0.8407371 ,
0.42116776, 0.5665252 , 0.7053944 , 0.5736386 , 0.6230023 ,
0.48117504, 0.52796155, 0.39941242, 0.54547876, 0.7272116 ,
0.54897714, 0.523994 , 0.6533455 , 0.38077196, 0.5519111 ,
0.66316503, 0.46431178, 0.37412134, 0.5933377 , 0.61004925,
0.49943402, 0.4806319 , 0.5991593 , 0.51578337, 0.44524017,
0.24100226, 0.42497262, 0.63646454, 0.78476584, 0.55473834,
0.26551068, 0.36232346, 0.29171205, 0.4763685 , 0.6601026 ,
0.34422097, 0.67474854, 0.39421585, 0.34440812, 0.45544848,
0.8161871 , 0.6905303 , 0.6365791 , 0.5285045 , 0.6099448 ,
0.54091954, 0.56154066, 0.5089898 , 0.5446285 , 0.643386 ,
0.6753814 , 0.5724523 , 0.35792008, 0.53221124, 0.41221312,
0.55671316, 0.5096273 , 0.6705215 , 0.51981735, 0.548849 ,
0.46995398, 0.64089704, 0.6630525 , 0.5049213 , 0.4652524 ,
0.7100901 , 0.7665413 , 0.7090804 , 0.38844725, 0.60974854,
0.6509926 , 0.3871108 , 0.48560327, 0.48087615, 0.64458025,
0.9014838 , 0.54654825, 0.73597574, 0.44388396, 0.5535791 ,
0.61524373, 0.651733 , 0.5902918 , 0.3039532 , 0.32333356,
0.6485903 , 0.6703676 , 0.5282404 , 0.42160285, 0.7147951 ,
0.69491166, 0.2628424 , 0.47986296, 0.51505476, 0.6353423 ,
0.31095037, 0.4648145 , 0.41581687, 0.5836198 , 0.6675648 ,
0.5347364 , 0.6146405 , 0.47276402, 0.603103 , 0.46342447,
0.55651903, 0.33279827, 0.66820097, 0.42250937, 0.4014731 ,
0.5020434 , 0.30073088, 0.4639544 , 0.52969277, 0.42189834,
0.45939073, 0.48187277, 0.46560508, 0.60856736, 0.69641584,
0.56047803, 0.46933538, 0.69886035, 0.37440833, 0.5953375 ,
0.38355142, 0.40880513, 0.45499676, 0.600989 , 0.47509772,
0.4255842 , 0.5842295 , 0.6488961 , 0.8021985 , 0.44763464,
0.24923256, 0.52268636, 0.48713988, 0.5499488 , 0.5343191 ,
0.55189824, 0.19829603, 0.43111655, 0.76310563, 0.28973398,
0.569523 , 0.4592044 , 0.5501238 , 0.5594205 , 0.5322111 ,
0.84726304, 0.48171085, 0.45779154, 0.5082828 , 0.5513334 ,
0.64164245, 0. , 0.4149034 , 0.6931229 , 0.79822356,
0.75499356, 0.71429014, 0.6479542 , 0.5913222 , 0.6086662 ,
0.51924384, 0.51713663, 0.5809727 , 0.33458105, 0.47934353,
0.413319 , 0.6978297 , 0.36808777, 0.5374097 , 0.75435174,
0.4699698 , 0.33652395, 0.5289879 , 0.48408607, 0.31341746,
0.27841452, 0.56569135, 0.33595476, 0.5754073 , 0.47615036,
0.59850264, 0.45045155, 0.53011894, 0.7813892 , 0.46000174,
0.6293877 , 0.57028383, 0.5635631 , 0.51294714, 0.32131955,
0.41615418, 0.43640524, 0.6300865 , 0.56392336, 0.24798852,
0.34310782, 0.54155445, 0.64739555, 0.52224094, 0.6394095 ,
0.6416259 , 0.62210846, 0.4284557 , 0.70014524, 0.5149384 ,
0.522662 , 0.37624723, 0.36317435, 0.5684601 ], dtype=float32), 2: array([0.6748316 , 0.4074247 , 0.49105936, 0.38992038, 0.46978638,
0.6357716 , 0.42606366, 0.560135 , 0.54508513, 0.5517799 ,
0.3964912 , 0.78342354, 0.8053925 , 0.5808751 , 0.3884107 ,
0.56678355, 0.37066388, 0.54382163, 0.22759368, 0.70864516,
0.303966 , 0.60365826, 0.5034732 , 0.52999544, 0.46359798,
0.2859369 , 0.52990705, 0.5909621 , 0.32454422, 0.38926983,
0.96485156, 0.37908873, 0.4566861 , 0.6039012 , 0.59427416,
0.5869093 , 0.57286465, 0.5639263 , 0.48871076, 0.3544438 ,
0.7356649 , 0.20563804, 0.42331544, 0.5544179 , 0.3522051 ,
0.6289837 , 0.5469201 , 0.5465153 , 0.7104316 , 0.30148742,
0.21848506, 0.5294103 , 0.5647303 , 0.51238114, 0.28752846,
0.68530625, 0.72127384, 0.682508 , 0.49102718, 0.5946483 ,
0.61741525, 0.46993878, 0.6558486 , 0.3242415 , 0.38155767,
0.7227691 , 0.51612824, 0.58191395, 0.28775153, 0.6944261 ,
0.5493823 , 0.67946756, 0.7234838 , 0.54425246, 0.41901666,
0.5975187 , 0.627099 , 0.52148986, 0.9563764 , 0.2641179 ,
0.7301533 , 0.67143786, 0.55293715, 0.56072325, 0.49951297,
0.53529453, 0.48332396, 0.08366083, 0.4848326 , 0.65085953,
0.39826554, 0.3728425 , 0.45496678, 0.43883774, 0.41533807,
0.6146676 , 0.56835836, 0.72247344, 0. , 0.57956624,
0.27458906, 0.43997896, 0.55270195, 0.20345142, 0.6813145 ,
0.7134562 , 0.7397493 , 0.31485173, 0.5345157 , 0.5753244 ,
0.64990854, 0.62176925, 0.47460815, 0.5424521 , 0.8495754 ,
0.41225788, 0.48881984, 0.54615325, 0.3140851 , 0.40762606,
0.7431823 , 0.48576972, 0.45932463, 0.70246595, 0.619881 ,
0.29678798, 0.51290536, 0.5251198 , 0.77270365, 0.65206325,
0.55996567, 0.63707215, 0.14858927, 0.60790217, 0.56439173,
0.24563964, 0.60371107, 0.5484797 , 0.4412493 , 0.7600152 ,
0.5376845 , 0.79639024, 0.2144368 , 0.5354967 , 0.24194936,
0.58767736, 0.4875134 , 0.42291513, 0.55643535, 0.43211702,
0.6417125 , 0.56925964, 0.6627311 , 0.17544743, 0.7613894 ,
0.6579407 , 0.57729125, 0.5627387 , 0.3440774 , 0.63208085,
0.5659609 , 0.7209287 , 0.6067615 , 0.40275 , 0.5164946 ,
0.8186469 , 0.6339206 , 0.20342632, 0.8481145 , 0.82366467,
0.56067413, 0.71509784, 0.80045533, 0.7285235 , 0.4451461 ,
0.6246485 , 0.45633486, 0.42407495, 0.43958852, 0.58433235,
0.69412637, 0.5768258 , 0.32418138, 0.454048 , 0.3948432 ,
0.6179795 , 0.50213104, 0.23949648, 0.48517033, 0.7483972 ,
0.36877486, 0.61136216, 0.44356537, 0.54707456, 0.10315315,
0.29949546, 0.2710408 , 0.5831354 , 0.707144 , 0.4151 ,
0.24692956, 0.39056236, 0.35587728, 0.52209836, 0.7207828 ,
0.34318405, 0.8759284 , 0.18325937, 0.03065613, 0.49207333,
0.8686864 , 0.5794197 , 0.6828631 , 0.36431122, 0.6268849 ,
0.5038582 , 0.66722894, 0.63037086, 0.3770932 , 0.48218024,
0.81755686, 0.6661568 , 0.12384431, 0.5251198 , 0.44799998,
0.3825945 , 0.71873266, 0.9796641 , 0.62351716, 0.6530301 ,
0.33182004, 0.57896644, 0.45177144, 0.41930673, 0.3770389 ,
0.431782 , 0.6692333 , 0.5824252 , 0.6155698 , 0.40630376,
0.6964858 , 0.32052276, 0.4700297 , 0.4917665 , 0.8633192 ,
0.82269835, 0.3568557 , 0.985096 , 0.38960674, 0.51275676,
0.53697646, 0.39290482, 0.5948498 , 0.40631393, 0.4710755 ,
0.75382733, 0.8678897 , 0.5860656 , 0.50644547, 0.6931877 ,
0.827078 , 0.29435402, 0.39100865, 0.5259532 , 0.50417584,
0.18836886, 0.45406368, 0.16322745, 0.3951799 , 0.48666334,
0.43524384, 0.4730288 , 0.23024385, 0.50140816, 0.5175587 ,
0.48539448, 0.21951239, 0.76472694, 0.36810192, 0.6813255 ,
0.41129318, 0.57596487, 0.63458204, 0.8033597 , 0.36706468,
0.45644128, 0.50169414, 0.3441729 , 0.60120344, 0.6332276 ,
0.63546056, 0.4909976 , 0.5266029 , 0.58087826, 0.4846154 ,
0.16037425, 0.62975156, 0.3920663 , 0.6932027 , 0.5640802 ,
0.22320414, 0.3731415 , 0.6093587 , 0.6346745 , 0.6098801 ,
0.16370593, 0.5321208 , 0.36934385, 0.4892848 , 0.66808945,
0.66030204, 0.23884499, 0.36510858, 0.64053303, 0.54897684,
0.3590053 , 0.2972839 , 0.34753802, 0.6453297 , 0.5251196 ,
0.6814555 , 0.39543784, 0.44776392, 0.68364346, 0.4058828 ,
0.54477197, 0.25028777, 0.3434456 , 0.7371401 , 0.86635256,
0.686606 , 0.49274966, 0.4893725 , 0.42395604, 0.46989986,
0.61714303, 0.2987194 , 0.76872516, 0.40346637, 0.67223847,
0.5341008 , 0.48688528, 0.38268352, 0.666827 , 1. ,
0.43178546, 0.2497689 , 0.584798 , 0.69876546, 0.7022596 ,
0.56139475, 0.535826 , 0.44755712, 0.51189864, 0.55382276,
0.5954705 , 0.5261249 , 0.627666 , 0.6882771 , 0.4707281 ,
0.7680467 , 0.76645 , 0.4647652 , 0.404939 , 0.53761715,
0.29035944, 0.35611305, 0.8760471 , 0.6254243 , 0.38855845,
0.34764704, 0.5626235 , 0.6937429 , 0.48045602, 0.68960315,
0.6259925 , 0.58327866, 0.7423123 , 0.5612407 , 0.49681857,
0.6235863 , 0.39066392, 0.51519734, 0.49413523], dtype=float32), 3: array([0.6542318 , 0.5377808 , 0.39427352, 0.42770526, 0.47166738,
0.6680032 , 0.44329754, 0.49493238, 0.46237895, 0.57495785,
0.42989856, 0.7036364 , 0.7427337 , 0.6434955 , 0.50641924,
0.62676185, 0.41936448, 0.5544305 , 0.2921245 , 0.5734323 ,
0.2935471 , 0.6822767 , 0.5622412 , 0.62485796, 0.48756942,
0.22544909, 0.49136087, 0.679961 , 0.35861352, 0.35035822,
0.8163559 , 0.520737 , 0.4385027 , 0.6636616 , 0.6539404 ,
0.60592294, 0.49439794, 0.5566711 , 0.43656635, 0.3411835 ,
0.69705296, 0.3407699 , 0.4261004 , 0.62751544, 0.2921626 ,
0.5072126 , 0.48436025, 0.55037826, 0.6580973 , 0.32724953,
0.19243304, 0.62950337, 0.50237507, 0.4228691 , 0.21795171,
0.6045182 , 0.7464016 , 0.7525598 , 0.54767984, 0.56992143,
0.5775037 , 0.48859718, 0.74667466, 0.23965628, 0.4473735 ,
0.6389479 , 0.45452487, 0.58664054, 0.30382884, 0.7368003 ,
0.56128067, 0.7119723 , 0.6852876 , 0.5870384 , 0.3799207 ,
0.6283663 , 0.5649955 , 0.6069957 , 0.94719225, 0.22929955,
0.6486107 , 0.65455157, 0.5466983 , 0.4873031 , 0.40211558,
0.52241105, 0.45032373, 0.13322203, 0.5490885 , 0.6202621 ,
0.4511717 , 0.3063856 , 0.42220318, 0.39039963, 0.38010105,
0.6651868 , 0.5328936 , 0.6131631 , 0. , 0.52032447,
0.42317504, 0.5114572 , 0.6621108 , 0.28805304, 0.6526411 ,
0.7281291 , 0.76102555, 0.32751876, 0.5521735 , 0.5289147 ,
0.69779736, 0.68449605, 0.4677217 , 0.46570984, 0.74059623,
0.6687386 , 0.526675 , 0.5870634 , 0.28352338, 0.4921353 ,
0.7379906 , 0.43511122, 0.50613725, 0.7821665 , 0.6494672 ,
0.33980587, 0.5046495 , 0.5283351 , 0.8038778 , 0.52949345,
0.55786353, 0.68334764, 0.27564308, 0.3695339 , 0.6317535 ,
0.13243376, 0.6835329 , 0.52257854, 0.45744082, 0.8174785 ,
0.57084996, 0.76096934, 0.15712643, 0.56485105, 0.2633177 ,
0.5813226 , 0.5298225 , 0.5427938 , 0.586514 , 0.43465352,
0.6102504 , 0.5177347 , 0.7217669 , 0.21697025, 0.6146182 ,
0.6744388 , 0.60285586, 0.59281015, 0.26246607, 0.6146374 ,
0.58215857, 0.6584 , 0.59804606, 0.38756725, 0.52817786,
0.78565884, 0.70261556, 0.28947195, 0.8879533 , 0.75424474,
0.56979764, 0.79933727, 0.70712245, 0.6947445 , 0.46565795,
0.63337314, 0.44335625, 0.5567024 , 0.45160365, 0.4859375 ,
0.5825016 , 0.47205833, 0.29501644, 0.54978055, 0.4105344 ,
0.53806275, 0.52932763, 0.30232495, 0.49972665, 0.7279396 ,
0.2684043 , 0.6308848 , 0.43253362, 0.39378652, 0.11930857,
0.3749699 , 0.32635775, 0.55045974, 0.70638657, 0.35426024,
0.21915436, 0.3507852 , 0.43327543, 0.48687625, 0.71844566,
0.37366402, 0.90822446, 0.1718719 , 0.18446386, 0.55817807,
0.8028598 , 0.5387753 , 0.6225522 , 0.3418844 , 0.57539505,
0.41666165, 0.7595317 , 0.5946971 , 0.32267916, 0.4584917 ,
0.8040467 , 0.6062819 , 0.13489926, 0.5283351 , 0.4586317 ,
0.473593 , 0.7651325 , 0.98452383, 0.6477622 , 0.6048426 ,
0.34395784, 0.59981287, 0.3959062 , 0.42500147, 0.4247128 ,
0.44683853, 0.53576064, 0.6344432 , 0.67739755, 0.37684232,
0.72056973, 0.4385622 , 0.43668 , 0.6016548 , 0.79325795,
0.66228557, 0.31493598, 0.97571874, 0.33136997, 0.578599 ,
0.46057224, 0.29644817, 0.5894011 , 0.34435746, 0.507949 ,
0.6673075 , 0.8544765 , 0.5088211 , 0.73321015, 0.56988007,
0.8791954 , 0.46055466, 0.2512141 , 0.54515755, 0.51327956,
0.2897278 , 0.5352044 , 0.19738676, 0.33345237, 0.47579262,
0.47217768, 0.45885113, 0.32030272, 0.54774165, 0.4643789 ,
0.50881547, 0.16045596, 0.7847433 , 0.30525547, 0.7076869 ,
0.42714575, 0.6004711 , 0.66534877, 0.9246982 , 0.37249017,
0.470351 , 0.50644106, 0.27928612, 0.7095356 , 0.5893712 ,
0.65979886, 0.55490106, 0.53423333, 0.5893161 , 0.5291243 ,
0.18188052, 0.7221448 , 0.41399443, 0.6156911 , 0.51341146,
0.35419738, 0.40884674, 0.63282645, 0.5193399 , 0.6381396 ,
0.25020102, 0.43971163, 0.4374941 , 0.44072276, 0.71287936,
0.62531126, 0.36002475, 0.33606008, 0.5655713 , 0.68675596,
0.27631754, 0.33216003, 0.5067967 , 0.63783395, 0.5283349 ,
0.61430395, 0.4524425 , 0.3432733 , 0.6720959 , 0.39899063,
0.55632484, 0.31208792, 0.30387443, 0.7800658 , 0.8409841 ,
0.6674822 , 0.5534705 , 0.475604 , 0.41495994, 0.4002237 ,
0.6579579 , 0.33016756, 0.752389 , 0.3981897 , 0.67810756,
0.5619784 , 0.44760135, 0.41618544, 0.6703766 , 1. ,
0.46168402, 0.245408 , 0.65896124, 0.8166215 , 0.7485484 ,
0.594238 , 0.6279917 , 0.32763523, 0.5644418 , 0.6252166 ,
0.4871536 , 0.6665477 , 0.6146433 , 0.686448 , 0.49793965,
0.7148773 , 0.75429183, 0.47640997, 0.38003728, 0.5814651 ,
0.3123399 , 0.32535496, 0.7765912 , 0.58655864, 0.5247335 ,
0.27610427, 0.5483681 , 0.71346617, 0.5215375 , 0.71260816,
0.5932541 , 0.57799846, 0.7147718 , 0.5404635 , 0.445882 ,
0.69863516, 0.37483338, 0.59227234, 0.5595903 ], dtype=float32)}
What am I missing? Thanks a lot for the help!
Thank you so much for pointing out this issue. I should have documented how to use this code better. After fixing your issue I will go back and document it much better.
I need more information about the graph and the users_embedding. Could you please run this code?:
import networkx as nx
import numpy as np
G = # Your Graph
users_embeddings = # Your embeddings
# The graph
print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")
print(f"Is directed?: {G.is_directed()}")
print(f"Is connected?: {nx.is_connected(G)}")
nodes = G.nodes(data=True)
print(f"Nodes: {nodes[0]}")
print(f"Length of users_embeddings: {len(users_embeddings)}")
print(f"Type of users_embeddings: {type(users_embeddings)}")
i = users_embeddings.keys()[0]
first_embedding_shape = users_embeddings[i].shape
print(f"Shape of the first embedding: {first_embedding_shape}")
# Check if all the embeddings have the same shape
ans = all(emb == first_embedding_shape for emb in users_embeddings.values())
print(f"Are all the embeddings of the same shape?: {ans}")
Hi @faalatawi, I experimented with your code using a very dummy dataset. The graph is quite small, with 4 nodes and 4 edges. Initially, my graph is directed. However, after executing the following lines of code in your "get_data" function, it becomes an undirected graph:
G = nx.read_gml(dataset_path + "graph.gml")
if nx.is_directed(G):
G = G.to_undirected()
So, after obtaining the user embeddings, I executed the "verification code" you posted in your last reply. I made a few minor modifications as there were some errors. Here is the code I ran:
import networkx as nx
import numpy as np
def alatawi_check(G, users_embeddings):
# The graph
print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")
print(f"Is directed?: {G.is_directed()}")
print(f"Is connected?: {nx.is_connected(G)}")
nodes = G.nodes(data=True)
print(nodes)
print(f"Nodes: {nodes[0]}")
print(f"Length of users_embeddings: {len(users_embeddings)}")
print(f"Type of users_embeddings: {type(users_embeddings)}")
i = list(users_embeddings.keys())[0]
first_embedding_shape = users_embeddings[i].shape
print(f"Shape of the first embedding: {first_embedding_shape}")
# Check if all the embeddings have the same shape
ans = all(np.array_equal(emb, first_embedding_shape) for emb in users_embeddings.values())
print(f"Are all the embeddings of the same shape?: {ans}")
Here are the results I get from this execution:
Number of nodes: 4
Number of edges: 3
Is directed?: False
Is connected?: True
[(0, {}), (1, {}), (2, {}), (3, {})]
Nodes: {}
Length of users_embeddings: 4
Type of users_embeddings: <class 'dict'>
Shape of the first embedding: (384,)
Are all the embeddings of the same shape?: False
I reviewed your response in thread #1 and observed that it's essential to group together all the tweets belonging to the same user to ensure the proper functioning of your code. I incorporated this as a preprocessing phase in my original dataset and subsequently modified the "preprocess_tweets" function.
Additionally, I included the normalization of the users' embeddings, as you mentioned in your initial response in this thread. However, I noticed that in your example dataset, such as the "gun" one, the embeddings are not normalized. For this reason, I am uncertain whether normalization is a necessary step for your code to work correctly.
To facilitate debugging, I am providing the complete modified code for the "get_data" function:
import networkx as nx
import numpy as np
import pandas as pd
from networkx.algorithms.community.louvain import louvain_communities
from sentence_transformers import SentenceTransformer
from .tweet_preprocessing import preprocess_tweet_for_bert
from sklearn.utils import check_random_state
import torch
import random
import os
from tqdm import tqdm
tqdm.pandas()
# disable warnings
import warnings
warnings.filterwarnings("ignore")
def get_data(dataset_path: str, louvain_resolution=0.05, SBert_model_name="all-MiniLM-L6-v2"):
# Set the random seed for python
# seed = 42
# random.seed(seed)
# np.random.seed(seed)
# pd.np.random.seed(seed)
# check_random_state(seed)
# torch.manual_seed(seed)
# 1: Get the Graph G
G = nx.read_gml(dataset_path + "graph.gml")
if nx.is_directed(G):
G = G.to_undirected()
# 2: Get the user embeddings
df = pd.read_feather(dataset_path + "tweets.feather")
if os.path.exists(dataset_path + "embeddings.feather"):
print("Loading embeddings")
df_emb = pd.read_feather(dataset_path + "embeddings.feather")
df = df.merge(df_emb, on="user_id", how="inner")
else:
print("No embeddings found, creating them")
def preprocess_tweets(tweets):
# print(tweets)
out = []
#for tw in tweets:
# print(tw)
tw = preprocess_tweet_for_bert(tweets)
# print(tw)
if len(tw) > 1:
out.append(" ".join(tw))
return out
print("Preprocessing tweets")
# print(df["tweets"])
# print(df.tweets)
df["tweets"] = df.tweets.progress_apply(preprocess_tweets)
# print(df["tweets"])
# Remove users with no tweets
df = df[df.tweets.apply(len) > 0]
model = SentenceTransformer(SBert_model_name)
def embed_user_tweets(tweets):
emb = model.encode(tweets)
emb = np.mean(emb, axis=0)
return emb
print("Embedding tweets")
df["embeddings"] = df.tweets.progress_apply(embed_user_tweets)
df_emb = df[["user_id", "embeddings"]]
df_emb.reset_index(drop=True, inplace=True)
df_emb.to_feather(dataset_path + "embeddings.feather")
# 3: Filter the graph
G = G.subgraph(df.user_id)
print(f'Graph: {G}')
connected_components = list(nx.connected_components(G))
lcc_nodes = max(nx.connected_components(G), key=len)
G = G.subgraph(lcc_nodes)
df = df[df.user_id.isin(G.nodes())]
# 5: Find the communities
community = louvain_communities(G, resolution=louvain_resolution, seed=42)
community = list(community)
def which_community(node):
for i, c in enumerate(community):
if node in c:
return i
return -1
df["community"] = df.user_id.apply(which_community)
# allsides
df_allsides = pd.read_feather(dataset_path + "allsides.feather")
# 4: Make a map from user_id to index
node_id_map = {node: i for i, node in enumerate(G.nodes())}
G = nx.relabel_nodes(G, node_id_map)
# 6: Get the embeddings and labels
print(df.set_index("user_id")["embeddings"])
users_embeddings = df.set_index("user_id")["embeddings"].to_dict()
labels = df.set_index("user_id")["community"].to_dict()
allsides_scores = df_allsides.set_index("user_id")["allsides_score"].to_dict()
allsides_scores = {
node_id_map[user_id]: score
for user_id, score in allsides_scores.items()
if user_id in node_id_map
}
users_embeddings_tmp = {}
labels_tmp = {}
for user_id, index in node_id_map.items():
users_embeddings_tmp[index] = users_embeddings[user_id]
labels_tmp[index] = labels[user_id]
users_embeddings = users_embeddings_tmp
# print(users_embeddings)
# Normalization
normalized_array = {key: (users_embeddings[key] - np.min(users_embeddings[key])) / (np.max(users_embeddings[key]) - np.min(users_embeddings[key])) for key in users_embeddings.keys()}
# Printing normalized array
# for key, value in normalized_array.items():
#print(f"{key}: {value}")
# print(normalized_array)
users_embeddings = normalized_array
# print(users_embeddings[0, :])
labels = labels_tmp
labels = np.array(list(labels.values()))
return G, users_embeddings, labels, allsides_scores, node_id_map
For completeness, I give you also an image and the .gml file (anonymous version) of the graph I'm using to test your code:
graph [
directed 1
node [
id 0
label "User1"
]
node [
id 1
label "User2"
]
node [
id 2
label "User3"
]
node [
id 3
label "User4"
]
edge [
source 0
target 1
tweet "Tweet2"
iteration 5
]
edge [
source 0
target 2
tweet "Tweet1"
iteration 5
]
edge [
source 1
target 0
tweet "Tweet1"
iteration 4
]
edge [
source 2
target 3
tweet "Tweet3"
iteration 4
]
]
Thanks again for the help and let me know if you need any other info :)
Hi @faalatawi, Do you have any news?
Thanks :)
I'm sorry I did not have time to look into this issue. However, I think your graph is too small for this algorithm.
I think the main problem is the user embeddings are not the same size for all the nodes. Make sure that the embedding is the same shape for all the nodes.
By the way, the algorithm does not need user embeddings. To experiment, you could just pass a graph without the embeddings. You will get node embeddings.
So my advice is to make sure that the embedding is the same size for all nodes.
I will look more into this problem later this week. I'm sorry this time is a little bit busy for me.
Hi all, I'm trying to replicate your work about Echo Chamber Score on some datasets I have. I'm getting the following error:
Can you suggest me how to solve it?
Thanks in advance.