google-deepmind / neural_networks_chomsky_hierarchy

Neural Networks and the Chomsky Hierarchy
https://arxiv.org/abs/2207.02098
Apache License 2.0
187 stars 18 forks source link

Reproducing results and arguments, expanding on issue #2 #7

Closed kirk86 closed 9 months ago

kirk86 commented 11 months ago

Hi, I've tried to reproduce the results on the task of even pairs similar to what another user was asking on issue #2. I've noticed you've posted the hyperparameters to reproduce the results but some arguments are not callable directly from the command line, e.g. act_tau where/how do we set this argument? Even after setting most hyperparameters according to the values you've provided I still can't seem to get the results reported in the paper. Any ideas what I'm doing wrong/missing?

For instance, to reproduce I call the example.py script with the following params: --architecture transformer_encoder \ --task even_pairs \ --ckpt_frequency 10000 \ --computation_steps_mult 0 \ --compute_full_range_test True \ --curriculum uniform \ --curriculum_kwargs.training_range 40 \ --eval_log_frequency 500 \ --is_autoregressive False \ --learning_rate 0.0005 \ --max_range_test_length 500 \ --model_init_seed 3 \ --range_test_sub_batch_size 32 \ --range_test_total_batch_size 512 \ --save_frequency 100000 \ --task even_pairs \ --task_level regular \ --training_data_seed 1 \ --training_log_frequency 500 \ --training_steps 1000000 \ --validation_batch_size 128 \ --validation_data_seed 1 \ --validation_lengths [60, 200, 500] \ --batch_size 128

Then, I get the following results, which to my interpretation is quite bad, assuming the Network score: 0.494 corresponds to network's accuracy on the even pairs task. Any suggestions?

