Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddpm推理的尝试 #1

Open
wdcww opened this issue Nov 30, 2024 · 2 comments
Open

ddpm推理的尝试 #1

wdcww opened this issue Nov 30, 2024 · 2 comments

Comments

@wdcww
Copy link
Owner

wdcww commented Nov 30, 2024

我注意到在Repaint仓库的confs/里的那些.yml,
当使用celeba256_250000.pt和places256_300000.pt作为model_path时,都没有键为classifier_path的键值对。
所以我优先去关注了一下没有涉及classifier的推理(无条件引导推理?),在 commit 1 可以看到我注释掉了很多涉及classifier_path的部分。

关于函数p_sample()中的ddpm采样

.yml文件里 predict_xstart: false 会导致 model_mean_type 是 ModelMeanType.EPSILON,

此时,会把每次的model_output看作是预测的噪声eps,
eps通过函数_predict_xstart_from_eps() 得到预测的原始图像 x_0,
进而再使用q_posterior_mean_variance()得到posterior_mean ,
posterior_mean 返回到p_sample()中,即是out["mean"]。

则ddpm的采样通过sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise得到。

@wdcww
Copy link
Owner Author

wdcww commented Nov 30, 2024

然而,我在看到ddpm论文里面的Algorithm 2 后,尝试去把p_sample()中获得sample方法替换为符合Algorithm 2 的样子。


为此,我做了以下三处修改:

Repaint的gaussion_diffusion.py的这里添加了

# 测试和ddpm论文里面一样的推理公式

 self.one_chu_sqrt_alpha = 1.0 / np.sqrt(alphas)

 self.betas_chu_sqrt_one_mins_alpha_cumprod = self.betas / np.sqrt(1.0 - self.alphas_cumprod)

这里,把

return {
            "mean": model_mean,
            "variance": model_variance,
            "log_variance": model_log_variance,
            "pred_xstart": pred_xstart,
        }

替换为了:

return {
            "mean": model_mean,
            "variance": model_variance,
            "log_variance": model_log_variance,
            "pred_xstart": pred_xstart,
            "model_output": model_output ### 为了 测试和ddpm论文里面一样的推理公式
        }

在p_sample()的这一行新添加了一个if-else分支,将原来的获得sample的式子写在else里(如下所示),

        if conf.is_ddpm_paper_get_xprev:
            print("conf.is_ddpm_paper_get_xprev")
            sample = (_extract_into_tensor(self.one_chu_sqrt_alpha, t, x.shape)
                  * (
                          x - out["model_output"] * _extract_into_tensor(self.betas_chu_sqrt_one_mins_alpha_cumprod,t, x.shape)
                     )
                 ) + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
        else:
            sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise

可以看到,需要在.yml文件里面新增一行配置,像是下面这样:
is_ddpm_paper_get_xprev: false


以上的三处修改,分别可以在我的代码中看到

@wdcww
Copy link
Owner Author

wdcww commented Nov 30, 2024

《Tutorial on Diffusion Models for Imaging and Vision》 的公式(56),和p_sample()中原本的sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 如出一辙?

而我添加的sample = (_extract_into_tensor(self.one_chu_sqrt_alpha, t, x.shape) * ( x - out["model_output"] * _extract_into_tensor(self.betas_chu_sqrt_one_mins_alpha_cumprod,t, x.shape) ) ) + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 应该是想去实现《Tutorial on Diffusion Models for Imaging and Vision》的公式(58)进而做”Inference on a Deniosing Diffusion Probabilistic Model. (Version Predict noise)“ ?


这样的话,想要使用sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
则应该认为model_output是pred_xstart,(即下面的文件中self.model_mean_type == ModelMeanType.START_X),
然后再送入注释有"# # # q(x_{t-1} | x_t, x_0) # # #"那行的函数

 elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
            if self.model_mean_type == ModelMeanType.START_X:
                # 模型直接预测 x_0(原始图像) (预测 x_0 版本)
                # 如果启用了 predict_xstart=True,model_mean_type 就是这个
                pred_xstart = process_xstart(model_output)
            else:
                # self.model_mean_type 是 ModelMeanType.EPSILON (当.yml文件里 predict_xstart: false) ###
                # 模型预测的是 噪声项epsilon,
                # 通过噪声和当前 x_t 推导出原始图像 x_0
                pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))

            model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) # # # q(x_{t-1} | x_t, x_0) # # #

而如果使用我添加的sample = (_extract_into_tensor(self.one_chu_sqrt_alpha, t, x.shape) * ( x - out["model_output"] * _extract_into_tensor(self.betas_chu_sqrt_one_mins_alpha_cumprod,t, x.shape) ) ) + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
则应该认为model_output是噪声。


猜想:

.yml配置里的参数应该是这样的两个组合:

认为model_output是 预测的噪声 时:

is_ddpm_paper_get_xprev: true
predict_xstart: false

认为model_output是 预测的x0 时:

is_ddpm_paper_get_xprev: false
predict_xstart: true

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant