Skip to content

Commit

Permalink
output mask by default, overlay optional
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielNobbe committed Oct 27, 2023
1 parent 7ac5aef commit 2895c8f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
9 changes: 7 additions & 2 deletions SegGPT/SegGPT_inference/seggpt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run_one_image(img, tgt, model, device):
return output


def inference_image(model, device, img_path, img2_paths, tgt2_paths, out_path):
def inference_image(model, device, img_path, img2_paths, tgt2_paths, out_path, ovl_path):
res, hres = 448, 448

image = Image.open(img_path).convert("RGB")
Expand Down Expand Up @@ -99,7 +99,12 @@ def inference_image(model, device, img_path, img2_paths, tgt2_paths, out_path):
size=[size[1], size[0]],
mode='nearest',
).permute(0, 2, 3, 1)[0].numpy()
output = Image.fromarray((input_image * (0.6 * output / 255 + 0.4)).astype(np.uint8))

if ovl_path is not None:
overlay = Image.fromarray((input_image * (0.6 * output / 255 + 0.4)).astype(np.uint8))
overlay.save(ovl_path)

output = Image.fromarray(output.astype(np.uint8))
output.save(out_path)


Expand Down
9 changes: 7 additions & 2 deletions SegGPT/SegGPT_inference/seggpt_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def get_args_parser():
choices=['instance', 'semantic'], default='instance')
parser.add_argument('--device', type=str, help='cuda or cpu',
default='cuda')
parser.add_argument('--output_dir', type=str, help='path to output',
parser.add_argument('--output_dir', type=str, help='path to output folder for output mask',
default='./')
parser.add_argument('--overlay-dir', type=str, help='path to output folder for combined mask and input (for visualising)', default=None)
return parser.parse_args()


Expand Down Expand Up @@ -61,8 +62,12 @@ def prepare_model(chkpt_dir, arch='seggpt_vit_large_patch16_input896x448', seg_t

img_name = os.path.basename(args.input_image)
out_path = os.path.join(args.output_dir, "output_" + '.'.join(img_name.split('.')[:-1]) + '.png')
if args.overlay_dir is not None:
ovl_path = os.path.join(args.overlay_dir, "overlay_" + '.'.join(img_name.split('.')[:-1]) + '.png')
else:
ovl_path = None

inference_image(model, device, args.input_image, args.prompt_image, args.prompt_target, out_path)
inference_image(model, device, args.input_image, args.prompt_image, args.prompt_target, out_path, ovl_path)

if args.input_video is not None:
assert args.prompt_target is not None and len(args.prompt_target) == 1
Expand Down

0 comments on commit 2895c8f

Please sign in to comment.