f-lab-edu / virtual-try-on

3 stars 2 forks source link

모델에 대한 모듈화 작업 #5

Open f-lab-sage opened 8 months ago

f-lab-sage commented 8 months ago

예시


class VTryOnSerice:

    def __init__(self):
        self.model = VTryOnModel(...)
        self.model.load_model()

        self.seg_model = Segment(...)

    def handler(self, ):
        # img download

        output = self.model.inference(...)
        # img upload
        # ...

class VTryOnModel:

    def __init__(self, val, ...):
        self.val = val
        self.warp_model = None
        self.gen_model = None
        # additional init

    def load_model(self):
        warp_model = AFWM(opt, 3)
        print(warp_model)
        warp_model.eval()
        warp_model.cuda()
        load_checkpoint(warp_model, opt.warp_checkpoint)

        gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
        print(gen_model)
        gen_model.eval()
        gen_model.cuda()
        load_checkpoint(gen_model, opt.gen_checkpoint)

        self.warp_model = wrap_model
        self.gen_model = gen_model

    def inference(self, p_img, c_img, m_img):
        real_image = data['image']
        clothes = data['clothes']
        ##edge is extracted from the clothes image with the built-in function in python
        edge = data['edge']
        edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
        clothes = clothes * edge        

        flow_out = warp_model(real_image.cuda(), clothes.cuda())
        warped_cloth, last_flow, = flow_out
        warped_edge = F.grid_sample(edge.cuda(), last_flow.permute(0, 2, 3, 1),
                          mode='bilinear', padding_mode='zeros')

        gen_inputs = torch.cat([real_image.cuda(), warped_cloth, warped_edge], 1)
        gen_outputs = gen_model(gen_inputs)
        p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
        p_rendered = torch.tanh(p_rendered)
        m_composite = torch.sigmoid(m_composite)
        m_composite = m_composite * warped_edge
        p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)

        ...

        return combined_img
f-lab-sage commented 8 months ago

DB 부분은 superbase를 고려해보기

f-lab-sage commented 8 months ago

PR 작성 방법

간단한 설명

f-lab-sage commented 8 months ago

설정 파일 사용 시

config.yaml

env:
  env_name: dev, stage,
db:
  endpoint: xxx
  timeout:xxx
model:
  batch_size:
  model_path: s3://xxxxx

실제 로딩할 때는 load_yaml(xxx)

f-lab-sage commented 8 months ago

ML 모듈 수정방향


from PIL import Image

# init 과정
 vton = VTryOnModel(start_epoch=1, epoch_iter=0)
 vton.load_model()

 seg = ClothSegmentationModel()

# 유저 요청이 들어왔을 때
 person_img = Image.open("xxx.jpg")
 cloth_img = Image.open("yyy.jpg")

 dataset, dataset_size = vton.load_data("데이터들 저장된 경로")

 # segmentation

 # 전처리
 pp_image = vton.preprocess("이미지 경로들")

 output_img = vton.inference(pp_image)
 output_img.save("")

 # upload