alan-turing-institute / reprosyn

MIT License
6 stars 1 forks source link

configpath not working with pategan #63

Open lbeziaud opened 1 year ago

lbeziaud commented 1 year ago

Hello. The rsyn executable seems to misbehave when using PATEGAN, but does not with MST (I did not try with other generators).

Here --configpath is passed as a keyword argument to PateGan:

rsyn --dataset data.csv --size 10 --configpath config.json --metadata meta.json pategan
Traceback (most recent call last):
  File "./venv/bin/rsyn", line 8, in <module>
    sys.exit(cli())
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 1657, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "./venv/lib64/python3.9/site-packages/click/decorators.py", line 26, in new_func
    return f(get_current_context(), *args, **kwargs)
  File "./venv/lib64/python3.9/site-packages/reprosyn/cli_utils.py", line 31, in wrapper
    func(ctx, **kwargs)
  File "./venv/lib64/python3.9/site-packages/reprosyn/methods/gans/cli.py", line 102, in cmd_pategan
    generator.run()
  File "./venv/lib64/python3.9/site-packages/reprosyn/generator.py", line 113, in run
    self.generate()
  File "./venv/lib64/python3.9/site-packages/reprosyn/methods/gans/gans.py", line 151, in generate
    self.gen = PateGan(self.meta, **self.params)
TypeError: __init__() got an unexpected keyword argument 'configpath'

If the option is removed, then generateconfig is passed as an argument:

rsyn --dataset data.csv --size 10 --metadata data.json pategan
Traceback (most recent call last):
  File "./venv/bin/rsyn", line 8, in <module>
    sys.exit(cli())
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 1657, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "./venv/lib64/python3.9/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "./venv/lib64/python3.9/site-packages/click/decorators.py", line 26, in new_func
    return f(get_current_context(), *args, **kwargs)
  File "./venv/lib64/python3.9/site-packages/reprosyn/cli_utils.py", line 31, in wrapper
    func(ctx, **kwargs)
  File "./venv/lib64/python3.9/site-packages/reprosyn/methods/gans/cli.py", line 102, in cmd_pategan
    generator.run()
  File "./venv/lib64/python3.9/site-packages/reprosyn/generator.py", line 113, in run
    self.generate()
  File "./venv/lib64/python3.9/site-packages/reprosyn/methods/gans/gans.py", line 151, in generate
    self.gen = PateGan(self.meta, **self.params)
TypeError: __init__() got an unexpected keyword argument 'generateconfig'

I'll update the issue after a proper investigation.

lbeziaud commented 1 year ago

This can be fixed by adding a catch-all to PateGan, similar to MST (and also to PipelineBase, which I am not sure why PateGan is not deriving).

diff --git a/src/reprosyn/methods/gans/pate_gan.py b/src/reprosyn/methods/gans/pate_gan.py
index 5ce8e15..46dacb8 100644
--- a/src/reprosyn/methods/gans/pate_gan.py
+++ b/src/reprosyn/methods/gans/pate_gan.py
@@ -47,6 +47,7 @@ class PateGan(GenerativeModel):
         batch_size=128,
         learning_rate=1e-4,
         multiprocess=False,
+        **kw
     ):
         """
         :param metadata: dict: Attribute metadata describing the data domain of the synthetic target data