waleedka / hiddenlayer

Neural network graphs and training metrics for PyTorch, Tensorflow, and Keras.
MIT License
1.79k stars 266 forks source link

How to fold when constant is added ? #26

Open theevann opened 5 years ago

theevann commented 5 years ago

Hello ! How can I write a proper folding rule for this ?

image

I don't know how to remove the "Constant" by including it in a folding since it has no parent...

waleedka commented 5 years ago

You can remove the constant using the transform: Prune("Constant").

I don't know the rules you have in place already, but based on the image you posted I'm guessing you have something like this:

    Fold("Constant > Unsqueeze", "ConstUnsqueeze")
    Fold("Gather > Unsqueeze", "GatherUnsqueeze")

If so, then you can build on the above as follows:

    # 1. Merge the right branch:
    Fold("Shape > (ConstUnsqueeze | GatherUnsqueeze) > Concat", "RightBranch")

    # 2. Merge the rest
    Fold("Relu > RightBranch > Reshape", "FullBlock")

Use this as a hint only. Since I don't have access to your code, I'm just guessing here.