faalatawi / echo-chamber-score

ECS (Echo Chamber Score) is a method to measure the echo chamber and polarization in social media.
MIT License
5 stars 0 forks source link

Issue on Forward function #2

Open giammy677dev opened 10 months ago

giammy677dev commented 10 months ago

Hi all, I'm trying to replicate your work about Echo Chamber Score on some datasets I have. I'm getting the following error:

Traceback (most recent call last):
  File "C:\Users\ogiam\PycharmProjects\ECS\1_echo_chamber_score.py", line 88, in <module>
    user_emb = EchoGAE_algorithm(G, user_embeddings=users_embeddings, show_progress=False, hidden_channels=20, out_channels=10, epochs=300)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\src\EchoGAE.py", line 50, in EchoGAE_algorithm
    model, x, train_pos_edge_index = run(
                                     ^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\src\GAE.py", line 66, in run
    loss = __train(model, optimizer, x, train_pos_edge_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\src\GAE.py", line 26, in __train
    z = model.encode(x, train_pos_edge_index)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\torch_geometric\nn\models\autoencoder.py", line 80, in encode
    return self.encoder(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\src\GAE.py", line 18, in forward
    x = self.conv1(x, edge_index).relu()
        ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\torch_geometric\nn\conv\gcn_conv.py", line 224, in forward
    edge_index, edge_weight, x.size(self.node_dim),
                             ^^^^^^^^^^^^^^^^^^^^^
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)

Can you suggest me how to solve it?

Thanks in advance.

faalatawi commented 10 months ago

Maybe the error is related to user_embeddings. It should be an numpy array of shape (n, m) where n is number of nodes and m is the dimensions of user embedding. The embedding values should be between 0 and 1.

Try to normalize the embedding. If it doesn't work, give me more information about the user embeddings:

giammy677dev commented 10 months ago

Hi @faalatawi and thanks for the rapid response :)

I normalized the user_embeddings as you said and it seems it go to the next line of code. Now I'm getting this error:

Traceback (most recent call last):
  File "C:\Users\ogiam\PycharmProjects\ECS\1_echo_chamber_score.py", line 97, in <module>
    user_emb = EchoGAE_algorithm(G, user_embeddings=users_embeddings, show_progress=False, hidden_channels=20, out_channels=10, epochs=300)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\src\EchoGAE.py", line 49, in EchoGAE_algorithm
    model, x, train_pos_edge_index = run(
                                     ^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\src\GAE.py", line 68, in run
    auc, ap = __test(
              ^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\src\GAE.py", line 43, in __test
    return model.test(z, pos_edge_index, neg_edge_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\torch_geometric\nn\models\autoencoder.py", line 136, in test
    return roc_auc_score(y, pred), average_precision_score(y, pred)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 214, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\sklearn\metrics\_ranking.py", line 605, in roc_auc_score
    y_true = check_array(y_true, ensure_2d=False, dtype=None)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ogiam\PycharmProjects\ECS\venv\Lib\site-packages\sklearn\utils\validation.py", line 967, in check_array
    raise ValueError(
ValueError: Found array with 0 sample(s) (shape=(0,)) while a minimum of 1 is required.

To be clear, in the load_data.py, you convert user_embeddings to a dictionary:

users_embeddings = df.set_index("user_id")["embeddings"].to_dict()

The following is the users_embedding I got after normalization, as you suggested:

{0: array([0.5903886 , 0.5463845 , 0.6314844 , 0.2994573 , 0.53178114,
       0.72305346, 0.3940116 , 0.5983963 , 0.36165282, 0.46718055,
       0.371573  , 0.86012286, 0.75624   , 0.65208215, 0.31125316,
       0.56961894, 0.36351144, 0.40996528, 0.23478517, 0.77074856,
       0.21189088, 0.6337178 , 0.53798616, 0.49903655, 0.37051472,
       0.3854286 , 0.4718904 , 0.5163373 , 0.19723426, 0.39895633,
       0.8976927 , 0.21628626, 0.3593885 , 0.48840365, 0.48226994,
       0.43893492, 0.48543227, 0.5892171 , 0.38826406, 0.40871355,
       0.59276766, 0.08683099, 0.30721337, 0.5162417 , 0.6030433 ,
       0.57679516, 0.37648603, 0.5437432 , 0.7303272 , 0.27150497,
       0.33576313, 0.49856034, 0.57819676, 0.4675867 , 0.33486864,
       0.5247566 , 0.6328391 , 0.6870526 , 0.38876027, 0.4387255 ,
       0.5059428 , 0.4151606 , 0.56148404, 0.45583224, 0.26991713,
       0.5404365 , 0.36613652, 0.50812584, 0.27858657, 0.77044636,
       0.44390696, 0.7116696 , 0.74979055, 0.53280395, 0.5741853 ,
       0.50218266, 0.5360777 , 0.53478837, 1.        , 0.41538575,
       0.9712991 , 0.6728056 , 0.63580185, 0.41534263, 0.58824044,
       0.39682275, 0.35553297, 0.12713297, 0.4352373 , 0.4530455 ,
       0.21358362, 0.44481987, 0.4569379 , 0.6189847 , 0.41602683,
       0.5464685 , 0.44565925, 0.75979394, 0.        , 0.5848383 ,
       0.02218434, 0.3135237 , 0.29259095, 0.1951891 , 0.59517765,
       0.46691978, 0.64640707, 0.34184644, 0.4921255 , 0.5788995 ,
       0.6034956 , 0.6086504 , 0.26671427, 0.35178533, 0.7536429 ,
       0.32784587, 0.39693376, 0.36548433, 0.3502349 , 0.41376254,
       0.6751696 , 0.32571805, 0.29139057, 0.5199206 , 0.7914655 ,
       0.35013318, 0.40610346, 0.474595  , 0.715731  , 0.55157787,
       0.69133556, 0.6740565 , 0.2891    , 0.58775306, 0.62619764,
       0.22411384, 0.52803564, 0.5504468 , 0.42701572, 0.49767405,
       0.58265716, 0.8593093 , 0.19029467, 0.29489756, 0.31142318,
       0.58041364, 0.35520896, 0.32476845, 0.49472684, 0.5327737 ,
       0.5418649 , 0.63241345, 0.7387447 , 0.22186379, 0.69398755,
       0.49795955, 0.44908583, 0.54026526, 0.25971207, 0.6299726 ,
       0.36077514, 0.67199713, 0.5334661 , 0.1798632 , 0.39948094,
       0.55486876, 0.43986547, 0.22933358, 0.7387832 , 0.92368567,
       0.29325098, 0.58843464, 0.75875854, 0.5492905 , 0.5206712 ,
       0.47314996, 0.37598154, 0.36545646, 0.5015214 , 0.6523398 ,
       0.56226677, 0.596226  , 0.4993254 , 0.3105751 , 0.3684766 ,
       0.5790526 , 0.42430234, 0.21187836, 0.47814125, 0.65961665,
       0.41557905, 0.50058234, 0.55442435, 0.42274082, 0.1365884 ,
       0.1645132 , 0.29398423, 0.6055696 , 0.77949506, 0.35396424,
       0.15519348, 0.26841366, 0.28413072, 0.40042803, 0.7029442 ,
       0.14210477, 0.6940655 , 0.19428387, 0.11259534, 0.4314326 ,
       0.8772073 , 0.5604053 , 0.54182166, 0.38905093, 0.57011956,
       0.43929538, 0.6414684 , 0.50285524, 0.42010292, 0.5579324 ,
       0.72515714, 0.5751192 , 0.10681102, 0.474595  , 0.38437125,
       0.4455422 , 0.5125746 , 0.70831937, 0.5460206 , 0.57248586,
       0.34819818, 0.52376235, 0.5426327 , 0.33171818, 0.38510123,
       0.6092513 , 0.6695501 , 0.67054087, 0.31687284, 0.527404  ,
       0.5847972 , 0.3198751 , 0.4546798 , 0.42352444, 0.70881397,
       0.9764365 , 0.40056378, 0.7918312 , 0.3442903 , 0.45757028,
       0.6102687 , 0.55319154, 0.71293557, 0.17177549, 0.18230942,
       0.66884625, 0.7329294 , 0.46095172, 0.4112841 , 0.8137524 ,
       0.69304144, 0.16537918, 0.39652187, 0.38389784, 0.51303506,
       0.07949679, 0.34835663, 0.19845116, 0.51711726, 0.5326349 ,
       0.54591066, 0.57445216, 0.4025225 , 0.53146696, 0.35923636,
       0.47883376, 0.22849476, 0.75719845, 0.28265658, 0.37521958,
       0.40089855, 0.33905256, 0.35004425, 0.6059098 , 0.33998072,
       0.4290143 , 0.45183882, 0.36477107, 0.58499044, 0.664054  ,
       0.5714948 , 0.33190066, 0.675558  , 0.27499285, 0.39646444,
       0.1741357 , 0.43009242, 0.41168192, 0.6453223 , 0.4173889 ,
       0.2664299 , 0.48379079, 0.59220153, 0.7751846 , 0.41976035,
       0.02754701, 0.47559944, 0.4104384 , 0.36888206, 0.530058  ,
       0.5023671 , 0.13987543, 0.33606586, 0.7380448 , 0.3089485 ,
       0.50874645, 0.25042897, 0.4580534 , 0.5849977 , 0.4745948 ,
       0.7601535 , 0.33843175, 0.3520388 , 0.45736083, 0.4769383 ,
       0.6366737 , 0.01269261, 0.35455778, 0.73761207, 0.7977838 ,
       0.8040379 , 0.59207374, 0.549212  , 0.37760028, 0.46691293,
       0.53902924, 0.34738073, 0.5861576 , 0.25502285, 0.5389182 ,
       0.3297547 , 0.583315  , 0.32766637, 0.569069  , 0.8713058 ,
       0.3744244 , 0.22712737, 0.44884762, 0.44120434, 0.34806576,
       0.2522678 , 0.46934766, 0.30780926, 0.5544345 , 0.43582988,
       0.44257188, 0.36943206, 0.5749146 , 0.7811406 , 0.42525542,
       0.7214907 , 0.5741349 , 0.4951715 , 0.39221266, 0.26793188,
       0.24854879, 0.32631084, 0.77015775, 0.5320774 , 0.23697053,
       0.08882525, 0.54427993, 0.5777584 , 0.4866187 , 0.61439764,
       0.6713501 , 0.542743  , 0.5242594 , 0.65211165, 0.5329402 ,
       0.50373816, 0.3688049 , 0.35005218, 0.4660658 ], dtype=float32), 1: array([0.5504856 , 0.5994761 , 0.7668127 , 0.40479252, 0.6021334 ,
       0.65550673, 0.49918935, 0.5916094 , 0.38146687, 0.51254594,
       0.5037622 , 0.82190436, 0.69192666, 0.7278631 , 0.40627387,
       0.6097601 , 0.47449362, 0.52652866, 0.38925773, 0.79302955,
       0.3318664 , 0.63591677, 0.5368127 , 0.5571479 , 0.4492227 ,
       0.5301211 , 0.4952623 , 0.53548306, 0.33398268, 0.49999142,
       0.85397255, 0.2610326 , 0.483789  , 0.48730516, 0.5092014 ,
       0.47482398, 0.6071238 , 0.6307133 , 0.52939606, 0.58217114,
       0.5603396 , 0.21029054, 0.4193466 , 0.5934032 , 0.7729158 ,
       0.61854917, 0.3926683 , 0.60682935, 0.71347207, 0.4038011 ,
       0.38194218, 0.49573156, 0.636773  , 0.60056764, 0.47029474,
       0.47534442, 0.60573375, 0.5910211 , 0.40359557, 0.47699314,
       0.49011517, 0.46399236, 0.53056073, 0.593339  , 0.41638085,
       0.5195405 , 0.43436658, 0.52034014, 0.41826865, 0.6791163 ,
       0.5859266 , 0.6711173 , 0.6697913 , 0.5201541 , 0.58509666,
       0.6142568 , 0.54922146, 0.55967075, 1.        , 0.6134324 ,
       0.8166271 , 0.5827241 , 0.69778234, 0.46431488, 0.6503128 ,
       0.42055455, 0.48216805, 0.29052272, 0.4626449 , 0.5175474 ,
       0.31328282, 0.4980816 , 0.54858106, 0.7439629 , 0.3890496 ,
       0.6421289 , 0.5331454 , 0.6628377 , 0.19336227, 0.6261516 ,
       0.18801202, 0.32611975, 0.3060235 , 0.26275033, 0.5535011 ,
       0.56534064, 0.62437934, 0.42243594, 0.5483634 , 0.6074568 ,
       0.56832886, 0.6379465 , 0.43839157, 0.47260424, 0.7210698 ,
       0.41984734, 0.4971962 , 0.31881064, 0.47105277, 0.36896485,
       0.64241767, 0.42088464, 0.38107595, 0.5341317 , 0.72590035,
       0.5268472 , 0.63523835, 0.53221124, 0.7368571 , 0.4755117 ,
       0.79182297, 0.70873946, 0.54206204, 0.69074744, 0.6341343 ,
       0.39392188, 0.5516004 , 0.5458956 , 0.47847182, 0.46813655,
       0.5993048 , 0.81980145, 0.35527578, 0.34878874, 0.4191267 ,
       0.5242655 , 0.5142229 , 0.3976383 , 0.5661604 , 0.5192766 ,
       0.57318205, 0.7227642 , 0.8097223 , 0.5119465 , 0.6859499 ,
       0.50233614, 0.53391176, 0.61771804, 0.38389373, 0.6997867 ,
       0.34473568, 0.6623587 , 0.53183115, 0.32129094, 0.36221623,
       0.5237297 , 0.42702368, 0.37889093, 0.66812634, 0.8407371 ,
       0.42116776, 0.5665252 , 0.7053944 , 0.5736386 , 0.6230023 ,
       0.48117504, 0.52796155, 0.39941242, 0.54547876, 0.7272116 ,
       0.54897714, 0.523994  , 0.6533455 , 0.38077196, 0.5519111 ,
       0.66316503, 0.46431178, 0.37412134, 0.5933377 , 0.61004925,
       0.49943402, 0.4806319 , 0.5991593 , 0.51578337, 0.44524017,
       0.24100226, 0.42497262, 0.63646454, 0.78476584, 0.55473834,
       0.26551068, 0.36232346, 0.29171205, 0.4763685 , 0.6601026 ,
       0.34422097, 0.67474854, 0.39421585, 0.34440812, 0.45544848,
       0.8161871 , 0.6905303 , 0.6365791 , 0.5285045 , 0.6099448 ,
       0.54091954, 0.56154066, 0.5089898 , 0.5446285 , 0.643386  ,
       0.6753814 , 0.5724523 , 0.35792008, 0.53221124, 0.41221312,
       0.55671316, 0.5096273 , 0.6705215 , 0.51981735, 0.548849  ,
       0.46995398, 0.64089704, 0.6630525 , 0.5049213 , 0.4652524 ,
       0.7100901 , 0.7665413 , 0.7090804 , 0.38844725, 0.60974854,
       0.6509926 , 0.3871108 , 0.48560327, 0.48087615, 0.64458025,
       0.9014838 , 0.54654825, 0.73597574, 0.44388396, 0.5535791 ,
       0.61524373, 0.651733  , 0.5902918 , 0.3039532 , 0.32333356,
       0.6485903 , 0.6703676 , 0.5282404 , 0.42160285, 0.7147951 ,
       0.69491166, 0.2628424 , 0.47986296, 0.51505476, 0.6353423 ,
       0.31095037, 0.4648145 , 0.41581687, 0.5836198 , 0.6675648 ,
       0.5347364 , 0.6146405 , 0.47276402, 0.603103  , 0.46342447,
       0.55651903, 0.33279827, 0.66820097, 0.42250937, 0.4014731 ,
       0.5020434 , 0.30073088, 0.4639544 , 0.52969277, 0.42189834,
       0.45939073, 0.48187277, 0.46560508, 0.60856736, 0.69641584,
       0.56047803, 0.46933538, 0.69886035, 0.37440833, 0.5953375 ,
       0.38355142, 0.40880513, 0.45499676, 0.600989  , 0.47509772,
       0.4255842 , 0.5842295 , 0.6488961 , 0.8021985 , 0.44763464,
       0.24923256, 0.52268636, 0.48713988, 0.5499488 , 0.5343191 ,
       0.55189824, 0.19829603, 0.43111655, 0.76310563, 0.28973398,
       0.569523  , 0.4592044 , 0.5501238 , 0.5594205 , 0.5322111 ,
       0.84726304, 0.48171085, 0.45779154, 0.5082828 , 0.5513334 ,
       0.64164245, 0.        , 0.4149034 , 0.6931229 , 0.79822356,
       0.75499356, 0.71429014, 0.6479542 , 0.5913222 , 0.6086662 ,
       0.51924384, 0.51713663, 0.5809727 , 0.33458105, 0.47934353,
       0.413319  , 0.6978297 , 0.36808777, 0.5374097 , 0.75435174,
       0.4699698 , 0.33652395, 0.5289879 , 0.48408607, 0.31341746,
       0.27841452, 0.56569135, 0.33595476, 0.5754073 , 0.47615036,
       0.59850264, 0.45045155, 0.53011894, 0.7813892 , 0.46000174,
       0.6293877 , 0.57028383, 0.5635631 , 0.51294714, 0.32131955,
       0.41615418, 0.43640524, 0.6300865 , 0.56392336, 0.24798852,
       0.34310782, 0.54155445, 0.64739555, 0.52224094, 0.6394095 ,
       0.6416259 , 0.62210846, 0.4284557 , 0.70014524, 0.5149384 ,
       0.522662  , 0.37624723, 0.36317435, 0.5684601 ], dtype=float32), 2: array([0.6748316 , 0.4074247 , 0.49105936, 0.38992038, 0.46978638,
       0.6357716 , 0.42606366, 0.560135  , 0.54508513, 0.5517799 ,
       0.3964912 , 0.78342354, 0.8053925 , 0.5808751 , 0.3884107 ,
       0.56678355, 0.37066388, 0.54382163, 0.22759368, 0.70864516,
       0.303966  , 0.60365826, 0.5034732 , 0.52999544, 0.46359798,
       0.2859369 , 0.52990705, 0.5909621 , 0.32454422, 0.38926983,
       0.96485156, 0.37908873, 0.4566861 , 0.6039012 , 0.59427416,
       0.5869093 , 0.57286465, 0.5639263 , 0.48871076, 0.3544438 ,
       0.7356649 , 0.20563804, 0.42331544, 0.5544179 , 0.3522051 ,
       0.6289837 , 0.5469201 , 0.5465153 , 0.7104316 , 0.30148742,
       0.21848506, 0.5294103 , 0.5647303 , 0.51238114, 0.28752846,
       0.68530625, 0.72127384, 0.682508  , 0.49102718, 0.5946483 ,
       0.61741525, 0.46993878, 0.6558486 , 0.3242415 , 0.38155767,
       0.7227691 , 0.51612824, 0.58191395, 0.28775153, 0.6944261 ,
       0.5493823 , 0.67946756, 0.7234838 , 0.54425246, 0.41901666,
       0.5975187 , 0.627099  , 0.52148986, 0.9563764 , 0.2641179 ,
       0.7301533 , 0.67143786, 0.55293715, 0.56072325, 0.49951297,
       0.53529453, 0.48332396, 0.08366083, 0.4848326 , 0.65085953,
       0.39826554, 0.3728425 , 0.45496678, 0.43883774, 0.41533807,
       0.6146676 , 0.56835836, 0.72247344, 0.        , 0.57956624,
       0.27458906, 0.43997896, 0.55270195, 0.20345142, 0.6813145 ,
       0.7134562 , 0.7397493 , 0.31485173, 0.5345157 , 0.5753244 ,
       0.64990854, 0.62176925, 0.47460815, 0.5424521 , 0.8495754 ,
       0.41225788, 0.48881984, 0.54615325, 0.3140851 , 0.40762606,
       0.7431823 , 0.48576972, 0.45932463, 0.70246595, 0.619881  ,
       0.29678798, 0.51290536, 0.5251198 , 0.77270365, 0.65206325,
       0.55996567, 0.63707215, 0.14858927, 0.60790217, 0.56439173,
       0.24563964, 0.60371107, 0.5484797 , 0.4412493 , 0.7600152 ,
       0.5376845 , 0.79639024, 0.2144368 , 0.5354967 , 0.24194936,
       0.58767736, 0.4875134 , 0.42291513, 0.55643535, 0.43211702,
       0.6417125 , 0.56925964, 0.6627311 , 0.17544743, 0.7613894 ,
       0.6579407 , 0.57729125, 0.5627387 , 0.3440774 , 0.63208085,
       0.5659609 , 0.7209287 , 0.6067615 , 0.40275   , 0.5164946 ,
       0.8186469 , 0.6339206 , 0.20342632, 0.8481145 , 0.82366467,
       0.56067413, 0.71509784, 0.80045533, 0.7285235 , 0.4451461 ,
       0.6246485 , 0.45633486, 0.42407495, 0.43958852, 0.58433235,
       0.69412637, 0.5768258 , 0.32418138, 0.454048  , 0.3948432 ,
       0.6179795 , 0.50213104, 0.23949648, 0.48517033, 0.7483972 ,
       0.36877486, 0.61136216, 0.44356537, 0.54707456, 0.10315315,
       0.29949546, 0.2710408 , 0.5831354 , 0.707144  , 0.4151    ,
       0.24692956, 0.39056236, 0.35587728, 0.52209836, 0.7207828 ,
       0.34318405, 0.8759284 , 0.18325937, 0.03065613, 0.49207333,
       0.8686864 , 0.5794197 , 0.6828631 , 0.36431122, 0.6268849 ,
       0.5038582 , 0.66722894, 0.63037086, 0.3770932 , 0.48218024,
       0.81755686, 0.6661568 , 0.12384431, 0.5251198 , 0.44799998,
       0.3825945 , 0.71873266, 0.9796641 , 0.62351716, 0.6530301 ,
       0.33182004, 0.57896644, 0.45177144, 0.41930673, 0.3770389 ,
       0.431782  , 0.6692333 , 0.5824252 , 0.6155698 , 0.40630376,
       0.6964858 , 0.32052276, 0.4700297 , 0.4917665 , 0.8633192 ,
       0.82269835, 0.3568557 , 0.985096  , 0.38960674, 0.51275676,
       0.53697646, 0.39290482, 0.5948498 , 0.40631393, 0.4710755 ,
       0.75382733, 0.8678897 , 0.5860656 , 0.50644547, 0.6931877 ,
       0.827078  , 0.29435402, 0.39100865, 0.5259532 , 0.50417584,
       0.18836886, 0.45406368, 0.16322745, 0.3951799 , 0.48666334,
       0.43524384, 0.4730288 , 0.23024385, 0.50140816, 0.5175587 ,
       0.48539448, 0.21951239, 0.76472694, 0.36810192, 0.6813255 ,
       0.41129318, 0.57596487, 0.63458204, 0.8033597 , 0.36706468,
       0.45644128, 0.50169414, 0.3441729 , 0.60120344, 0.6332276 ,
       0.63546056, 0.4909976 , 0.5266029 , 0.58087826, 0.4846154 ,
       0.16037425, 0.62975156, 0.3920663 , 0.6932027 , 0.5640802 ,
       0.22320414, 0.3731415 , 0.6093587 , 0.6346745 , 0.6098801 ,
       0.16370593, 0.5321208 , 0.36934385, 0.4892848 , 0.66808945,
       0.66030204, 0.23884499, 0.36510858, 0.64053303, 0.54897684,
       0.3590053 , 0.2972839 , 0.34753802, 0.6453297 , 0.5251196 ,
       0.6814555 , 0.39543784, 0.44776392, 0.68364346, 0.4058828 ,
       0.54477197, 0.25028777, 0.3434456 , 0.7371401 , 0.86635256,
       0.686606  , 0.49274966, 0.4893725 , 0.42395604, 0.46989986,
       0.61714303, 0.2987194 , 0.76872516, 0.40346637, 0.67223847,
       0.5341008 , 0.48688528, 0.38268352, 0.666827  , 1.        ,
       0.43178546, 0.2497689 , 0.584798  , 0.69876546, 0.7022596 ,
       0.56139475, 0.535826  , 0.44755712, 0.51189864, 0.55382276,
       0.5954705 , 0.5261249 , 0.627666  , 0.6882771 , 0.4707281 ,
       0.7680467 , 0.76645   , 0.4647652 , 0.404939  , 0.53761715,
       0.29035944, 0.35611305, 0.8760471 , 0.6254243 , 0.38855845,
       0.34764704, 0.5626235 , 0.6937429 , 0.48045602, 0.68960315,
       0.6259925 , 0.58327866, 0.7423123 , 0.5612407 , 0.49681857,
       0.6235863 , 0.39066392, 0.51519734, 0.49413523], dtype=float32), 3: array([0.6542318 , 0.5377808 , 0.39427352, 0.42770526, 0.47166738,
       0.6680032 , 0.44329754, 0.49493238, 0.46237895, 0.57495785,
       0.42989856, 0.7036364 , 0.7427337 , 0.6434955 , 0.50641924,
       0.62676185, 0.41936448, 0.5544305 , 0.2921245 , 0.5734323 ,
       0.2935471 , 0.6822767 , 0.5622412 , 0.62485796, 0.48756942,
       0.22544909, 0.49136087, 0.679961  , 0.35861352, 0.35035822,
       0.8163559 , 0.520737  , 0.4385027 , 0.6636616 , 0.6539404 ,
       0.60592294, 0.49439794, 0.5566711 , 0.43656635, 0.3411835 ,
       0.69705296, 0.3407699 , 0.4261004 , 0.62751544, 0.2921626 ,
       0.5072126 , 0.48436025, 0.55037826, 0.6580973 , 0.32724953,
       0.19243304, 0.62950337, 0.50237507, 0.4228691 , 0.21795171,
       0.6045182 , 0.7464016 , 0.7525598 , 0.54767984, 0.56992143,
       0.5775037 , 0.48859718, 0.74667466, 0.23965628, 0.4473735 ,
       0.6389479 , 0.45452487, 0.58664054, 0.30382884, 0.7368003 ,
       0.56128067, 0.7119723 , 0.6852876 , 0.5870384 , 0.3799207 ,
       0.6283663 , 0.5649955 , 0.6069957 , 0.94719225, 0.22929955,
       0.6486107 , 0.65455157, 0.5466983 , 0.4873031 , 0.40211558,
       0.52241105, 0.45032373, 0.13322203, 0.5490885 , 0.6202621 ,
       0.4511717 , 0.3063856 , 0.42220318, 0.39039963, 0.38010105,
       0.6651868 , 0.5328936 , 0.6131631 , 0.        , 0.52032447,
       0.42317504, 0.5114572 , 0.6621108 , 0.28805304, 0.6526411 ,
       0.7281291 , 0.76102555, 0.32751876, 0.5521735 , 0.5289147 ,
       0.69779736, 0.68449605, 0.4677217 , 0.46570984, 0.74059623,
       0.6687386 , 0.526675  , 0.5870634 , 0.28352338, 0.4921353 ,
       0.7379906 , 0.43511122, 0.50613725, 0.7821665 , 0.6494672 ,
       0.33980587, 0.5046495 , 0.5283351 , 0.8038778 , 0.52949345,
       0.55786353, 0.68334764, 0.27564308, 0.3695339 , 0.6317535 ,
       0.13243376, 0.6835329 , 0.52257854, 0.45744082, 0.8174785 ,
       0.57084996, 0.76096934, 0.15712643, 0.56485105, 0.2633177 ,
       0.5813226 , 0.5298225 , 0.5427938 , 0.586514  , 0.43465352,
       0.6102504 , 0.5177347 , 0.7217669 , 0.21697025, 0.6146182 ,
       0.6744388 , 0.60285586, 0.59281015, 0.26246607, 0.6146374 ,
       0.58215857, 0.6584    , 0.59804606, 0.38756725, 0.52817786,
       0.78565884, 0.70261556, 0.28947195, 0.8879533 , 0.75424474,
       0.56979764, 0.79933727, 0.70712245, 0.6947445 , 0.46565795,
       0.63337314, 0.44335625, 0.5567024 , 0.45160365, 0.4859375 ,
       0.5825016 , 0.47205833, 0.29501644, 0.54978055, 0.4105344 ,
       0.53806275, 0.52932763, 0.30232495, 0.49972665, 0.7279396 ,
       0.2684043 , 0.6308848 , 0.43253362, 0.39378652, 0.11930857,
       0.3749699 , 0.32635775, 0.55045974, 0.70638657, 0.35426024,
       0.21915436, 0.3507852 , 0.43327543, 0.48687625, 0.71844566,
       0.37366402, 0.90822446, 0.1718719 , 0.18446386, 0.55817807,
       0.8028598 , 0.5387753 , 0.6225522 , 0.3418844 , 0.57539505,
       0.41666165, 0.7595317 , 0.5946971 , 0.32267916, 0.4584917 ,
       0.8040467 , 0.6062819 , 0.13489926, 0.5283351 , 0.4586317 ,
       0.473593  , 0.7651325 , 0.98452383, 0.6477622 , 0.6048426 ,
       0.34395784, 0.59981287, 0.3959062 , 0.42500147, 0.4247128 ,
       0.44683853, 0.53576064, 0.6344432 , 0.67739755, 0.37684232,
       0.72056973, 0.4385622 , 0.43668   , 0.6016548 , 0.79325795,
       0.66228557, 0.31493598, 0.97571874, 0.33136997, 0.578599  ,
       0.46057224, 0.29644817, 0.5894011 , 0.34435746, 0.507949  ,
       0.6673075 , 0.8544765 , 0.5088211 , 0.73321015, 0.56988007,
       0.8791954 , 0.46055466, 0.2512141 , 0.54515755, 0.51327956,
       0.2897278 , 0.5352044 , 0.19738676, 0.33345237, 0.47579262,
       0.47217768, 0.45885113, 0.32030272, 0.54774165, 0.4643789 ,
       0.50881547, 0.16045596, 0.7847433 , 0.30525547, 0.7076869 ,
       0.42714575, 0.6004711 , 0.66534877, 0.9246982 , 0.37249017,
       0.470351  , 0.50644106, 0.27928612, 0.7095356 , 0.5893712 ,
       0.65979886, 0.55490106, 0.53423333, 0.5893161 , 0.5291243 ,
       0.18188052, 0.7221448 , 0.41399443, 0.6156911 , 0.51341146,
       0.35419738, 0.40884674, 0.63282645, 0.5193399 , 0.6381396 ,
       0.25020102, 0.43971163, 0.4374941 , 0.44072276, 0.71287936,
       0.62531126, 0.36002475, 0.33606008, 0.5655713 , 0.68675596,
       0.27631754, 0.33216003, 0.5067967 , 0.63783395, 0.5283349 ,
       0.61430395, 0.4524425 , 0.3432733 , 0.6720959 , 0.39899063,
       0.55632484, 0.31208792, 0.30387443, 0.7800658 , 0.8409841 ,
       0.6674822 , 0.5534705 , 0.475604  , 0.41495994, 0.4002237 ,
       0.6579579 , 0.33016756, 0.752389  , 0.3981897 , 0.67810756,
       0.5619784 , 0.44760135, 0.41618544, 0.6703766 , 1.        ,
       0.46168402, 0.245408  , 0.65896124, 0.8166215 , 0.7485484 ,
       0.594238  , 0.6279917 , 0.32763523, 0.5644418 , 0.6252166 ,
       0.4871536 , 0.6665477 , 0.6146433 , 0.686448  , 0.49793965,
       0.7148773 , 0.75429183, 0.47640997, 0.38003728, 0.5814651 ,
       0.3123399 , 0.32535496, 0.7765912 , 0.58655864, 0.5247335 ,
       0.27610427, 0.5483681 , 0.71346617, 0.5215375 , 0.71260816,
       0.5932541 , 0.57799846, 0.7147718 , 0.5404635 , 0.445882  ,
       0.69863516, 0.37483338, 0.59227234, 0.5595903 ], dtype=float32)}

What am I missing? Thanks a lot for the help!

faalatawi commented 10 months ago

Thank you so much for pointing out this issue. I should have documented how to use this code better. After fixing your issue I will go back and document it much better.

I need more information about the graph and the users_embedding. Could you please run this code?:

import networkx as nx
import numpy as np

G = # Your Graph
users_embeddings = # Your embeddings

# The graph
print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")
print(f"Is directed?: {G.is_directed()}")
print(f"Is connected?: {nx.is_connected(G)}")

nodes = G.nodes(data=True)
print(f"Nodes: {nodes[0]}")

print(f"Length of users_embeddings: {len(users_embeddings)}")
print(f"Type of users_embeddings: {type(users_embeddings)}")

i = users_embeddings.keys()[0]
first_embedding_shape = users_embeddings[i].shape
print(f"Shape of the first embedding: {first_embedding_shape}")

# Check if all the embeddings have the same shape
ans = all(emb == first_embedding_shape for emb in users_embeddings.values())
print(f"Are all the embeddings of the same shape?: {ans}")
giammy677dev commented 10 months ago

Hi @faalatawi, I experimented with your code using a very dummy dataset. The graph is quite small, with 4 nodes and 4 edges. Initially, my graph is directed. However, after executing the following lines of code in your "get_data" function, it becomes an undirected graph:

G = nx.read_gml(dataset_path + "graph.gml")
    if nx.is_directed(G):
        G = G.to_undirected()

So, after obtaining the user embeddings, I executed the "verification code" you posted in your last reply. I made a few minor modifications as there were some errors. Here is the code I ran:

import networkx as nx
import numpy as np

def alatawi_check(G, users_embeddings):

    # The graph
    print(f"Number of nodes: {G.number_of_nodes()}")
    print(f"Number of edges: {G.number_of_edges()}")
    print(f"Is directed?: {G.is_directed()}")
    print(f"Is connected?: {nx.is_connected(G)}")

    nodes = G.nodes(data=True)
    print(nodes)
    print(f"Nodes: {nodes[0]}")

    print(f"Length of users_embeddings: {len(users_embeddings)}")
    print(f"Type of users_embeddings: {type(users_embeddings)}")

    i = list(users_embeddings.keys())[0]
    first_embedding_shape = users_embeddings[i].shape
    print(f"Shape of the first embedding: {first_embedding_shape}")

    # Check if all the embeddings have the same shape
    ans = all(np.array_equal(emb, first_embedding_shape) for emb in users_embeddings.values())
    print(f"Are all the embeddings of the same shape?: {ans}")

Here are the results I get from this execution:

Number of nodes: 4
Number of edges: 3
Is directed?: False
Is connected?: True
[(0, {}), (1, {}), (2, {}), (3, {})]
Nodes: {}
Length of users_embeddings: 4
Type of users_embeddings: <class 'dict'>
Shape of the first embedding: (384,)
Are all the embeddings of the same shape?: False

Additional Infos

Modifications in load_data function

I reviewed your response in thread #1 and observed that it's essential to group together all the tweets belonging to the same user to ensure the proper functioning of your code. I incorporated this as a preprocessing phase in my original dataset and subsequently modified the "preprocess_tweets" function.

Additionally, I included the normalization of the users' embeddings, as you mentioned in your initial response in this thread. However, I noticed that in your example dataset, such as the "gun" one, the embeddings are not normalized. For this reason, I am uncertain whether normalization is a necessary step for your code to work correctly.

To facilitate debugging, I am providing the complete modified code for the "get_data" function:

import networkx as nx
import numpy as np
import pandas as pd
from networkx.algorithms.community.louvain import louvain_communities
from sentence_transformers import SentenceTransformer
from .tweet_preprocessing import preprocess_tweet_for_bert
from sklearn.utils import check_random_state
import torch
import random

import os

from tqdm import tqdm

tqdm.pandas()

# disable warnings
import warnings

warnings.filterwarnings("ignore")

def get_data(dataset_path: str, louvain_resolution=0.05, SBert_model_name="all-MiniLM-L6-v2"):
    # Set the random seed for python
    # seed = 42
    # random.seed(seed)
    # np.random.seed(seed)
    # pd.np.random.seed(seed)
    # check_random_state(seed)
    # torch.manual_seed(seed)

    # 1: Get the Graph G
    G = nx.read_gml(dataset_path + "graph.gml")
    if nx.is_directed(G):
        G = G.to_undirected()

    # 2: Get the user embeddings
    df = pd.read_feather(dataset_path + "tweets.feather")

    if os.path.exists(dataset_path + "embeddings.feather"):
        print("Loading embeddings")
        df_emb = pd.read_feather(dataset_path + "embeddings.feather")
        df = df.merge(df_emb, on="user_id", how="inner")
    else:
        print("No embeddings found, creating them")

        def preprocess_tweets(tweets):
            # print(tweets)
            out = []
            #for tw in tweets:
                # print(tw)
            tw = preprocess_tweet_for_bert(tweets)
            # print(tw)
            if len(tw) > 1:
                out.append(" ".join(tw))
            return out

        print("Preprocessing tweets")
        # print(df["tweets"])
        # print(df.tweets)
        df["tweets"] = df.tweets.progress_apply(preprocess_tweets)
        # print(df["tweets"])
        # Remove users with no tweets
        df = df[df.tweets.apply(len) > 0]

        model = SentenceTransformer(SBert_model_name)

        def embed_user_tweets(tweets):
            emb = model.encode(tweets)
            emb = np.mean(emb, axis=0)
            return emb

        print("Embedding tweets")
        df["embeddings"] = df.tweets.progress_apply(embed_user_tweets)

        df_emb = df[["user_id", "embeddings"]]
        df_emb.reset_index(drop=True, inplace=True)
        df_emb.to_feather(dataset_path + "embeddings.feather")

    # 3: Filter the graph
    G = G.subgraph(df.user_id)
    print(f'Graph: {G}')
    connected_components = list(nx.connected_components(G))
    lcc_nodes = max(nx.connected_components(G), key=len)
    G = G.subgraph(lcc_nodes)
    df = df[df.user_id.isin(G.nodes())]

    # 5: Find the communities
    community = louvain_communities(G, resolution=louvain_resolution, seed=42)
    community = list(community)

    def which_community(node):
        for i, c in enumerate(community):
            if node in c:
                return i
        return -1

    df["community"] = df.user_id.apply(which_community)

    # allsides
    df_allsides = pd.read_feather(dataset_path + "allsides.feather")

    # 4: Make a map from user_id to index
    node_id_map = {node: i for i, node in enumerate(G.nodes())}
    G = nx.relabel_nodes(G, node_id_map)

    # 6: Get the embeddings and labels
    print(df.set_index("user_id")["embeddings"])
    users_embeddings = df.set_index("user_id")["embeddings"].to_dict()
    labels = df.set_index("user_id")["community"].to_dict()
    allsides_scores = df_allsides.set_index("user_id")["allsides_score"].to_dict()
    allsides_scores = {
        node_id_map[user_id]: score
        for user_id, score in allsides_scores.items()
        if user_id in node_id_map
    }

    users_embeddings_tmp = {}
    labels_tmp = {}

    for user_id, index in node_id_map.items():
        users_embeddings_tmp[index] = users_embeddings[user_id]
        labels_tmp[index] = labels[user_id]

    users_embeddings = users_embeddings_tmp
    # print(users_embeddings)
    # Normalization
    normalized_array = {key: (users_embeddings[key] - np.min(users_embeddings[key])) / (np.max(users_embeddings[key]) - np.min(users_embeddings[key])) for key in users_embeddings.keys()}

    # Printing normalized array
    # for key, value in normalized_array.items():
        #print(f"{key}: {value}")
    # print(normalized_array)
    users_embeddings = normalized_array
    # print(users_embeddings[0, :])
    labels = labels_tmp
    labels = np.array(list(labels.values()))

    return G, users_embeddings, labels, allsides_scores, node_id_map

Graph and .gml

For completeness, I give you also an image and the .gml file (anonymous version) of the graph I'm using to test your code:

Figure_1

graph [
  directed 1
  node [
    id 0
    label "User1"
  ]
  node [
    id 1
    label "User2"
  ]
  node [
    id 2
    label "User3"
  ]
  node [
    id 3
    label "User4"
  ]
  edge [
    source 0
    target 1
    tweet "Tweet2"
    iteration 5
  ]
  edge [
    source 0
    target 2
    tweet "Tweet1"
    iteration 5
  ]
  edge [
    source 1
    target 0
    tweet "Tweet1"
    iteration 4
  ]
  edge [
    source 2
    target 3
    tweet "Tweet3"
    iteration 4
  ]
]

Thanks again for the help and let me know if you need any other info :)

giammy677dev commented 10 months ago

Hi @faalatawi, Do you have any news?

Thanks :)

faalatawi commented 10 months ago

I'm sorry I did not have time to look into this issue. However, I think your graph is too small for this algorithm.

I think the main problem is the user embeddings are not the same size for all the nodes. Make sure that the embedding is the same shape for all the nodes.

By the way, the algorithm does not need user embeddings. To experiment, you could just pass a graph without the embeddings. You will get node embeddings.

So my advice is to make sure that the embedding is the same size for all nodes.

I will look more into this problem later this week. I'm sorry this time is a little bit busy for me.