|
|
import torch |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
import os |
|
|
import random |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from lakonlab.models.diffusions.schedulers.flow_map_sde import FlowMapSDEScheduler |
|
|
from lakonlab.ui.gradio.create_img_edit import create_interface_img_edit |
|
|
from lakonlab.pipelines.pipeline_piflux2 import PiFlux2Pipeline |
|
|
from lakonlab.pipelines.prompt_rewriters.qwen3_vl import Qwen3VLPromptRewriter |
|
|
|
|
|
|
|
|
DEFAULT_PROMPT = """Museum-style FIELD GUIDE poster on neutral parchment (#F3EEE3). Use Inter (or Helvetica/Arial). All text #2D3748, thin connector lines 1px #A0AEC0. |
|
|
|
|
|
Center: full-body original fantasy creature, 3/4 standing pose. Around it: four small inset boxes labeled exactly "EYE DETAIL", "FOOT DETAIL", "SKIN TEXTURE", "SILHOUETTE SCALE" (with a simple human comparison silhouette). Bottom: a short footprint trail diagram. One small habitat vignette (misty rocky shoreline with tide pools). |
|
|
|
|
|
Exact text (only these, clean print layout): |
|
|
Top: "FIELD GUIDE" |
|
|
Sub: "AURORA SHOREWALKER" |
|
|
Small line: "CLASS: COASTAL DRIFTER" |
|
|
Under silhouette: "HEIGHT: 1.7 m" |
|
|
|
|
|
Crisp ink outlines with soft watercolor-like fills, high readability, balanced hierarchy, premium poster aesthetic.""" |
|
|
|
|
|
SYSTEM_PROMPT_TEXT_ONLY_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_text_only.txt' |
|
|
SYSTEM_PROMPT_WITH_IMAGES_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_with_images.txt' |
|
|
|
|
|
|
|
|
def _patch_diffusers_bnb_shape_check(): |
|
|
try: |
|
|
import diffusers.quantizers.bitsandbytes.bnb_quantizer as bnbq |
|
|
except Exception: |
|
|
return |
|
|
|
|
|
def _numel(shape): |
|
|
if shape is None: |
|
|
return None |
|
|
if hasattr(shape, "numel"): |
|
|
return int(shape.numel()) |
|
|
|
|
|
n = 1 |
|
|
for d in shape: |
|
|
n *= int(d) |
|
|
return n |
|
|
|
|
|
def patched_check(self, param_name, current_param, loaded_param): |
|
|
cshape = getattr(current_param, "shape", None) |
|
|
lshape = getattr(loaded_param, "shape", None) |
|
|
n = _numel(cshape) |
|
|
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) |
|
|
if tuple(lshape) != tuple(inferred_shape): |
|
|
raise ValueError( |
|
|
f"Expected flattened shape mismatch for {param_name}: " |
|
|
f"loaded={tuple(lshape)} inferred={tuple(inferred_shape)}" |
|
|
) |
|
|
return True |
|
|
|
|
|
|
|
|
for name, obj in vars(bnbq).items(): |
|
|
if isinstance(obj, type) and hasattr(obj, "check_quantized_param_shape"): |
|
|
setattr(obj, "check_quantized_param_shape", patched_check) |
|
|
|
|
|
|
|
|
_patch_diffusers_bnb_shape_check() |
|
|
|
|
|
|
|
|
pipe = PiFlux2Pipeline.from_pretrained( |
|
|
'diffusers/FLUX.2-dev-bnb-4bit', |
|
|
torch_dtype=torch.bfloat16) |
|
|
pipe.load_piflow_adapter( |
|
|
'Lakonik/pi-FLUX.2', |
|
|
subfolder='gmflux2_k8_piid_4step', |
|
|
target_module_name='transformer') |
|
|
pipe.scheduler = FlowMapSDEScheduler.from_config( |
|
|
pipe.scheduler.config, shift=3.2, use_dynamic_shifting=False, final_step_size_scale=0.5) |
|
|
pipe = pipe.to('cuda') |
|
|
|
|
|
prompt_rewriter = Qwen3VLPromptRewriter( |
|
|
device_map="cuda", |
|
|
system_prompt_text_only=open(SYSTEM_PROMPT_TEXT_ONLY_PATH, 'r').read(), |
|
|
system_prompt_wigh_images=open(SYSTEM_PROMPT_WITH_IMAGES_PATH, 'r').read(), |
|
|
max_new_tokens_default=512, |
|
|
) |
|
|
|
|
|
|
|
|
def set_random_seed(seed: int, deterministic: bool = True) -> None: |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
if deterministic: |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def run_rewrite_prompt_gpu(seed, prompt, image_list, progress): |
|
|
set_random_seed(seed) |
|
|
progress(0.05, desc="Rewriting prompt...") |
|
|
if image_list is None: |
|
|
final_prompt = prompt_rewriter.rewrite_text_batch( |
|
|
[prompt])[0] |
|
|
else: |
|
|
final_prompt = prompt_rewriter.rewrite_edit_batch( |
|
|
[image_list], [prompt])[0] |
|
|
return final_prompt |
|
|
|
|
|
|
|
|
def run_rewrite_prompt(seed, prompt, rewrite_prompt, in_image, progress=gr.Progress(track_tqdm=True)): |
|
|
image_list = None |
|
|
if in_image is not None and len(in_image) > 0: |
|
|
image_list = [] |
|
|
for item in in_image: |
|
|
image_list.append(item[0]) |
|
|
if rewrite_prompt: |
|
|
final_prompt = run_rewrite_prompt_gpu(seed, prompt, image_list, progress) |
|
|
return final_prompt, None |
|
|
else: |
|
|
return '', None |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate( |
|
|
seed, prompt, rewrite_prompt, rewritten_prompt, in_image, width, height, steps, |
|
|
progress=gr.Progress(track_tqdm=True)): |
|
|
image_list = None |
|
|
if in_image is not None and len(in_image) > 0: |
|
|
image_list = [] |
|
|
for item in in_image: |
|
|
image_list.append(item[0]) |
|
|
return pipe( |
|
|
image=image_list, |
|
|
prompt=rewritten_prompt if rewrite_prompt else prompt, |
|
|
width=width, |
|
|
height=height, |
|
|
num_inference_steps=steps, |
|
|
generator=torch.Generator().manual_seed(seed), |
|
|
).images[0] |
|
|
|
|
|
|
|
|
with gr.Blocks(analytics_enabled=False, |
|
|
title='pi-FLUX.2 Demo', |
|
|
css_paths='lakonlab/ui/gradio/style.css' |
|
|
) as demo: |
|
|
|
|
|
md_txt = '# pi-FLUX.2 Demo\n\n' \ |
|
|
'Official demo of the paper [pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation](https://arxiv.org/abs/2510.14974). ' \ |
|
|
'**Base model:** [FLUX.2 dev](https://huggingface.co/black-forest-labs/FLUX.2-dev). **Fast policy:** GMFlow. **Code:** [https://github.com/Lakonik/piFlow](https://github.com/Lakonik/piFlow).\n' \ |
|
|
'<br> Use and distribution of this app are governed by the [FLUX [dev] Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/LICENSE.txt).' |
|
|
gr.Markdown(md_txt) |
|
|
|
|
|
create_interface_img_edit( |
|
|
generate, |
|
|
prompt=DEFAULT_PROMPT, |
|
|
steps=4, guidance_scale=None, |
|
|
args=['last_seed', 'prompt', 'rewrite_prompt', 'rewritten_prompt', 'in_image', 'width', 'height', 'steps'], |
|
|
rewrite_prompt_api=run_rewrite_prompt, |
|
|
rewrite_prompt_args=['last_seed', 'prompt', 'rewrite_prompt', 'in_image'], |
|
|
height=1024, |
|
|
width=1024 |
|
|
) |
|
|
demo.queue().launch() |
|
|
|