ngraymon / termfactory

Exploring a formal representation for the residual terms and/or cleaning up code of the residual equation generator
MIT License
2 stars 1 forks source link

Implement `opt_einsum` generator for full cc #76

Open ngraymon opened 2 years ago

ngraymon commented 2 years ago

see example below

def add_m0_n0_fully_connected_terms_optimized(R, ansatz, truncation, h_args, t_args, opt_paths):
    """Optimized calculation of the operator(name='', rank=0, m=0, n=0) fully_connected terms."""

    if ansatz.ground_state:
        R += h_args[(0, 0)]

        if truncation.at_least_linear:
            if truncation.singles:
                R += np.einsum('aci, cbi -> ab', h_args[(0, 1)], t_args[(1, 0)])
    else:
        R += h_args[(0, 0)]

        if truncation.at_least_linear:
            if truncation.singles:
                R += np.einsum('aci, cbi -> ab', h_args[(1, 0)], t_args[(0, 1)])
    return

replace

R += np.einsum('aci, cbi -> ab', h_args[(0, 1)], t_args[(1, 0)])

with

R += next(optimized_einsum)(h_args[(0, 1)], t_args[(1, 0)])
ngraymon commented 2 years ago

Instead of looping

for order in range(1, max_order+1):

we have to loop over both dimensions (m,n)

for omega_term in master_omega.operator_list:
    specifier_string = f"m{omega_term.m}_n{omega_term.n}"
ngraymon commented 2 years ago

The final structure is different than the original design.

compute_all_optimized_paths returns a dictionary of all optimized paths for proj^m_n compute_m#_n#_optimized_paths returns a dictionary with 1 key, value pair where the value is a list of length 3

for example:

connected_opt_path_list = compute_m3_n0_fully_connected_optimized_paths(A, N, ansatz, truncation)
linked_opt_path_list = compute_m3_n0_linked_disconnected_optimized_paths(A, N, ansatz, truncation)
unlinked_opt_path_list = compute_m3_n0_unlinked_disconnected_optimized_paths(A, N, ansatz, truncation)

return_dict = {
    (3, 0): [connected_opt_path_list, linked_opt_path_list, unlinked_opt_path_list]
}

each compute_m#_n#_{term_type_name}_optimized_paths function returns the appropriate list filled with contracted expressions for example:

def compute_m0_n0_fully_connected_optimized_paths(A, N, ansatz, truncation):
    """Calculate optimized einsum paths for the fully_connected terms."""

    fully_connected_opt_path_list = []

    if ansatz.ground_state:

        if truncation.at_least_linear:
            if truncation.singles:
                fully_connected_opt_path_list.append(oe.contract_expression((A, A, N), (A, A, N)))

        if truncation.at_least_quadratic:
            if truncation.singles:
                fully_connected_opt_path_list.extend([
                    oe.contract_expression((A, A, N, N), (A, A, N), (A, A, N)),
                    oe.contract_expression((A, A, N, N), (A, A, N), (A, A, N))
                ])
            if truncation.doubles:
                fully_connected_opt_path_list.append(oe.contract_expression((A, A, N, N), (A, A, N, N)))

        if truncation.at_least_cubic:
            if truncation.singles:
                fully_connected_opt_path_list.extend([
                    oe.contract_expression((A, A, N, N, N), (A, A, N), (A, A, N), (A, A, N)),
                    oe.contract_expression((A, A, N, N, N), (A, A, N), (A, A, N), (A, A, N)),
                    oe.contract_expression((A, A, N, N, N), (A, A, N), (A, A, N), (A, A, N)),
                    oe.contract_expression((A, A, N, N, N), (A, A, N), (A, A, N), (A, A, N)),
                    oe.contract_expression((A, A, N, N, N), (A, A, N), (A, A, N), (A, A, N)),
                    oe.contract_expression((A, A, N, N, N), (A, A, N), (A, A, N), (A, A, N))
                ])
            if truncation.doubles:
                fully_connected_opt_path_list.extend([
                    oe.contract_expression((A, A, N, N, N), (A, A, N), (A, A, N, N)),
                    oe.contract_expression((A, A, N, N, N), (A, A, N, N), (A, A, N))
                ])
            if truncation.triples:
                fully_connected_opt_path_list.append(oe.contract_expression((A, A, N, N, N), (A, A, N, N, N)))
    else:
        raise Exception('Hot Band amplitudes not implemented properly and have not been theoretically verified!')

    return fully_connected_opt_path_list
ngraymon commented 2 years ago

I forgot to include the actual einsum operation spec in the contraction expressions! Need to fix this.