I1229 13:14:48.570926 137960357544576 range_evaluation.py:103] {'length': 1, 'accuracy': 0.85546875}
I1229 13:14:52.590799 137960357544576 range_evaluation.py:103] {'length': 2, 'accuracy': 0.5703125}
I1229 13:14:57.770567 137960357544576 range_evaluation.py:103] {'length': 3, 'accuracy': 0.53125}
I1229 13:15:01.790874 137960357544576 range_evaluation.py:103] {'length': 4, 'accuracy': 0.4921875}
I1229 13:15:05.636651 137960357544576 range_evaluation.py:103] {'length': 5, 'accuracy': 0.501953125}
I1229 13:15:10.938885 137960357544576 range_evaluation.py:103] {'length': 6, 'accuracy': 0.498046875}
I1229 13:15:14.831170 137960357544576 range_evaluation.py:103] {'length': 7, 'accuracy': 0.498046875}
I1229 13:15:18.885272 137960357544576 range_evaluation.py:103] {'length': 8, 'accuracy': 0.49609375}
I1229 13:15:24.155033 137960357544576 range_evaluation.py:103] {'length': 9, 'accuracy': 0.51953125}
I1229 13:15:28.306296 137960357544576 range_evaluation.py:103] {'length': 10, 'accuracy': 0.5078125}
I1229 13:15:32.190845 137960357544576 range_evaluation.py:103] {'length': 11, 'accuracy': 0.4921875}
I1229 13:15:37.482550 137960357544576 range_evaluation.py:103] {'length': 12, 'accuracy': 0.48046875}
I1229 13:15:41.339636 137960357544576 range_evaluation.py:103] {'length': 13, 'accuracy': 0.515625}
I1229 13:15:45.471857 137960357544576 range_evaluation.py:103] {'length': 14, 'accuracy': 0.498046875}
I1229 13:15:50.920557 137960357544576 range_evaluation.py:103] {'length': 15, 'accuracy': 0.5}
I1229 13:15:55.305159 137960357544576 range_evaluation.py:103] {'length': 16, 'accuracy': 0.470703125}
I1229 13:15:59.485405 137960357544576 range_evaluation.py:103] {'length': 17, 'accuracy': 0.52734375}
I1229 13:16:05.230954 137960357544576 range_evaluation.py:103] {'length': 18, 'accuracy': 0.48828125}
I1229 13:16:09.539948 137960357544576 range_evaluation.py:103] {'length': 19, 'accuracy': 0.521484375}
I1229 13:16:13.983446 137960357544576 range_evaluation.py:103] {'length': 20, 'accuracy': 0.498046875}
I1229 13:16:19.752009 137960357544576 range_evaluation.py:103] {'length': 21, 'accuracy': 0.49609375}
I1229 13:16:24.136376 137960357544576 range_evaluation.py:103] {'length': 22, 'accuracy': 0.45703125}
I1229 13:16:28.429387 137960357544576 range_evaluation.py:103] {'length': 23, 'accuracy': 0.515625}
I1229 13:16:33.991695 137960357544576 range_evaluation.py:103] {'length': 24, 'accuracy': 0.498046875}
I1229 13:16:38.159836 137960357544576 range_evaluation.py:103] {'length': 25, 'accuracy': 0.53515625}
I1229 13:16:42.974402 137960357544576 range_evaluation.py:103] {'length': 26, 'accuracy': 0.525390625}
I1229 13:16:47.857553 137960357544576 range_evaluation.py:103] {'length': 27, 'accuracy': 0.49609375}
I1229 13:16:52.188318 137960357544576 range_evaluation.py:103] {'length': 28, 'accuracy': 0.5390625}
I1229 13:16:58.110971 137960357544576 range_evaluation.py:103] {'length': 29, 'accuracy': 0.498046875}
I1229 13:17:02.604572 137960357544576 range_evaluation.py:103] {'length': 30, 'accuracy': 0.501953125}
I1229 13:17:06.803083 137960357544576 range_evaluation.py:103] {'length': 31, 'accuracy': 0.521484375}
I1229 13:17:12.299283 137960357544576 range_evaluation.py:103] {'length': 32, 'accuracy': 0.494140625}
I1229 13:17:16.324559 137960357544576 range_evaluation.py:103] {'length': 33, 'accuracy': 0.49609375}
I1229 13:17:20.660620 137960357544576 range_evaluation.py:103] {'length': 34, 'accuracy': 0.46484375}
I1229 13:17:26.030283 137960357544576 range_evaluation.py:103] {'length': 35, 'accuracy': 0.451171875}
I1229 13:17:30.409441 137960357544576 range_evaluation.py:103] {'length': 36, 'accuracy': 0.498046875}
I1229 13:17:34.422365 137960357544576 range_evaluation.py:103] {'length': 37, 'accuracy': 0.478515625}
I1229 13:17:39.923307 137960357544576 range_evaluation.py:103] {'length': 38, 'accuracy': 0.490234375}
I1229 13:17:43.979161 137960357544576 range_evaluation.py:103] {'length': 39, 'accuracy': 0.5}
I1229 13:17:48.113711 137960357544576 range_evaluation.py:103] {'length': 40, 'accuracy': 0.4765625}
I1229 13:17:54.553531 137960357544576 range_evaluation.py:103] {'length': 41, 'accuracy': 0.5}
I1229 13:17:59.742008 137960357544576 range_evaluation.py:103] {'length': 42, 'accuracy': 0.455078125}
I1229 13:18:05.911395 137960357544576 range_evaluation.py:103] {'length': 43, 'accuracy': 0.50390625}
I1229 13:18:11.101922 137960357544576 range_evaluation.py:103] {'length': 44, 'accuracy': 0.482421875}
I1229 13:18:16.174727 137960357544576 range_evaluation.py:103] {'length': 45, 'accuracy': 0.513671875}
I1229 13:18:22.604817 137960357544576 range_evaluation.py:103] {'length': 46, 'accuracy': 0.482421875}
I1229 13:18:27.482460 137960357544576 range_evaluation.py:103] {'length': 47, 'accuracy': 0.486328125}
I1229 13:18:33.887342 137960357544576 range_evaluation.py:103] {'length': 48, 'accuracy': 0.45703125}
I1229 13:18:38.783443 137960357544576 range_evaluation.py:103] {'length': 49, 'accuracy': 0.533203125}
I1229 13:18:44.269059 137960357544576 range_evaluation.py:103] {'length': 50, 'accuracy': 0.466796875}
I1229 13:18:50.191066 137960357544576 range_evaluation.py:103] {'length': 51, 'accuracy': 0.513671875}
I1229 13:18:55.434187 137960357544576 range_evaluation.py:103] {'length': 52, 'accuracy': 0.525390625}
I1229 13:19:01.725342 137960357544576 range_evaluation.py:103] {'length': 53, 'accuracy': 0.509765625}
I1229 13:19:06.966732 137960357544576 range_evaluation.py:103] {'length': 54, 'accuracy': 0.455078125}
I1229 13:19:12.832618 137960357544576 range_evaluation.py:103] {'length': 55, 'accuracy': 0.48828125}
I1229 13:19:18.343561 137960357544576 range_evaluation.py:103] {'length': 56, 'accuracy': 0.53125}
I1229 13:19:23.306194 137960357544576 range_evaluation.py:103] {'length': 57, 'accuracy': 0.501953125}
I1229 13:19:30.069247 137960357544576 range_evaluation.py:103] {'length': 58, 'accuracy': 0.490234375}
I1229 13:19:35.114671 137960357544576 range_evaluation.py:103] {'length': 59, 'accuracy': 0.4609375}
I1229 13:19:41.482457 137960357544576 range_evaluation.py:103] {'length': 60, 'accuracy': 0.51171875}
I1229 13:19:46.581363 137960357544576 range_evaluation.py:103] {'length': 61, 'accuracy': 0.458984375}
I1229 13:19:53.469707 137960357544576 range_evaluation.py:103] {'length': 62, 'accuracy': 0.52734375}
I1229 13:19:58.645372 137960357544576 range_evaluation.py:103] {'length': 63, 'accuracy': 0.521484375}
I1229 13:20:03.946299 137960357544576 range_evaluation.py:103] {'length': 64, 'accuracy': 0.46484375}
I1229 13:20:10.534994 137960357544576 range_evaluation.py:103] {'length': 65, 'accuracy': 0.51171875}
I1229 13:20:15.717426 137960357544576 range_evaluation.py:103] {'length': 66, 'accuracy': 0.50390625}
I1229 13:20:22.130596 137960357544576 range_evaluation.py:103] {'length': 67, 'accuracy': 0.486328125}
I1229 13:20:27.386350 137960357544576 range_evaluation.py:103] {'length': 68, 'accuracy': 0.46484375}
I1229 13:20:32.920934 137960357544576 range_evaluation.py:103] {'length': 69, 'accuracy': 0.48046875}
I1229 13:20:40.808766 137960357544576 range_evaluation.py:103] {'length': 70, 'accuracy': 0.51171875}
I1229 13:20:45.946934 137960357544576 range_evaluation.py:103] {'length': 71, 'accuracy': 0.498046875}
I1229 13:20:52.315445 137960357544576 range_evaluation.py:103] {'length': 72, 'accuracy': 0.5546875}
I1229 13:20:57.379544 137960357544576 range_evaluation.py:103] {'length': 73, 'accuracy': 0.466796875}
I1229 13:21:03.938791 137960357544576 range_evaluation.py:103] {'length': 74, 'accuracy': 0.4765625}
I1229 13:21:09.187290 137960357544576 range_evaluation.py:103] {'length': 75, 'accuracy': 0.498046875}
I1229 13:21:15.498011 137960357544576 range_evaluation.py:103] {'length': 76, 'accuracy': 0.46484375}
I1229 13:21:20.657211 137960357544576 range_evaluation.py:103] {'length': 77, 'accuracy': 0.447265625}
I1229 13:21:26.026159 137960357544576 range_evaluation.py:103] {'length': 78, 'accuracy': 0.455078125}
I1229 13:21:32.214586 137960357544576 range_evaluation.py:103] {'length': 79, 'accuracy': 0.49609375}
I1229 13:21:37.564054 137960357544576 range_evaluation.py:103] {'length': 80, 'accuracy': 0.529296875}
I1229 13:21:43.964729 137960357544576 range_evaluation.py:103] {'length': 81, 'accuracy': 0.482421875}
I1229 13:21:49.275739 137960357544576 range_evaluation.py:103] {'length': 82, 'accuracy': 0.517578125}
I1229 13:21:55.442925 137960357544576 range_evaluation.py:103] {'length': 83, 'accuracy': 0.515625}
I1229 13:22:01.015446 137960357544576 range_evaluation.py:103] {'length': 84, 'accuracy': 0.447265625}
I1229 13:22:06.539398 137960357544576 range_evaluation.py:103] {'length': 85, 'accuracy': 0.515625}
I1229 13:22:13.119949 137960357544576 range_evaluation.py:103] {'length': 86, 'accuracy': 0.5234375}
I1229 13:22:18.448452 137960357544576 range_evaluation.py:103] {'length': 87, 'accuracy': 0.515625}
I1229 13:22:25.389026 137960357544576 range_evaluation.py:103] {'length': 88, 'accuracy': 0.48828125}
I1229 13:22:30.793943 137960357544576 range_evaluation.py:103] {'length': 89, 'accuracy': 0.48828125}
I1229 13:22:37.409508 137960357544576 range_evaluation.py:103] {'length': 90, 'accuracy': 0.5625}
I1229 13:22:42.677580 137960357544576 range_evaluation.py:103] {'length': 91, 'accuracy': 0.4921875}
I1229 13:22:48.850975 137960357544576 range_evaluation.py:103] {'length': 92, 'accuracy': 0.470703125}
I1229 13:22:54.618408 137960357544576 range_evaluation.py:103] {'length': 93, 'accuracy': 0.486328125}
I1229 13:23:00.022013 137960357544576 range_evaluation.py:103] {'length': 94, 'accuracy': 0.494140625}
I1229 13:23:06.688524 137960357544576 range_evaluation.py:103] {'length': 95, 'accuracy': 0.462890625}
I1229 13:23:12.773394 137960357544576 range_evaluation.py:103] {'length': 96, 'accuracy': 0.546875}
I1229 13:23:19.476184 137960357544576 range_evaluation.py:103] {'length': 97, 'accuracy': 0.50390625}
I1229 13:23:24.680352 137960357544576 range_evaluation.py:103] {'length': 98, 'accuracy': 0.478515625}
I1229 13:23:30.576612 137960357544576 range_evaluation.py:103] {'length': 99, 'accuracy': 0.494140625}
I1229 13:23:36.006940 137960357544576 range_evaluation.py:103] {'length': 100, 'accuracy': 0.478515625}
Network score: 0.4941075211864407
fazega commented 11 months ago

