yanweiyue / AgentPrune

13 stars 1 forks source link

Question about Eq.(11) #1

Open TienyuZuo opened 6 days ago

TienyuZuo commented 6 days ago

Thank you very much for your outstanding work. I have been following the delay research on LLM multi-agent systems, and your work has given me great inspiration. I would like to ask a question about how the convex optimization problem in Eq. (11) is handled in the code. I read run_gsm8k.py and noticed that there is no mention of the nuclear norm in the optimization part; it seems to only include Distribution Approximation (i.e., Eq. (9))? python single_loss = -log_prob * utility loss_list.append(single_loss) total_loss = torch.mean(torch.stack(loss_list)) total_loss.backward()

yanweiyue commented 6 days ago

Thank you for your kind words and for following our work on LLM multi-agent systems! We're delighted to hear that it has been helpful and inspiring for your research.

Due to an oversight in our version management, some code related to the nuclear norm was inadvertently omitted. You have pointed out some of the deficiencies in the existing code very sensitively. We deeply respect your meticulousness and seriousness.

To address your question, the code has now been updated to align with the convex optimization formulation described in Eq. (11). Below is the corrected code snippet:

def nuclear_norm(matrix):
    _, S, _ = torch.svd(matrix)
    return torch.sum(S)

def frobenius_norm(A, S):
    return torch.norm(A - S, p='fro')

spatial_matrix_train = realized_graph.spatial_logits.reshape((len(agent_names), len(agent_names)))
temporal_matrix_train = realized_graph.temporal_logits.reshape((len(agent_names), len(agent_names)))
spatial_matrix_fixed = torch.tensor(kwargs["fixed_spatial_masks"], dtype=torch.float32).reshape((len(agent_names), len(agent_names)))
temporal_matrix_fixed = torch.tensor(kwargs["fixed_temporal_masks"], dtype=torch.float32).reshape((len(agent_names), len(agent_names)))

loss_s = nuclear_norm(spatial_matrix_train)
loss_t = nuclear_norm(temporal_matrix_train)
frob_loss_s = frobenius_norm(spatial_matrix_fixed, spatial_matrix_train)
frob_loss_t = frobenius_norm(temporal_matrix_fixed, temporal_matrix_train)

add_loss = loss_s + loss_t + F.relu(frob_loss_s - args.delta) + F.relu(frob_loss_t - args.delta)

This revised implementation ensures that both spatial and temporal matrices are regularized using the nuclear norm and also incorporates the Frobenius norm for the fixed matrices with the corresponding constraints.

Thank you again for bringing this to our attention. Your feedback helps us improve, and we’re grateful for your insight. Please let us know if you have any further questions or need additional clarification.

Best regards😀