| | import gradio as gr |
| | import torch |
| | import io |
| | from PIL import Image |
| | import numpy as np |
| | import spaces |
| | import math |
| | import re |
| | from einops import rearrange |
| | from mmengine.config import Config |
| | from src.builder import BUILDER |
| |
|
| | import matplotlib |
| | matplotlib.use("Agg") |
| | import matplotlib.pyplot as plt |
| |
|
| | from scripts.camera.cam_dataset import Cam_Generator |
| | from scripts.camera.visualization.visualize_batch import make_perspective_figures |
| |
|
| | from huggingface_hub import snapshot_download |
| | import os |
| |
|
| | NUM = r"[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?" |
| | CAM_PATTERN = re.compile(r"(?:camera parameters.*?:|roll.*?:)\s*("+NUM+r")\s*,\s*("+NUM+r")\s*,\s*("+NUM+r")", re.IGNORECASE|re.DOTALL) |
| |
|
| | def center_crop(image): |
| | w, h = image.size |
| | s = min(w, h) |
| | l = (w - s) // 2 |
| | t = (h - s) // 2 |
| | return image.crop((l, t, l + s, t + s)) |
| |
|
| |
|
| | |
| | config = "configs/pipelines/stage_2_base.py" |
| | config = Config.fromfile(config) |
| | model = BUILDER.build(config.model).eval() |
| | _ = snapshot_download( |
| | repo_id="KangLiao/Puffin", |
| | repo_type="model", |
| | allow_patterns="Puffin-Base.pth", |
| | local_dir="checkpoints/", |
| | local_dir_use_symlinks=False, |
| | revision="main", |
| | ) |
| | _ = model.load_state_dict(torch.load("checkpoints/Puffin-Base.pth", map_location='cpu'), strict=False) |
| | os.remove("checkpoints/Puffin-Base.pth") |
| |
|
| | _ = snapshot_download( |
| | repo_id="wusize/Puffin", |
| | repo_type="model", |
| | local_dir="checkpoints/", |
| | local_dir_use_symlinks=False, |
| | revision="main", |
| | ) |
| | _ = model.vae.load_state_dict(torch.load('checkpoints/vae.pth', map_location='cpu'), strict=True) |
| | os.remove('checkpoints/vae.pth') |
| |
|
| |
|
| | if torch.cuda.is_available(): |
| | model = model.to(torch.bfloat16).cuda() |
| | else: |
| | model = model.to(torch.float32) |
| |
|
| |
|
| | def fig_to_image(fig): |
| | buf = io.BytesIO() |
| | fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
| | buf.seek(0) |
| | img = Image.open(buf).convert('RGB') |
| | buf.close() |
| | return img |
| |
|
| | def extract_up_lat_figs(fig_dict): |
| | fig_up, fig_lat = None, None |
| | others = {} |
| | for k, fig in fig_dict.items(): |
| | if ("up_field" in k) and (fig_up is None): |
| | fig_up = fig |
| | elif ("latitude_field" in k) and (fig_lat is None): |
| | fig_lat = fig |
| | else: |
| | others[k] = fig |
| | return fig_up, fig_lat, others |
| |
|
| |
|
| | @torch.inference_mode() |
| | @spaces.GPU(duration=120) |
| | |
| | def camera_understanding(image_src, question, seed, progress=gr.Progress(track_tqdm=True)): |
| | |
| | torch.cuda.empty_cache() |
| | |
| | |
| | |
| | |
| | |
| | print(torch.cuda.is_available()) |
| |
|
| | prompt = ("Describe the image in detail. Then reason its spatial distribution and estimate its camera parameters (roll, pitch, and field-of-view).") |
| |
|
| | image = Image.fromarray(image_src).convert('RGB') |
| | image = center_crop(image) |
| | image = image.resize((512, 512)) |
| | x = torch.from_numpy(np.array(image)).float() |
| | x = x / 255.0 |
| | x = 2 * x - 1 |
| | x = rearrange(x, 'h w c -> c h w') |
| |
|
| | with torch.no_grad(): |
| | outputs = model.understand(prompt=[prompt], pixel_values=[x], progress_bar=False) |
| |
|
| | text = outputs[0] |
| | |
| | gen = Cam_Generator(mode="base") |
| | cam = gen.get_cam(text) |
| | |
| | bgr = np.array(image)[:, :, ::-1].astype(np.float32) / 255.0 |
| | rgb = bgr[:, :, ::-1].copy() |
| | image_tensor = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) |
| | single_batch = {} |
| | single_batch["image"] = image_tensor |
| | single_batch["up_field"] = cam[:2].unsqueeze(0) |
| | single_batch["latitude_field"] = cam[2:].unsqueeze(0) |
| |
|
| | figs = make_perspective_figures(single_batch, single_batch, n_pairs=1) |
| | up_img = lat_img = None |
| | for k, fig in figs.items(): |
| | if "up_field" in k: |
| | up_img = fig_to_image(fig) |
| | elif "latitude_field" in k: |
| | lat_img = fig_to_image(fig) |
| | plt.close(fig) |
| |
|
| | return text |
| |
|
| |
|
| | @torch.inference_mode() |
| | @spaces.GPU(duration=120) |
| | def generate_image(prompt_scene, |
| | seed=42, |
| | roll=0.1, |
| | pitch=0.1, |
| | fov=1.0, |
| | progress=gr.Progress(track_tqdm=True)): |
| | |
| | torch.cuda.empty_cache() |
| | |
| | |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | np.random.seed(seed) |
| | print(torch.cuda.is_available()) |
| | |
| | generator = torch.Generator().manual_seed(seed) |
| | prompt_camera = ( |
| | "The camera parameters (roll, pitch, and field-of-view) are: " |
| | f"{roll:.4f}, {pitch:.4f}, {fov:.4f}." |
| | ) |
| | gen = Cam_Generator() |
| | cam_map = gen.get_cam(prompt_camera).to(model.device) |
| | cam_map = cam_map / (math.pi / 2) |
| | |
| | prompt = prompt_scene + " " + prompt_camera |
| | print("prompt:", prompt) |
| | |
| | bsz = 4 |
| | with torch.no_grad(): |
| | images, output_reasoning = model.generate( |
| | prompt=[prompt]*bsz, |
| | cfg_prompt=[""]*bsz, |
| | pixel_values_init=None, |
| | cfg_scale=4.5, |
| | num_steps=50, |
| | cam_values=[[cam_map]]*bsz, |
| | progress_bar=False, |
| | reasoning=False, |
| | prompt_reasoning=[""]*bsz, |
| | generator=generator, |
| | height=512, |
| | width=512 |
| | ) |
| |
|
| | images = rearrange(images, 'b c h w -> b h w c') |
| | images = torch.clamp(127.5 * images + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() |
| | ret_images = [Image.fromarray(image) for image in images] |
| | return ret_images |
| |
|
| |
|
| | |
| | css = ''' |
| | .gradio-container {max-width: 960px !important} |
| | ''' |
| | with gr.Blocks(css=css) as demo: |
| | gr.Markdown("# Puffin") |
| |
|
| | with gr.Tab("Camera-controllable Image Generation"): |
| | gr.Markdown(value="## Camera-controllable Image Generation") |
| |
|
| | prompt_input = gr.Textbox(label="Prompt.") |
| |
|
| | with gr.Accordion("Camera Parameters", open=True): |
| | with gr.Row(): |
| | roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value") |
| | pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value") |
| | fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value") |
| | seed_input = gr.Number(label="Seed (Optional)", precision=0, value=42) |
| | |
| | generation_button = gr.Button("Generate Images") |
| | |
| | image_output = gr.Gallery(label="Generated Images", columns=4, rows=1) |
| | |
| | examples_t2i = gr.Examples( |
| | label="Prompt examples.", |
| | examples=[ |
| | "A sunny day casts light on two warmly colored buildings—yellow with green accents and deeper orange—framed by a lush green tree, with a blue sign and street lamp adding details in the foreground.", |
| | "A high-vantage-point view of lush, autumn-colored mountains blanketed in green and gold, set against a clear blue sky with scattered white clouds, offering a tranquil and breathtaking vista of a serene valley below.", |
| | "A grand, historic castle with pointed spires and elaborate stone structures stands against a clear blue sky, flanked by a circular fountain, vibrant red flowers, and neatly trimmed hedges in a beautifully landscaped garden.", |
| | "A serene aerial view of a coastal landscape at sunrise/sunset, featuring warm pink and orange skies transitioning to cool blues, with calm waters stretching to rugged, snow-capped mountains in the background, creating a tranquil and picturesque scene.", |
| | "A worn, light-yellow walls room with herringbone terracotta floors and three large arched windows framed in pink trim and white panes, showcasing signs of age and disrepair, overlooks a residential area through glimpses of greenery and neighboring buildings.", |
| | ], |
| | inputs=prompt_input, |
| | ) |
| |
|
| | with gr.Tab("Camera Understanding"): |
| | gr.Markdown(value="## Camera Understanding") |
| | image_input = gr.Image() |
| |
|
| | understanding_button = gr.Button("Chat") |
| | understanding_output = gr.Textbox(label="Response") |
| | |
| | |
| | |
| |
|
| | with gr.Accordion("Advanced options", open=False): |
| | und_seed_input = gr.Number(label="Seed", precision=0, value=42) |
| |
|
| | examples_inpainting = gr.Examples( |
| | label="Camera Understanding examples", |
| | examples=[ |
| | "assets/1.jpg", |
| | "assets/2.jpg", |
| | "assets/3.jpg", |
| | "assets/4.jpg", |
| | "assets/5.jpg", |
| | "assets/6.jpg", |
| | ], |
| | inputs=image_input, |
| | ) |
| |
|
| | generation_button.click( |
| | fn=generate_image, |
| | inputs=[prompt_input, seed_input, roll, pitch, fov], |
| | outputs=image_output |
| | ) |
| |
|
| | understanding_button.click( |
| | camera_understanding, |
| | inputs=[image_input, und_seed_input], |
| | outputs=[understanding_output] |
| | ) |
| |
|
| | demo.launch(share=True) |