Spaces:
Running
Running
| import torch | |
| def ddpm_sampler( | |
| net, | |
| batch, | |
| conditioning_keys=None, | |
| scheduler=None, | |
| uncond_tokens=None, | |
| num_steps=1000, | |
| cfg_rate=0, | |
| generator=None, | |
| use_confidence_sampling=False, | |
| use_uncond_token=True, | |
| confidence_value=1.0, | |
| unconfidence_value=0.0, | |
| ): | |
| if scheduler is None: | |
| raise ValueError("Scheduler must be provided") | |
| x_cur = batch["y"].to(torch.float32) | |
| latents = batch["previous_latents"] | |
| if use_confidence_sampling: | |
| batch["confidence"] = ( | |
| torch.ones(x_cur.shape[0], device=x_cur.device) * confidence_value | |
| ) | |
| step_indices = torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device) | |
| steps = 1 - step_indices / num_steps | |
| gammas = scheduler(steps) | |
| latents_cond = latents_uncond = latents | |
| # dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| dtype = torch.float32 | |
| if cfg_rate > 0 and conditioning_keys is not None: | |
| stacked_batch = {} | |
| for key in conditioning_keys: | |
| if f"{key}_mask" in batch: | |
| if use_confidence_sampling and not use_uncond_token: | |
| stacked_batch[f"{key}_mask"] = torch.cat( | |
| [batch[f"{key}_mask"], batch[f"{key}_mask"]], dim=0 | |
| ) | |
| else: | |
| if ( | |
| batch[f"{key}_mask"].shape[1] | |
| > uncond_tokens[f"{key}_mask"].shape[1] | |
| ): | |
| uncond_mask = ( | |
| torch.zeros_like(batch[f"{key}_mask"]) | |
| if batch[f"{key}_mask"].dtype == torch.bool | |
| else torch.ones_like(batch[f"{key}_mask"]) * -torch.inf | |
| ) | |
| uncond_mask[:, : uncond_tokens[f"{key}_mask"].shape[1]] = ( | |
| uncond_tokens[f"{key}_mask"] | |
| ) | |
| else: | |
| uncond_mask = uncond_tokens[f"{key}_mask"] | |
| batch[f"{key}_mask"] = torch.cat( | |
| [ | |
| batch[f"{key}_mask"], | |
| torch.zeros( | |
| batch[f"{key}_mask"].shape[0], | |
| uncond_tokens[f"{key}_embeddings"].shape[1] | |
| - batch[f"{key}_mask"].shape[1], | |
| device=batch[f"{key}_mask"].device, | |
| dtype=batch[f"{key}_mask"].dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| stacked_batch[f"{key}_mask"] = torch.cat( | |
| [batch[f"{key}_mask"], uncond_mask], dim=0 | |
| ) | |
| if f"{key}_embeddings" in batch: | |
| if use_confidence_sampling and not use_uncond_token: | |
| stacked_batch[f"{key}_embeddings"] = torch.cat( | |
| [ | |
| batch[f"{key}_embeddings"], | |
| batch[f"{key}_embeddings"], | |
| ], | |
| dim=0, | |
| ) | |
| else: | |
| if ( | |
| batch[f"{key}_embeddings"].shape[1] | |
| > uncond_tokens[f"{key}_embeddings"].shape[1] | |
| ): | |
| uncond_tokens[f"{key}_embeddings"] = torch.cat( | |
| [ | |
| uncond_tokens[f"{key}_embeddings"], | |
| torch.zeros( | |
| uncond_tokens[f"{key}_embeddings"].shape[0], | |
| batch[f"{key}_embeddings"].shape[1] | |
| - uncond_tokens[f"{key}_embeddings"].shape[1], | |
| uncond_tokens[f"{key}_embeddings"].shape[2], | |
| device=uncond_tokens[f"{key}_embeddings"].device, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| elif ( | |
| batch[f"{key}_embeddings"].shape[1] | |
| < uncond_tokens[f"{key}_embeddings"].shape[1] | |
| ): | |
| batch[f"{key}_embeddings"] = torch.cat( | |
| [ | |
| batch[f"{key}_embeddings"], | |
| torch.zeros( | |
| batch[f"{key}_embeddings"].shape[0], | |
| uncond_tokens[f"{key}_embeddings"].shape[1] | |
| - batch[f"{key}_embeddings"].shape[1], | |
| batch[f"{key}_embeddings"].shape[2], | |
| device=batch[f"{key}_embeddings"].device, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| stacked_batch[f"{key}_embeddings"] = torch.cat( | |
| [ | |
| batch[f"{key}_embeddings"], | |
| uncond_tokens[f"{key}_embeddings"], | |
| ], | |
| dim=0, | |
| ) | |
| elif key not in batch: | |
| raise ValueError(f"Key {key} not in batch") | |
| else: | |
| if isinstance(batch[key], torch.Tensor): | |
| if use_confidence_sampling and not use_uncond_token: | |
| stacked_batch[key] = torch.cat([batch[key], batch[key]], dim=0) | |
| else: | |
| stacked_batch[key] = torch.cat( | |
| [batch[key], uncond_tokens], dim=0 | |
| ) | |
| elif isinstance(batch[key], list): | |
| if use_confidence_sampling and not use_uncond_token: | |
| stacked_batch[key] = [*batch[key], *batch[key]] | |
| else: | |
| stacked_batch[key] = [*batch[key], *uncond_tokens] | |
| else: | |
| raise ValueError( | |
| "Conditioning must be a tensor or a list of tensors" | |
| ) | |
| if use_confidence_sampling: | |
| stacked_batch["confidence"] = torch.cat( | |
| [ | |
| torch.ones(x_cur.shape[0], device=x_cur.device) * confidence_value, | |
| torch.ones(x_cur.shape[0], device=x_cur.device) | |
| * unconfidence_value, | |
| ], | |
| dim=0, | |
| ) | |
| for step, (gamma_now, gamma_next) in enumerate(zip(gammas[:-1], gammas[1:])): | |
| with torch.cuda.amp.autocast(dtype=dtype): | |
| if cfg_rate > 0 and conditioning_keys is not None: | |
| stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0) | |
| stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2) | |
| stacked_batch["previous_latents"] = ( | |
| torch.cat([latents_cond, latents_uncond], dim=0) | |
| if latents is not None | |
| else None | |
| ) | |
| denoised_all, latents_all = net(stacked_batch) | |
| denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0) | |
| latents_cond, latents_uncond = latents_all.chunk(2, dim=0) | |
| denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate | |
| else: | |
| batch["y"] = x_cur | |
| batch["gamma"] = gamma_now.expand(x_cur.shape[0]) | |
| batch["previous_latents"] = latents | |
| denoised, latents = net( | |
| batch, | |
| ) | |
| x_pred = (x_cur - torch.sqrt(1 - gamma_now) * denoised) / torch.sqrt(gamma_now) | |
| x_pred = torch.clamp(x_pred, -1, 1) | |
| noise_pred = (x_cur - torch.sqrt(gamma_now) * x_pred) / torch.sqrt( | |
| 1 - gamma_now | |
| ) | |
| log_alpha_t = torch.log(gamma_now) - torch.log(gamma_next) | |
| alpha_t = torch.clip(torch.exp(log_alpha_t), 0, 1) | |
| x_mean = torch.rsqrt(alpha_t) * ( | |
| x_cur - torch.rsqrt(1 - gamma_now) * (1 - alpha_t) * noise_pred | |
| ) | |
| var_t = 1 - alpha_t | |
| eps = torch.randn(x_cur.shape, device=x_cur.device, generator=generator) | |
| x_next = x_mean + torch.sqrt(var_t) * eps | |
| x_cur = x_next | |
| return x_cur.to(torch.float32) | |