Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| def view_images(images, num_rows=1, offset_ratio=0.02): | |
| if type(images) is list: | |
| num_empty = len(images) % num_rows | |
| elif images.ndim == 4: | |
| num_empty = images.shape[0] % num_rows | |
| else: | |
| images = [images] | |
| num_empty = 0 | |
| empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
| images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
| num_items = len(images) | |
| h, w, c = images[0].shape | |
| offset = int(h * offset_ratio) | |
| num_cols = num_items // num_rows | |
| image_ = np.ones((h * num_rows + offset * (num_rows - 1), | |
| w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 | |
| for i in range(num_rows): | |
| for j in range(num_cols): | |
| image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ | |
| i * num_cols + j] | |
| pil_img = Image.fromarray(image_) | |
| return pil_img | |
| def diffusion_step(model, latents, context, t, guidance_scale, low_resource=False): | |
| if low_resource: | |
| noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] | |
| noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] | |
| else: | |
| latents_input = torch.cat([latents] * 2) | |
| noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] | |
| noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) | |
| latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] | |
| return latents | |
| def latent2image(vae, latents): | |
| latents = 1 / 0.18215 * latents | |
| image = vae.decode(latents)['sample'] | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy() | |
| image = (image * 255).astype(np.uint8) | |
| return image | |
| def init_latent(latent, model, height, width, generator, batch_size): | |
| if latent is None: | |
| latent = torch.randn( | |
| (1, model.unet.in_channels, height // 8, width // 8), | |
| generator=generator, | |
| ) | |
| latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) | |
| return latent, latents | |
| def text2image_ldm_stable( | |
| model, | |
| prompt, | |
| num_inference_steps = 50, | |
| guidance_scale = 7.5, | |
| generator = None, | |
| latent = None, | |
| low_resource = False, | |
| ): | |
| height = width = 512 | |
| batch_size = len(prompt) | |
| text_input = model.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=model.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] | |
| max_length = text_input.input_ids.shape[-1] | |
| uncond_input = model.tokenizer( | |
| [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
| ) | |
| uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] | |
| context = [uncond_embeddings, text_embeddings] | |
| if not low_resource: | |
| context = torch.cat(context) | |
| latent, latents = init_latent(latent, model, height, width, generator, batch_size) | |
| model.scheduler.set_timesteps(num_inference_steps) | |
| for t in model.scheduler.timesteps: | |
| latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource) | |
| image = latent2image(model.vae, latents) | |
| image, _ = model.run_safety_checker(image=image, device=model.device, dtype=text_embeddings.dtype) | |
| return image |