jemisjoky / TorchMPS

PyTorch toolbox for matrix product state models
MIT License
138 stars 31 forks source link

Unmerging `linear_region` of an `MPS` in `adaptive_mode` changes the output #7

Closed philip-bl closed 4 years ago

philip-bl commented 4 years ago

Hello. I've been playing with your code. I created an MPS in adaptive_mode and wanted to look at its cores when they are not merged in pairs. So I unmerged them, but the output changed (and, I guess, the weights changed as well, or something). Below is a small example demonstrating the problem:

import torch
from torchmps import MPS

mpo = MPS(input_dim=9, output_dim=10, bond_dim=11, init_std=0.0, adaptive_mode=True).train(False)
random_input = torch.randn(1, 9, 2)
output_before_unmerge = mpo(random_input)
mpo.linear_region.unmerge()
output_after_unmerge = mpo(random_input)
print(output_before_unmerge[0])
# tensor([0.0138, 0.0138, 0.0138, 0.0138, 0.0138, 0.0138, 0.0138, 0.0138, 0.0138,
#         0.0138], grad_fn=<SelectBackward>)
print(output_after_unmerge[0])
# tensor([-8.2583, -8.2583, -8.2583, -8.2583, -8.2583, -8.2583, -8.2583, -8.2583,
#         -8.2583, -8.2583], grad_fn=<SelectBackward>)

I guess I am misunderstanding something about how to use unmerge. If so, could you please explain how I should do this?

jemisjoky commented 4 years ago

Thanks for finding that issue @philip-bl!

To be honest, coding up the adaptive mode in Pytorch was a lot trickier than I had expected, and there were some hacky things I ended up doing to make the merging and unmerging processes work smoothly behind the scenes. One such thing was having two sets of the core tensors in memory, one for each merge "offset" (which half of the bonds are held fixed), along with a pointer to indicate which one would receive gradient updates during the computation.

I hadn't really intended the merge and unmerge methods of MergedLinearRegion to be called by users, and any use of unmerge() was supposed to be followed by a call to merge(offset) before the model's forward method was called. I'm not entirely sure what's going wrong in your example, but I suspect it's something to do with the model's pointer pointing to an unmerged configuration when it expects a merged one. Either way, my apologies for this not being at all clear from the documentation, I'll add some warnings there!

If you initialize with a very low value of merge_threshold, you can see that the intended unmerging and merging behavior (which happens behind the scenes) doesn't end up changing the output. However, if you want to play around with this behavior yourself, I'd recommend looking at the first part of MergedLinearRegion.forward for how merge and unmerge were intended to be used.

A quick modification of your example that shows the behind-the-scenes merging/unmerging behavior:

import torch
from torchmps import MPS

mpo = MPS(input_dim=9, output_dim=10, bond_dim=11, init_std=0.0, 
          adaptive_mode=True, merge_threshold=2).train(False)
random_input = torch.randn(1, 9, 2)
for n in range(5):
    print(mpo(random_input)[0])

    if n % 2 == 0:
        print("Offset =", mpo.linear_region.offset)
        print(mpo.linear_region.module_list)

# tensor([-0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332,
#         -0.0332, -0.0332], grad_fn=<SelectBackward>)
# Offset = 0
# ModuleList(
#   (0): MergedInput()
#   (1): MergedOutput()
#   (2): MergedInput()
# )
# ^^^INITIAL CONFIGURATION OF THE MPS

# tensor([-0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332,
#         -0.0332, -0.0332], grad_fn=<SelectBackward>)
# tensor([-0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332,
#         -0.0332, -0.0332], grad_fn=<SelectBackward>)
# Offset = 1
# ModuleList(
#   (0): InputSite()
#   (1): MergedInput()
#   (2): MergedOutput()
#   (3): MergedInput()
#   (4): InputSite()
# )
# ^^^THE MODULES AND MERGE OFFSET HAVE CHANGED, BUT OUTPUT IS THE SAME

# tensor([-0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332,
#         -0.0332, -0.0332], grad_fn=<SelectBackward>)
# tensor([-0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332,
#         -0.0332, -0.0332], grad_fn=<SelectBackward>)
# Offset = 0
# ModuleList(
#   (0): MergedInput()
#   (1): MergedOutput()
#   (2): MergedInput()
# )
# ^^^BACK TO THE INITIAL CONFIGURATION

# tensor([-0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332, -0.0332,
#         -0.0332, -0.0332], grad_fn=<SelectBackward>)
jemisjoky commented 4 years ago

I just added a bit of documentation/warnings about not calling these methods directly, so I'm going to close this for now. Let me know if you have any other issues in exploring this behavior though!