Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Great Project! Here are a few suggestions. #61

Closed
Isi-dev opened this issue Jul 22, 2024 · 1 comment
Closed

Great Project! Here are a few suggestions. #61

Isi-dev opened this issue Jul 22, 2024 · 1 comment

Comments

@Isi-dev
Copy link

Isi-dev commented Jul 22, 2024

Thanks a lot as UniAnimate is currently the best image animation project I have tested. I have so far noticed no artifact and the animation is quite smooth. I am using a 12GB VRAM without memory issues on a windows OS. Here are a few suggestions based on the errors I faced while trying to run the inference.

The first error was something like : The shape of the 2D attn_mask is torch.Size([77, 77]), but should be (1, 1).
This took some time to resolve especially because the error was not directly from your project, but from the open_clip dependency. It seems recent releases of open_clip included a change that can cause this error in some environments. In case any one encounters this error, simple navigate to Lib\site-packages\open_clip\transformer.py in your virtual environment and change batch_first: bool = True, on line 329 to batch_first: bool = False.

After fixing the first error, I ran the inference and got a second error which was something like: "list index out of range" & "local variable 'vit_frame' referenced before assignment."
The cause was that the use of reference videos or pose frames shorter than the max_frames set in UniAnimate\configs\UniAnimate_infer.yaml was not handled in the load_video_frames function in UniAnimate\tools\inferences\inference_unianimate_entrance.py. I suggest that the load_video_frames function be modified as shown below to handle all error cases related to the reference video or pose frame;

def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval=1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]):
    for _ in range(5):
        try:
            dwpose_all = {}
            frames_all = {}
            for ii_index in sorted(os.listdir(pose_file_path)):
                if ii_index != "ref_pose.jpg":
                    dwpose_all[ii_index] = Image.open(os.path.join(pose_file_path, ii_index))
                    frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path), cv2.COLOR_BGR2RGB))

            pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg"))

            # Sample max_frames poses for video generation
            stride = frame_interval
            total_frame_num = len(frames_all)
            cover_frame_num = (stride * (max_frames - 1) + 1)

            if total_frame_num < cover_frame_num:
                print(f'_total_frame_num ({total_frame_num}) is smaller than cover_frame_num ({cover_frame_num}), the sampled frame interval is changed')
                start_frame = 0
                end_frame = total_frame_num
                stride = max((total_frame_num - 1) // (max_frames - 1), 1)
                end_frame = stride * max_frames
            else:
                start_frame = 0
                end_frame = start_frame + cover_frame_num

            frame_list = []
            dwpose_list = []
            random_ref_frame = frames_all[list(frames_all.keys())[0]]
            if random_ref_frame.mode != 'RGB':
                random_ref_frame = random_ref_frame.convert('RGB')
            random_ref_dwpose = pose_ref
            if random_ref_dwpose.mode != 'RGB':
                random_ref_dwpose = random_ref_dwpose.convert('RGB')

            for i_index in range(start_frame, end_frame, stride):
                if i_index < len(frames_all):  # Check index within bounds
                    i_key = list(frames_all.keys())[i_index]
                    i_frame = frames_all[i_key]
                    i_frame = i_frame.convert('RGB')  
                    i_dwpose = dwpose_all[i_key]
                    i_dwpose = i_dwpose.convert('RGB') 
                    frame_list.append(i_frame)
                    dwpose_list.append(i_dwpose)

            if frame_list:
                middle_indix = 0
                ref_frame = frame_list[middle_indix]
                vit_frame = vit_transforms(ref_frame)
                random_ref_frame_tmp = train_trans_pose(random_ref_frame)
                random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose)
                misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0)
                video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0)
                dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0)

                video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
                dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
                misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
                random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])
                random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0])

                video_data[:len(frame_list), ...] = video_data_tmp
                misc_data[:len(frame_list), ...] = misc_data_tmp
                dwpose_data[:len(frame_list), ...] = dwpose_data_tmp
                random_ref_frame_data[:, ...] = random_ref_frame_tmp
                random_ref_dwpose_data[:, ...] = random_ref_dwpose_tmp

                return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data

        except Exception as e:
            logging.info(f'Error reading video frame: {e}')
            continue

    return None, None, None, None, None, None

Or you can come up with a better implementation. Thank you.

@wangxiang1230
Copy link
Collaborator

Hi, thanks for your efforts. We have merge your code to our repo, and it is a good implementation!

@Isi-dev Isi-dev closed this as completed Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants