Skip to content

Commit

Permalink
Merge pull request #25 from NaruseMioShirakana/main
Browse files Browse the repository at this point in the history
浅扩散Onnx
  • Loading branch information
NaruseMioShirakana authored Jul 7, 2023
2 parents 178e2c0 + 812e106 commit ae4120a
Show file tree
Hide file tree
Showing 3 changed files with 606 additions and 22 deletions.
39 changes: 33 additions & 6 deletions diffusion/diffusion_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def extract(a, t):
return a[t].reshape((1, 1, 1, 1))


def extract2(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
Expand Down Expand Up @@ -214,6 +220,15 @@ def forward(self, x, t, interval, cond):
return x_prev


class AlphasCumprod(nn.Module):
def __init__(self, alphas_cumprod):
super().__init__()
self.alphas_cumprod = alphas_cumprod

def forward(self, t):
return extract(self.alphas_cumprod, t)


class GaussianDiffusion(nn.Module):
def __init__(self,
out_dims=128,
Expand Down Expand Up @@ -491,8 +506,9 @@ def get_x_pred(self, x_1, noise_t, t_1, t_prev):
return x_pred

def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, export_denoise=True, export_pred=True, export_after=True):
self.DDIM_pred = DDimNoisePredictor(self.alphas_cumprod, self.denoise_fn)
self.DDIM_pred = torch.jit.script(self.DDIM_pred)
# self.DDIM_pred = DDimNoisePredictor(self.alphas_cumprod, self.denoise_fn)
# self.DDIM_pred = torch.jit.script(self.DDIM_pred)
self.alpha = AlphasCumprod(self.alphas_cumprod)
self.denoise_fn = torch.jit.script(self.denoise_fn)

cond = torch.randn([1, self.n_hidden, 10]).cpu()
Expand All @@ -516,7 +532,7 @@ def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, ex
torch.onnx.export(
self.denoise_fn,
(x.cpu(), ot_1.cpu(), cond.cpu()),
f"exp/{project_name}/{project_name}_denoise.onnx",
f"checkpoints/{project_name}/{project_name}_denoise.onnx",
input_names=["noise", "time", "condition"],
output_names=["noise_pred"],
dynamic_axes={
Expand All @@ -536,7 +552,7 @@ def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, ex
torch.onnx.export(
self.xp,
(x.cpu(), noise_pred.cpu(), t_1.cpu(), t_prev.cpu()),
f"exp/{project_name}/{project_name}_pred.onnx",
f"checkpoints/{project_name}/{project_name}_pred.onnx",
input_names=["noise", "noise_pred", "time", "time_prev"],
output_names=["noise_pred_o"],
dynamic_axes={
Expand All @@ -546,10 +562,20 @@ def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, ex
opset_version=16
)

torch.onnx.export(
self.alpha,
(t_1.cpu()),
f"checkpoints/{project_name}/{project_name}_alpha.onnx",
input_names=["time"],
output_names=["alphas_cumprod"],
opset_version=16
)

'''
torch.onnx.export(
self.DDIM_pred,
(x.cpu(), t_1.cpu(), spd_up_ddim.cpu(), cond.cpu()),
f"exp/{project_name}/{project_name}_ddim_pred.onnx",
f"checkpoints/{project_name}/{project_name}_ddim_pred.onnx",
input_names=["noise", "time", "speedup", "time_prev"],
output_names=["noise_pred_o"],
dynamic_axes={
Expand All @@ -558,6 +584,7 @@ def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, ex
},
opset_version=16
)
'''

x_pred = self.get_x_pred(x, noise_pred, t_1, t_prev)
noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond=cond)
Expand Down Expand Up @@ -586,7 +613,7 @@ def OnnxExport(self, project_name=None, init_noise=None, hidden_channels=256, ex
torch.onnx.export(
self.ad,
x.cpu(),
f"exp/{project_name}/{project_name}_after.onnx",
f"checkpoints/{project_name}/{project_name}_after.onnx",
input_names=["x"],
output_names=["mel_out"],
dynamic_axes={
Expand Down
Loading

0 comments on commit ae4120a

Please sign in to comment.