Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load pretrained Enhancement model in maxim_torch.py #9

Closed
marisancans opened this issue Mar 10, 2023 · 2 comments
Closed

How to load pretrained Enhancement model in maxim_torch.py #9

marisancans opened this issue Mar 10, 2023 · 2 comments

Comments

@marisancans
Copy link

I downloaded weights from maxim repo
Then in maxim-pytorch repo i run jax2torch.py file with

python maxim_pytorch/jax2torch.py -c maxim_ckpt_Enhancement_FiveK_checkpoint.npz
It works and I get torch_weight.pth file

I then try to load it but im unable to understand if Im giving the wrong arguments or your code is wrong

from maxim_pytorch.maxim_torch import MAXIM_dns_3s
import torch
import cv2
import numpy as np
from torchvision import transforms

from pathlib import Path


# These params are from https://github.com/google-research/maxim/blob/3c8265171ffccc80c3c9124844aef0d381609956/maxim/models/maxim.py#L910
s2 = {
    "features": 32,
    "depth": 3,
    "num_stages": 2, #
    "num_groups": 2, # 
    "num_bottleneck_blocks": 2, #
    "block_gmlp_factor": 2,
    "grid_gmlp_factor": 2,
    "input_proj_factor": 2,
    "channels_reduction": 4,
}

model = MAXIM_dns_3s(features=32, depth=3, block_gmlp_factor=2, grid_gmlp_factor=2, input_proj_factor=2, channels_reduction=4, num_supervision_scales=2)
state = torch.load("torch_weight.pth")

model.load_state_dict(state)
model.eval()

I get error:

RuntimeError: Error(s) in loading state_dict for MAXIM_dns_3s:
	Unexpected key(s) in state_dict: "stage_1_output_conv_0.bias", "stage_1_output_conv_0.weight", "stage_1_output_conv_1.bias", "stage_1_output_conv_1.weight", "stage_1_output_conv_2.bias", "stage_1_output_conv_2.weight". 
@wj320
Copy link

wj320 commented Sep 27, 2023

I downloaded weights from maxim repo Then in maxim-pytorch repo i run jax2torch.py file with

python maxim_pytorch/jax2torch.py -c maxim_ckpt_Enhancement_FiveK_checkpoint.npz It works and I get torch_weight.pth file

I then try to load it but im unable to understand if Im giving the wrong arguments or your code is wrong

from maxim_pytorch.maxim_torch import MAXIM_dns_3s
import torch
import cv2
import numpy as np
from torchvision import transforms

from pathlib import Path


# These params are from https://github.com/google-research/maxim/blob/3c8265171ffccc80c3c9124844aef0d381609956/maxim/models/maxim.py#L910
s2 = {
    "features": 32,
    "depth": 3,
    "num_stages": 2, #
    "num_groups": 2, # 
    "num_bottleneck_blocks": 2, #
    "block_gmlp_factor": 2,
    "grid_gmlp_factor": 2,
    "input_proj_factor": 2,
    "channels_reduction": 4,
}

model = MAXIM_dns_3s(features=32, depth=3, block_gmlp_factor=2, grid_gmlp_factor=2, input_proj_factor=2, channels_reduction=4, num_supervision_scales=2)
state = torch.load("torch_weight.pth")

model.load_state_dict(state)
model.eval()

I get error:

RuntimeError: Error(s) in loading state_dict for MAXIM_dns_3s:
	Unexpected key(s) in state_dict: "stage_1_output_conv_0.bias", "stage_1_output_conv_0.weight", "stage_1_output_conv_1.bias", "stage_1_output_conv_1.weight", "stage_1_output_conv_2.bias", "stage_1_output_conv_2.weight". 

I met the same question. Have you addressed it?

@marisancans
Copy link
Author

No, this repo is dead. We reverse engineered the code and implemented it ourselves in pytorch. Got really bad results on large resolutions and wont use this in the future

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants