Sxela / WarpFusion

WarpFusion
Other
955 stars 105 forks source link

animatediff only uses last scene #105

Open usergenic opened 4 months ago

usergenic commented 4 months ago

Discussed on discord. Here are the notes:

I notice that when split into scenes with content aware scheduling, animatediff batches seem to start from the last scene, skipping the others. This appears to be true both when defining manual splits or using the content aware scheduling to do same.

I added a print("filtered_scenes:", filtered_scenes) right before the

        for scene in filtered_scenes:
          if 'animatediff' in model_version:

inside the Do the Run cell, and it does show all the scenes, so I'm definitely confused at this point. My understanding of following the code breaks down though because I can't understand how the subsequent do_run_adiff(args) knows about any of that because everything I understand about scoping of variables involved doesn't seem to match what is happening.

usergenic commented 4 months ago

I think the solution may be to change the code in cell 4 which does not set args.scene_start and args.scene_end in the animatediff case:

        for scene in filtered_scenes:
          if 'animatediff' in model_version:
              #reset batch length to max size each scene if == -1
              if batch_length_bkup == -1:
                batch_length = scene_end-scene_start+1
              else:
                batch_length = min(batch_length_bkup, scene_end-scene_start+1)
              # batch_length = min(total_batch_length, scene_end-scene_start+1) #make sure to not overflow with -1 batch size
              do_run_adiff(args)
          else:
            scene_start, scene_end = scene
            args.start_frame = scene_start
            scene_end = min(scene_end+1, total_frames)
            args.max_frames = scene_end
            # frame_range = [scene_start, scene_end]
            # print('scene_start, scene_end, frame_range', scene_start, scene_end, frame_range)
            # print(frame_range, args.max_frames, args.start_frame)
            do_run()

Perhaps change to:

        for scene in filtered_scenes:
          scene_start, scene_end = scene
          args.start_frame = scene_start
          scene_end = min(scene_end+1, total_frames)
          args.max_frames = scene_end
          if 'animatediff' in model_version:
              #reset batch length to max size each scene if == -1
              if batch_length_bkup == -1:
                batch_length = scene_end-scene_start+1
              else:
                batch_length = min(batch_length_bkup, scene_end-scene_start+1)
              # batch_length = min(total_batch_length, scene_end-scene_start+1) #make sure to not overflow with -1 batch size
              do_run_adiff(args)
          else:
            # frame_range = [scene_start, scene_end]
            # print('scene_start, scene_end, frame_range', scene_start, scene_end, frame_range)
            # print(frame_range, args.max_frames, args.start_frame)
            do_run()
usergenic commented 4 months ago

for what it's worth, this does fix this problem, but i almost always immediately bump into a problem with either batch or context size exceeding scene size with animatediff and scenes due to quick cut areas, so I think the only way around that is to convert automatic scene splits into a manual list i can hand-edit to avoid them...

Sxela commented 4 months ago

The solution with smaller scnes is to overlap them (so that [0,5], [6,10] becomes [0,16], [6,22]

Sxela commented 4 months ago

Something like this:

          scene_start, scene_end = scene
          args.start_frame = scene_start
          scene_end = min(scene_end+1, total_frames)
          args.max_frames = scene_end
          if 'animatediff' in model_version:
              #reset batch length to max size each scene if == -1
              if batch_length_bkup == -1:
                batch_length = scene_end-scene_start+1
              else:
                batch_length = min(batch_length_bkup, scene_end-scene_start+1)

              if scene_end-scene_start+1 < context_length: #ensure batch size == context size 
                batch_length = context_length
                scene_end = scene_end + context_length - (scene_end-scene_start+1)

              do_run_adiff(args)
          else:
            do_run()