From 5cdf8cd2665b44d9026214395e5d92a135708719 Mon Sep 17 00:00:00 2001 From: JettHu Date: Sun, 28 Apr 2024 19:15:48 +0800 Subject: [PATCH 1/3] feat: add tcd scheduler --- comfy/k_diffusion/sampling.py | 35 ++++++++++++++++++++++++++++ comfy/model_sampling.py | 6 ++--- comfy/samplers.py | 2 +- comfy_extras/nodes_custom_sampler.py | 18 ++++++++++++++ comfy_extras/nodes_model_advanced.py | 5 +++- 5 files changed, 60 insertions(+), 6 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 7af016829d..a883062916 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -808,3 +808,38 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non d_prime = w1 * d + w2 * d_2 + w3 * d_3 x = x + d_prime * dt return x + +@torch.no_grad() +def sample_tcd( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + noise_sampler=None, + eta=0.3, +): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + timesteps_s = torch.floor((1 - eta) * model_sampling.timestep(sigmas)).to(dtype=torch.long).detach() + timesteps_s[-1] = 0 + alpha_prod_s = model_sampling.alphas_cumprod[timesteps_s] + beta_prod_s = 1 - alpha_prod_s + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) # predicted_original_sample + eps = (x - denoised) / sigmas[i] + denoised = alpha_prod_s[i + 1].sqrt() * denoised + beta_prod_s[i + 1].sqrt() * eps + + if callback is not None: + callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) + + x = denoised + if eta > 0 and sigmas[i + 1] > 0: + noise = noise_sampler(sigmas[i], sigmas[i + 1]) + x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt() + + return x diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 37976b326a..5872ac8a50 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -64,12 +64,10 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps self.linear_start = linear_start self.linear_end = linear_end - # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) - # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) - # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) - sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 self.set_sigmas(sigmas) + # register alphas_cumprod for some sampler, such as tcd. + self.register_buffer("alphas_cumprod", alphas_cumprod.float()) def set_sigmas(self, sigmas): self.register_buffer('sigmas', sigmas.float()) diff --git a/comfy/samplers.py b/comfy/samplers.py index b12b0fd1bf..5748a6e9de 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -506,7 +506,7 @@ def max_denoise(self, model_wrap, sigmas): KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 7afdbf4bf6..578e747d9e 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -315,6 +315,23 @@ def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_ "s_noise":s_noise }) return (sampler, ) +class SamplerTCD: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "eta": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta=0.3): + sampler = comfy.samplers.ksampler("tcd", {"eta": eta}) + return (sampler, ) + class Noise_EmptyNoise: def __init__(self): self.seed = 0 @@ -599,6 +616,7 @@ def add_noise(self, model, noise, sigmas, latent_image): "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMAdaptative": SamplerDPMAdaptative, + "SamplerTCD": SamplerTCD, "SplitSigmas": SplitSigmas, "FlipSigmas": FlipSigmas, diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 21af4b7333..2aa2805a1b 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -73,7 +73,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "tcd"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -95,6 +95,9 @@ def patch(self, model, sampling, zsnr): sampling_base = ModelSamplingDiscreteDistilled elif sampling == "x0": sampling_type = X0 + elif sampling == "tcd": + sampling_type = comfy.model_sampling.EPS + sampling_base= ModelSamplingDiscreteDistilled class ModelSamplingAdvanced(sampling_base, sampling_type): pass From 9b8f97c190928749a816c7a3636dbeb063d5ee54 Mon Sep 17 00:00:00 2001 From: JettHu Date: Tue, 30 Apr 2024 10:55:46 +0800 Subject: [PATCH 2/3] fix: Make sure timesteps_s in sample_tcd is on cpu --- comfy/k_diffusion/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index a883062916..d9032feb51 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -825,7 +825,7 @@ def sample_tcd( s_in = x.new_ones([x.shape[0]]) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") - timesteps_s = torch.floor((1 - eta) * model_sampling.timestep(sigmas)).to(dtype=torch.long).detach() + timesteps_s = torch.floor((1 - eta) * model_sampling.timestep(sigmas)).to(dtype=torch.long).detach().cpu() timesteps_s[-1] = 0 alpha_prod_s = model_sampling.alphas_cumprod[timesteps_s] beta_prod_s = 1 - alpha_prod_s From 6fd91a8f006636bf0a2ddf306a280b6ca0591af7 Mon Sep 17 00:00:00 2001 From: JettHu Date: Mon, 3 Jun 2024 19:36:15 +0800 Subject: [PATCH 3/3] fix: correct sampling when gamma is 0 --- comfy/k_diffusion/sampling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index d9032feb51..829fa383ec 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -841,5 +841,7 @@ def sample_tcd( if eta > 0 and sigmas[i + 1] > 0: noise = noise_sampler(sigmas[i], sigmas[i + 1]) x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt() + else: + x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2) return x