First, note that the average score for the transformer on this task is 67.7% ± 15.2 (standard deviation over 10 seeds). The number we report is the maximum score we got over 10 seeds. Unfortunately it's impossible to get perfect reproducibility (from the given best seed) just because we don't use the exact same hardware. And OOD performance is very brittle, especially for Transformers.

Second, I think Anian leaked some parameters we used internally for experiments that we removed in the final paper. You can ignore those (such as act_tau, used for Adaptive Computation Time, that we also tried).

In your case, the accuracy is very low even within the range of training lengths, so there is a problem. Can you try a lower learning rate (like 1e-4)? What do the losses look like? I will launch the exact set of parameters to compare on my own machine ASAP.

kirk86 commented 11 months ago

Thanks for your prompt reply.

Can you try a lower learning rate (like 1e-4)? What do the losses look like?

I did tried it and got same results. Here's a plot of losses and accuracies, something else is going on but can't tell what?

image
print(loss)
[0.7015707492828369, 0.6994349956512451, 0.691344141960144, 0.710098147392273, 0.7004635334014893, 0.697927713394165, 0.6934951543807983, 0.6899349689483643, 0.6963257789611816, 0.8071784377098083, 0.6978791952133179, 0.6934381127357483, 0.6925368309020996, 0.6989807486534119, 0.692752480506897, 0.691596269607544, 0.692746639251709, 0.6924327611923218, 0.6999204158782959, 0.6862495541572571, 0.7146303653717041, 0.6909568309783936, 0.6929700374603271, 0.6936980485916138, 0.7018345594406128, 0.6954343914985657, 0.6980022192001343, 0.6930896043777466, 0.6976462602615356, 0.6946387887001038, 0.6949437260627747, 0.6964828968048096, 0.6988976001739502, 0.6928001046180725, 0.690857470035553, 0.6941013932228088, 0.702369213104248, 0.6930474042892456, 0.6926758289337158, 0.6965607404708862, 0.6927996277809143, 0.708030104637146, 0.6915740966796875, 0.6944513320922852, 0.6907310485839844, 0.6866137981414795, 0.692477285861969, 0.6942601203918457, 0.6928719878196716, 0.6932602524757385, 0.6924253106117249, 0.694024920463562, 0.6923019289970398, 0.691500186920166, 0.6931975483894348, 0.694494366645813, 0.6943548917770386, 0.6919704675674438, 0.6952944397926331, 0.6955577731132507, 0.7023453712463379, 0.6944880485534668, 0.6943597793579102, 0.6948365569114685, 0.6914238929748535, 0.6941031217575073, 0.6935892105102539, 0.6953279972076416, 0.6933053731918335, 0.6920351982116699, 0.6933887004852295, 0.6932429075241089, 0.6937698125839233, 0.6930730938911438, 0.6935482621192932, 0.6916822195053101, 0.6918131113052368, 0.6913727521896362, 0.6911622881889343, 0.6927202939987183, 0.6933171153068542, 0.6943994164466858, 0.6925961971282959, 0.6929115653038025, 0.6933043003082275, 0.6912857294082642, 0.6890400648117065, 0.6919862031936646, 0.6934765577316284, 0.6931062936782837, 0.6958458423614502, 0.6948913931846619, 0.694998025894165, 0.6943366527557373, 0.6915961503982544, 0.695002019405365, 0.6928426623344421, 0.691338062286377, 0.6941516399383545, 0.6893787384033203, 0.6922582387924194]
kirk86 commented 10 months ago

