ahmedmalaa / Symbolic-Metamodeling

Codebase for "Demystifying Black-box Models with Symbolic Metamodels", NeurIPS 2019.
48 stars 24 forks source link

Very high loss and unstable training when trying new functions or time spans #3

Closed ghost closed 3 years ago

ghost commented 3 years ago

Hey Ahmed!

I wanted to thank you for making this repository open for others to use and explore. This concept is most fascinating, and I have begun to try and explore with it but for some reason haven't had much luck getting good results outside the examples provided.

For example, if I add these two functions into the benchmark


def random_test1_function(X):
    """
    Benchmark function number 1: f(x) = 0.5x/(6+x)
    """
    return (0.5*X)/(6+X)

def random_test2_function(X):
    """
    Benchmark number 2: f(x) = 0.5*x*sqrt(X)/(6+x+sinx*e^x)
    """
    return (0.5*X*np.sqrt(X))/(6+X+np.exp(X)*np.sin(X))

and change the time span I want to train them on, I get wildly large losses, even if I just use the original examples provided, and change the xrange from 0.01:1 to 0.01:5 the same thing happens (loss as high as e+25!), no matter what I try i.e increasing the number of data points, playing with the batches in the fit function found in symbolic_metamodeling.py, the results just blow up. Is there something about this implementation in particular that's doing this? Is the concept itself prone to be this unstable? I read in a closed issue that the gradient calculations were custom and not generalised, has this been the problem? Looking forward to hearing back!

ahmedmalaa commented 3 years ago

Hi abstractingagent,

Thank you for your interest in our work. I think that the problem is with the time span. Meijer G functions are defined over [0, 1], so you have to normalize the feature space whenever you're dealing with larger values. Let me know if this fixes your problem.

Thanks.

ghost commented 3 years ago

Hey Ahmed!

I really appreciate your reply! Just to clarify, you're saying the co-domain of the G functions is only defined between [0,1]? I tried to do something simpler like 0.5x/(x+6) whose values are between that range (just replaced the rational expression part in the true function) and it wasn't able to successfully capture the form (it got 0.007xe^-0.23), though it did fit well in the domain it was trained on, it stubbornly stayed in the expression form of (ax)e^(-bx). Any ideas on how to better capture this? I tried using BFGS instead of CG, the loss dropped much quicker and I got a lower final loss than using CG, but couldn't get it to capture the true rational form

Also, as I was witnessing the loss function and the expression get printed during initialisation and training, I noticed that for the rational function and the exponential example, the starting initialising before training gave me the function forms we were supposed to identify, was this purposeful? Given some random initialisation, due to no prior knowledge of what the form will be, can the method still identify the forms?

One more question, to clarify the implementation here - are the meijer g functions being used essentially as an adaptive activation function where the shift, scale and power parameters in the gamma functions are the things being trained, and the weights leading into the nodes are being held constant? I am trying to understand the neural network analog to this

Once again I am grateful for your time and response!

ahmedmalaa commented 3 years ago

To be precise, the G function is defined over the domain (0, 1] so avoid inputting a 0 to it (I think I already do this by adding an epsilon to the G function inputs automatically).

The algorithm is going to fit the input function well but is not guaranteed to recover the ground-truth expression because symbolic expressions are generally unidentifiable (there are probably infinitely many possible symbolic forms that can generate any given smooth function). The expression recovered will depend on the hyper-parameters of the G function, that's why you keep getting the form (ax)*e^(-bx). You can play with the G function's number of poles and zeros to make sure a rational form is covered.

Yes, you can think of the G function as a very complex activation function in a shallow network.

Thanks!