Closed smu160 closed 3 months ago
Woah this was on our to-do list for a bit recently, thanks so much for implementing it :) Hopefully I'll get the time to review this in a day or two, and we can (hopefully) add this as part of the next release! Thanks for helping improve Manim.
Hi @JasonGrace2282,
Thank you for taking the time to review. I added type hints the instance variables, method parameters, and return types of methods (only the ones I modified).
Please let me know if there is anything else that needs to be addressed.
Thank you!!
Hi @JasonGrace2282 ,
I checked in the changes you requested. Please let me know if anything else needs to be amended.
Thank you!!
Overview: What does this pull request change?
Hi, I apologize for dropping this enhancement for almost 9 months now. On the bright side, I'm glad I was able to come back to this after the change to using
av
for writing frames. As a show of good faith, I made sure to run benchmarks on two different machines, while accounting for all rendering quality levels. I will also attach the benchmark script to this PR in order to allow for others to verify on different setups/machines.This pull request introduces a significant performance improvement to the
SceneFileWriter
class in Manim by utilizing multithreading for the frame encoding process.Motivation and Explanation: Why and how do your changes improve the library?
The current implementation of
SceneFileWriter
encodes frames serially, which can be a bottleneck during the rendering process, especially for high-resolution videos. By introducing a separate thread for writing the video frames, we increase throughput, and we decrease the overall rendering time.Links to added or changed documentation pages
N/A
Benchmark Results
I carried out benchmarks on two different machines: an m2 MacBook Air and a Debian server that runs on a AMD Ryzen 9 7950X. The results show a noticeable reduction in rendering times across various quality levels with the multithreaded implementation as compared to the current implementation (main).
M2 MacBook Air
Debian Server with AMD Ryzen 9 7950X
Summary
System Information
Graphs above were annotated with system information.
Further Information and Comments
Steps to Reproduce
This assumes you're using my fork of manim.
Please put render.py and benchmark.py in the manim root. Note that you'll need to set the absolute path to the manim repo root using
REPO_PATH
Install dependencies:
benchmark.py
```python import subprocess import time import matplotlib.pyplot as plt import seaborn as sns import os import shutil REPO_PATH = "" MAIN_BRANCH = "main" CUSTOM_BRANCH = "multithreaded-sfw" BENCHMARK_SCENE = "Insertion" RENDER_COMMAND_TEMPLATE = "manim -q{} render.py " + BENCHMARK_SCENE QUALITY_LEVELS = ["l", "m", "h", "p", "k"] MEDIA_DIR = os.path.join(REPO_PATH, "media") def checkout_branch(branch_name): subprocess.run(["git", "checkout", branch_name], cwd=REPO_PATH) def run_benchmark(): if not REPO_PATH: raise ValueError("please provide an absolute path to the manim root directory") times = [] for quality in QUALITY_LEVELS: command = RENDER_COMMAND_TEMPLATE.format(quality) start_time = time.time() subprocess.run(command, shell=True, cwd=REPO_PATH) end_time = time.time() elapsed_time = end_time - start_time times.append(elapsed_time) if os.path.exists(MEDIA_DIR): shutil.rmtree(MEDIA_DIR) return times def plot_results(main_times, custom_times, machine_info): sns.set(style="whitegrid") plt.figure(figsize=(18, 12)) bar_width = 0.35 index = range(len(QUALITY_LEVELS)) bars1 = plt.bar(index, main_times, bar_width, label=MAIN_BRANCH, color=sns.color_palette("Blues", n_colors=5)[2]) bars2 = plt.bar([i + bar_width for i in index], custom_times, bar_width, label=CUSTOM_BRANCH, color=sns.color_palette("Oranges", n_colors=5)[2]) plt.xlabel('Quality Level', fontsize=14, fontweight='bold') plt.ylabel('Time (seconds)', fontsize=14, fontweight='bold') plt.title('Manim Rendering Time', fontsize=16, fontweight='bold') plt.xticks([i + bar_width / 2 for i in index], ['-q' + ql for ql in QUALITY_LEVELS], fontsize=12) plt.yticks(fontsize=12) plt.legend(fontsize=12) plt.grid(axis='y', linestyle='--', alpha=0.7) for bars in [bars1, bars2]: for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width() / 2.0, height, f'{height:.2f}', ha='center', va='bottom', fontsize=12, fontweight='bold') # Add machine info as text annotation machine_info_text = "\n".join(machine_info) plt.text(-0.3, max(max(main_times, custom_times))-7, machine_info_text, bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="white")) # plt.tight_layout() plt.savefig("benchmark_results_comparison.png", dpi=300) def get_machine_info(): try: neofetch_output = subprocess.check_output("neofetch --stdout", shell=True).decode().strip().split('\n') relevant_info = [ line for line in neofetch_output if any(keyword in line for keyword in ["OS:", "Kernel:", "CPU:", "Memory:"]) ] except subprocess.CalledProcessError: relevant_info = ["Neofetch is not installed or failed to run"] return relevant_info if __name__ == "__main__": # Run benchmarks on main branch print("Switching to main branch...") checkout_branch(MAIN_BRANCH) main_times = run_benchmark() # Run benchmarks on custom branch print("Switching to custom branch...") checkout_branch(CUSTOM_BRANCH) custom_times = run_benchmark() # Gather machine info print("Gathering machine information...") machine_info = get_machine_info() print(machine_info) # Plot and save results print("Plotting results...") plot_results(main_times, custom_times, machine_info) print("Benchmarking completed. Results saved to benchmark_results_comparison.png") ```render.py
```python """ Probably the ugliest python code written, but it works. Cleaning up, optimizing, and an overall refactor would be ideal. """ import numpy as np import itertools from manim import * import time FRAC_1_SQRT_2 = 1.0 / np.sqrt(2) # 1 / sqrt(2) class Array(VGroup): def __init__(self, values, **kwargs): super().__init__(**kwargs) self.values = values self.element_width = 1 self.element_color = WHITE initial_array = [Rectangle(height=1, width=1).set_color(GRAY_A) for _ in values] for i in range(1, len(initial_array)): initial_array[i].next_to(initial_array[i - 1], RIGHT, buff=0) self.add(*initial_array) self.elements = VGroup() for i, val in enumerate(values): tex = ( Tex(val) if val is not None else Tex(".", width=0, height=0, color=BLACK) ) tex.move_to(i * RIGHT * self.element_width) self.elements.add(tex) self.elements.set_color(self.element_color) self.add(self.elements) self.move_to(ORIGIN) def scale_fig(self, factor: float) -> None: """Scales the array by the given factor""" self.scale(factor) class Insertion(Scene): def construct(self): now = time.time() title = Tex(r"Strategy 3: Insertion") self.play( Write(title), ) self.wait() self.play(FadeOut(title)) code = """ let distance = 1 << target; let num_pairs = state.len() >> 1; for i in 0..num_pairs { let s0 = i + ((i >> target) << target); let s1 = s0 + distance; // update amplitudes here } """ rendered_code = Code( code=code, tab_width=4, language="rust", insert_line_no=False, ) self.play(Create(rendered_code)) self.wait(5) self.play(FadeOut(rendered_code)) num_qubits = Variable(4, MathTex("num\_qubits"), var_type=Integer).scale(0.75) state_len = Variable( int(2 ** num_qubits.tracker.get_value()), MathTex("state\_len"), var_type=Integer, ).scale(0.75) target = Variable(0, MathTex("target"), var_type=Integer).scale(0.75) num_pairs = Variable( int(state_len.tracker.get_value()) >> 1, MathTex("num\_pairs"), var_type=Integer, ).scale(0.75) distance = Variable(1, MathTex("distance"), var_type=Integer).scale(0.75) arr = Array( [ r"$z_{" + str(i) + r"}$" for i in range(int(state_len.tracker.get_value())) ] ) arr.scale_fig(0.75) vgroup = ( VGroup(num_qubits, state_len, target, num_pairs, distance) .arrange(DOWN) .move_to(ORIGIN + UP * 2) ) arr.next_to(vgroup, DOWN * 2) self.add(vgroup, arr) self.wait(1) start = arr[0].get_bottom() + DOWN * 2 end = arr[0].get_bottom() side_0 = Arrow(start=start, end=end, color=BLUE).scale(0.5, scale_tips=True) side_1 = Arrow(start=start, end=end, color=YELLOW).scale(0.5, scale_tips=True) self.add(side_0, side_1) for t in range(int(num_qubits.tracker.get_value())): self.play( target.tracker.animate.set_value(t), distance.tracker.animate.set_value(1 << t), ) for i in range(int(num_pairs.tracker.get_value())): s0 = i + ( (i >> int(target.tracker.get_value())) << int(target.tracker.get_value()) ) s1 = s0 + int(distance.tracker.get_value()) self.play( side_0.animate.next_to(arr[s0], DOWN), side_1.animate.next_to(arr[s1], DOWN), ) assembly = """ .LBB1_2: mov r11, r8 and r11, r9 lea rax, [r8 + r11] cmp rax, rsi jae .LBB1_6 lea r10, [rax + rdi] cmp r10, rsi jae .LBB1_7 shl rax, 4 vmovsd xmm2, qword ptr [rcx + rax] vmovsd xmm3, qword ptr [rcx + rax + 8] shl r10, 4 vmovsd xmm4, qword ptr [rcx + r10] vmovsd xmm5, qword ptr [rcx + r10 + 8] vmulsd xmm6, xmm1, xmm2 vmulsd xmm7, xmm0, xmm3 vaddsd xmm6, xmm7, xmm6 vmovsd qword ptr [rcx + rax], xmm6 vmulsd xmm3, xmm1, xmm3 vmulsd xmm2, xmm0, xmm2 vsubsd xmm2, xmm3, xmm2 vmovsd qword ptr [rcx + rax + 8], xmm2 inc r8 vmulsd xmm2, xmm1, xmm4 vmulsd xmm3, xmm0, xmm5 vsubsd xmm2, xmm2, xmm3 vmovsd qword ptr [rcx + r10], xmm2 vmulsd xmm2, xmm1, xmm5 vmulsd xmm3, xmm0, xmm4 vaddsd xmm2, xmm3, xmm2 vmovsd qword ptr [rcx + r10 + 8], xmm2 cmp rdx, r8 jne .LBB1_2""" self.clear() rendered_asm = Code( code=assembly, tab_width=4, language="nasm", insert_line_no=False, ).scale(0.65) self.play(Create(rendered_asm)) self.wait(5) self.play(FadeOut(rendered_asm)) end = time.time() elapsed = end - now print(f"elapsed: {elapsed}") ```Reviewer Checklist