Hi folls, @anianruoss @fazega any updates on this?

I've also have another clarifying question. The training loop runs for 10K steps and on each step you reserver a PRNG which you pass on to thesample_batchfunction, I suppose this happens in order to get samples that are diverse from previous steps, right? And, it seems you're not using any validation set cause the range_evaluationfunction is executed only once using a single PRNG=1.

My question is if you wanted to create different non-overlapping datasets train/valid/test, how would you do that? Would it suffice to use different PRNG for each set?

anianruoss commented 10 months ago

Note that for the paper we train for 1M steps -- the script provided here is just an example.

sample_batch requires a PRNG key to create data -- if you always provide the same key, you will always get the same data, if you want different data, pass a different key.

kirk86 commented 10 months ago

@anianruoss Thanks a lot for the reply and clarification. If the example is just an example where should I look in order to replicate/reproduce the tasks in the paper? Do you have somewhere all the hyperparams corresponding to the different tasks? You should add in the readme file for clarity how to run/execute the example.py for each task? E.g., for the tasks even_pairs run python example.py --training_steps 1M --hyperparam2 x --hyperparam y etc. So that we know we are running with the correct hyperparams as you did to retrieve results in the paper. BTW I've looked the paper both Sections A and B and there's nothing there indicating what hyperparams where used for which task (except for A.4 PROBLEM SETUP) but that's still quite vague.

Indicating in the readme something like the following would really help:

- Task binary arithmetic
python example.py --hyperparam x --hyperparam y -- hyperparam z

- Task reverse string
python example.py --hyperparam x --hyperparam y -- hyperparam z

- Task X
python example.py --hyperparam x --hyperparam y -- hyperparam z

-Task Y
python example.py --hyperparam x --hyperparam y -- hyperparam z
kirk86 commented 10 months ago

Note that for the paper we train for 1M steps -- the script provided here is just an example.

Still the same, even with 1M steps I'm still getting plots like the one above. Accuracy is at random 49/50%.

anianruoss commented 10 months ago

For even pairs, we used the following hyperparameters:

'model_init_seed': range(10)
'learning_rate': sweep([1e-4, 3e-4, 5e-4])
'batch_size': 128
'training_steps': 1_000_000
'max_grad_norm': 1.0
'l2_weight': 0.0
'max_range_test_length': 500
'curriculum':  'uniform'
'curriculum_kwargs': {'training_range': 40})

'transformer_encoder':
- 'num_layers': 5
- 'embedding_dim': 64
- 'positional_encodings': sweep(['NONE', 'SIN_COS', 'RELATIVE', 'ALIBI', 'ROTARY'])
- 'attention_window': None
- 'dropout_prob': 0.1
- 'is_autoregressive': False

Let me know if that works for you

kirk86 commented 10 months ago

@anianruoss Thanks a lot for the reply.

It seems that the is_autoregressive hyperparam should not go into the _ARCHITECTURE_PARAMS dictionary.

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/content/neural_networks_chomsky_hierarchy/experiments/example.py", line 143, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/content/neural_networks_chomsky_hierarchy/experiments/example.py", line 87, in main
    model = constants.MODEL_BUILDERS[_ARCHITECTURE.value](
TypeError: make_transformer_encoder() got an unexpected keyword argument 'is_autoregressive'

Isn't is_autoregressive hyperparam by default False? Seems to work if you pass it as hyperparam to example.py but not if you add it to _ARCHITECTURE_PARAMS dictionary

fazega commented 10 months ago

Hello, I think Anian sent you an internal config with extra parameters we haven't included in this open sourced version. Anyway, we're pretty busy at the moment with the ICML deadline, so we'll get back to you after that (in a week). I'll have a look at the training script: it's always possible there is a bug somewhere. Or are you also working towards this deadline?

kirk86 commented 10 months ago

@fazega Thanks for the answer. I'm also working towards the deadline.

Given the correct hypeparams and seeds shouldn't one be able to replicate results? I understand that there might be deviations due to hardware and cuda versions and what not, but at least one should be able to get an approximation to what is reported in the paper?

Providing a google colab replicating even some of the easiest tasks would be really helpful.

In the example.py there are lot of stuff that are hard-coded, e.g., ClassicTrainingParams or SinCosParams, should this be 10_000 for any task?

@chex.dataclass
class SinCosParams:
  """Parameters for the classical sin/cos positional encoding."""
  # The maximum wavelength used.
  max_time: int = 10_000

Could you please also explain this:

if is_autoregressive:
    outputs = model_apply_fn(
        params, rng_key, batch["input"], batch["output"], sample=False)
  else:
    outputs = model_apply_fn(params, rng_key, batch["input"])

Where is the sample argument coming from. I suppose the model_apply_fn is just similar to model(x, y, sample=False), no? Cause if that's the case then searching for sample argument in models directory finds nothing.

anianruoss commented 10 months ago

Note that we are sweeping over 10 seeds and 4 learning rates and taking the best-perfoming model. Are you doing that too?

anianruoss commented 10 months ago

You can ignore the is_autoregressive parameter. It should be set to False unless you're doing anything funky.

kirk86 commented 10 months ago

Note that we are sweeping over 10 seeds and 4 learning rates and taking the best-perfoming model. Are you doing that too?

Yes, so far I've run the 3 different learning rates you've provided 'learning_rate': sweep([1e-4, 3e-4, 5e-4]) and 6 different seeds. And all models seem to be random, accuracy 49/50%. How can we ensure that this is not a fluke, meaning are we trying to hack the seed?

Usually most models would work with default param initialization and almost any seed. When we report results over multiple seeds what we are trying to avoid is being lucky and getting that very good seed. So, reporting results over multiple seeds minimizes the probability getting the lottery ticket.

Here seems quite the opposite like trying to maximize the probability getting that lottery ticket seed?

and taking the best-perfoming model.

Unless the code in example.py is not doing that, I'm directly reporting the results after running the example.py which in essence trains the model and tests it on the range evaluation function.

Since, the code is online could you provide a google colab notebook replicating results? That shouldn't be too much of hassle?

fazega commented 10 months ago

That is a big hassle. Everything we do is Google-internal. We have to spend multiple days porting our codebase externally, and writing such a colab will definitely take us some time. Also as we explained, we train literally thousands of models in this paper. So not that easy to replicate. There might be a bug somewhere though. Random accuracy all along looks wrong. I'll investigate this weekend as you're also working for the deadline.

kirk86 commented 10 months ago

@fazega Thanks appreciate it!

we train literally thousands of models in this paper

I don't need all of them, I just need the transformer_encoder for the simple task of even_pairs for starters. Thanks!

anianruoss commented 10 months ago

I checked the model parameters for the transformer_encoder and they are the same

anianruoss commented 10 months ago

Re "seed hacking": The purpose of our paper is to show whether an architecture is capable of learning a task at all not on average. To that end we take the best-performing model over 10 seeds.

kirk86 commented 10 months ago

our paper is to show whether an architecture is capable of learning a task at all not on average

I don't disagree with that. As matter in fact it is shown in the paper that the model is not only capable of learning even_pairs but is also doing a phenomenal job at it! Contrary to what I've tried so far (maybe I'm doing sth wrong...).

fazega commented 9 months ago

I was able to reproduce the bug. I will try to fix over the weekend.

fazega commented 9 months ago

I tripled check the code and everything is correct. Using the following arguments (keeping everything else as default) works well for me. If you still have a problem, it's likely hardware. Closing.

seed=0
model_init_seed=3
training_steps=10_000
log_frequency=100
learning_rate=1e-4
sequence_length=40
task=even_pairs
batch_size=128
architecture_params={}
architecture='transformer_encoder'