Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can you provide the forward code for testing? #3

Closed
HappyAIWalker opened this issue Aug 24, 2022 · 2 comments
Closed

Can you provide the forward code for testing? #3

HappyAIWalker opened this issue Aug 24, 2022 · 2 comments

Comments

@HappyAIWalker
Copy link

I try to re-impl the forward as the paper and load the pretrained, but the result is not right.
So, can you provide the forward code for testing performance.
thanks for your reply.

@Wenju-Huang
Copy link

Wenju-Huang commented Aug 27, 2022

我复现的结果也有问题,另外我也尝试过把scalar strength设置为0,只测试第一个neurOps结果也不对。请问图像输入模型前后是否还有哪些前后处理吗,下面是我的代码,谢谢你的回复。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 7, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.act = nn.ReLU()
        self.gap_max = nn.AdaptiveMaxPool2d(1)
        self.gap_avg = nn.AdaptiveAvgPool2d(1)
    
    def forward(self, img):
        _, _, h, w = img.size()
        short_side = min(h, w)
        factor = 256/short_side
        img_resize = F.interpolate(img, scale_factor=factor, mode='bilinear')
        
        x = self.act(self.conv1(img_resize))
        x = self.act(self.conv2(x))
        y_max = self.gap_max(x).flatten(1)
        y_mean = self.gap_avg(x).flatten(1)
        y_std = torch.std(x, dim=(2, 3)).flatten(1)
        y = torch.cat((y_std, y_mean, y_max), dim=1)
        return y

class Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc3 = nn.Linear(96, 1)
        self.act = nn.Tanh()
        
    def forward(self, fea):
        return self.act(self.fc3(fea))

class Render(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Conv2d(3, 64, 1)
        self.mid_conv = nn.Conv2d(64, 64, 1)
        self.decoder = nn.Conv2d(64, 3, 1)
        self.act = nn.ReLU()
        
    def forward(self, img, v):
        z = self.encoder(img)
        z = self.act(self.mid_conv(z + v))
        y = self.decoder(z)
        return y
    
class Nerop(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_encoder = Encoder()
        
        self.ex_predictor = Predictor()
        self.bc_predictor = Predictor()
        self.vb_predictor = Predictor()
        
        self.ex_renderer = Render()
        self.bc_renderer = Render()
        self.vb_renderer = Render()

    def forward(self, img):
        img_code = self.image_encoder(img)
        v_ex = self.ex_predictor(img_code)
        img_ex = self.ex_renderer(img, v_ex)
        
        img_code = self.image_encoder(img_ex)
        v_bc = self.bc_predictor(img_code)
        img_bc = self.bc_renderer(img_ex, v_bc)
        
        img_code = self.image_encoder(img_bc)
        v_vb = self.vb_predictor(img_code)
        img_vb = self.vb_renderer(img_bc, v_vb)
        
        return img_vb
    
if __name__ == '__main__':
    import numpy as np
    from PIL import Image
    import torchvision.transforms.functional as tf
    model = Nerop()
    print('Load model')
    model.load_state_dict(torch.load('experiments/pretrain_models/neurop_fivek_lite.pth'))
    img = Image.open('website/img/input/4931.jpg')
    img_tensor = tf.to_tensor(img)
    img_tensor = img_tensor.unsqueeze(0)
    output = model(img_tensor).squeeze(0)
    img_enhance = tf.to_pil_image(output)
    img_enhance.save('enhance.jpeg')

@amberwangyili
Copy link
Owner

amberwangyili commented Aug 29, 2022

I try to re-impl the forward as the paper and load the pretrained, but the result is not right. So, can you provide the forward code for testing performance. thanks for your reply.

Thanks for your attention @Wenju-Huang @HappyAIWalker ! For testing the performance, please refer to our Pytorch implementation https://github.com/amberwangyili/neurop-pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants