import os import tyro import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from core.options import AllConfigs, Options from core.gs import GaussianRenderer import dearpygui.dearpygui as dpg import kiui from kiui.cam import OrbitCamera class GUI: def __init__(self, opt: Options): self.opt = opt self.W = opt.output_size self.H = opt.output_size self.cam = OrbitCamera(self.W, self.H, r=opt.cam_radius, fovy=opt.fovy) self.device = torch.device("cuda") self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) self.proj_matrix[0, 0] = 1 / self.tan_half_fov self.proj_matrix[1, 1] = 1 / self.tan_half_fov self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) self.proj_matrix[2, 3] = 1 self.mode = "image" self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) self.need_update = True # update buffer_image # renderer self.renderer = GaussianRenderer(opt) self.gaussain_scale_factor = 1 self.gaussians = self.renderer.load_ply(opt.test_path).to(self.device) dpg.create_context() self.register_dpg() self.test_step() def __del__(self): dpg.destroy_context() @torch.no_grad() def test_step(self): # ignore if no need to update if not self.need_update: return starter = torch.cuda.Event(enable_timing=True) ender = torch.cuda.Event(enable_timing=True) starter.record() # should update image if self.need_update: # render image cam_poses = torch.from_numpy(self.cam.pose).unsqueeze(0).to(self.device) cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction # cameras needed by gaussian rasterizer cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] cam_pos = - cam_poses[:, :3, 3] # [V, 3] buffer_image = self.renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=self.gaussain_scale_factor)[self.mode] buffer_image = buffer_image.squeeze(1) # [B, C, H, W] if self.mode in ['alpha']: buffer_image = buffer_image.repeat(1, 3, 1, 1) buffer_image = F.interpolate( buffer_image, size=(self.H, self.W), mode="bilinear", align_corners=False, ).squeeze(0) self.buffer_image = ( buffer_image.permute(1, 2, 0) .contiguous() .clamp(0, 1) .contiguous() .detach() .cpu() .numpy() ) self.need_update = False ender.record() torch.cuda.synchronize() t = starter.elapsed_time(ender) dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)") dpg.set_value( "_texture", self.buffer_image ) # buffer must be contiguous, else seg fault! def register_dpg(self): ### register texture with dpg.texture_registry(show=False): dpg.add_raw_texture( self.W, self.H, self.buffer_image, format=dpg.mvFormat_Float_rgb, tag="_texture", ) ### register window # the rendered image, as the primary window with dpg.window( tag="_primary_window", width=self.W, height=self.H, pos=[0, 0], no_move=True, no_title_bar=True, no_scrollbar=True, ): # add the texture dpg.add_image("_texture") # dpg.set_primary_window("_primary_window", True) # control window with dpg.window( label="Control", tag="_control_window", width=600, height=self.H, pos=[self.W, 0], no_move=True, no_title_bar=True, ): # button theme with dpg.theme() as theme_button: with dpg.theme_component(dpg.mvButton): dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) # timer stuff with dpg.group(horizontal=True): dpg.add_text("Infer time: ") dpg.add_text("no data", tag="_log_infer_time") # rendering options with dpg.collapsing_header(label="Rendering", default_open=True): # mode combo def callback_change_mode(sender, app_data): self.mode = app_data self.need_update = True dpg.add_combo( ("image", "alpha"), label="mode", default_value=self.mode, callback=callback_change_mode, ) # fov slider def callback_set_fovy(sender, app_data): self.cam.fovy = np.deg2rad(app_data) self.need_update = True dpg.add_slider_int( label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=np.rad2deg(self.cam.fovy), callback=callback_set_fovy, ) def callback_set_gaussain_scale(sender, app_data): self.gaussain_scale_factor = app_data self.need_update = True dpg.add_slider_float( label="gaussain scale", min_value=0, max_value=1, format="%.2f", default_value=self.gaussain_scale_factor, callback=callback_set_gaussain_scale, ) ### register camera handler def callback_camera_drag_rotate(sender, app_data): if not dpg.is_item_focused("_primary_window"): return dx = app_data[1] dy = app_data[2] self.cam.orbit(dx, dy) self.need_update = True def callback_camera_wheel_scale(sender, app_data): if not dpg.is_item_focused("_primary_window"): return delta = app_data self.cam.scale(delta) self.need_update = True def callback_camera_drag_pan(sender, app_data): if not dpg.is_item_focused("_primary_window"): return dx = app_data[1] dy = app_data[2] self.cam.pan(dx, dy) self.need_update = True with dpg.handler_registry(): # for camera moving dpg.add_mouse_drag_handler( button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate, ) dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) dpg.add_mouse_drag_handler( button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan ) dpg.create_viewport( title="Gaussian3D", width=self.W + 600, height=self.H + (45 if os.name == "nt" else 0), resizable=False, ) ### global theme with dpg.theme() as theme_no_padding: with dpg.theme_component(dpg.mvAll): # set all padding to 0 to avoid scroll bar dpg.add_theme_style( dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.add_theme_style( dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.add_theme_style( dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.bind_item_theme("_primary_window", theme_no_padding) dpg.setup_dearpygui() ### register a larger font # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf if os.path.exists("LXGWWenKai-Regular.ttf"): with dpg.font_registry(): with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: dpg.bind_font(default_font) # dpg.show_metrics() dpg.show_viewport() def render(self): while dpg.is_dearpygui_running(): # update texture every frame self.test_step() dpg.render_dearpygui_frame() opt = tyro.cli(AllConfigs) # load a saved ply and visualize assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py' gui = GUI(opt) gui.render()