Hansheng Chen
commited on
Commit
·
5d99e98
1
Parent(s):
d46bfe9
Release pi-FLUX.2 demo
Browse files- .gitignore +25 -0
- LICENSE.md +9 -0
- README.md +19 -4
- app.py +163 -0
- lakonlab/__init__.py +0 -0
- lakonlab/models/__init__.py +0 -0
- lakonlab/models/architecture/__init__.py +0 -0
- lakonlab/models/architecture/gmflow/__init__.py +0 -0
- lakonlab/models/architecture/gmflow/gm_output.py +24 -0
- lakonlab/models/architecture/gmflow/gmflux2.py +241 -0
- lakonlab/models/diffusions/__init__.py +0 -0
- lakonlab/models/diffusions/piflow_policies/__init__.py +8 -0
- lakonlab/models/diffusions/piflow_policies/base.py +21 -0
- lakonlab/models/diffusions/piflow_policies/dx.py +123 -0
- lakonlab/models/diffusions/piflow_policies/gmflow.py +174 -0
- lakonlab/models/diffusions/schedulers/__init__.py +0 -0
- lakonlab/models/diffusions/schedulers/flow_map_sde.py +186 -0
- lakonlab/pipelines/__init__.py +0 -0
- lakonlab/pipelines/piflow_utils.py +395 -0
- lakonlab/pipelines/pipeline_piflux2.py +395 -0
- lakonlab/pipelines/prompt_rewriters/__init__.py +0 -0
- lakonlab/pipelines/prompt_rewriters/qwen3_vl.py +172 -0
- lakonlab/pipelines/prompt_rewriters/system_prompts/default_text_only.txt +12 -0
- lakonlab/pipelines/prompt_rewriters/system_prompts/default_with_images.txt +10 -0
- lakonlab/ui/__init__.py +0 -0
- lakonlab/ui/gradio/__init__.py +0 -0
- lakonlab/ui/gradio/create_img_edit.py +88 -0
- lakonlab/ui/gradio/create_text_to_img.py +41 -0
- lakonlab/ui/gradio/shared_opts.py +87 -0
- lakonlab/ui/gradio/style.css +73 -0
- requirements.txt +10 -0
.gitignore
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/.idea/
|
| 2 |
+
/work_dirs*
|
| 3 |
+
.vscode/
|
| 4 |
+
/tmp
|
| 5 |
+
/data
|
| 6 |
+
/checkpoints
|
| 7 |
+
*.so
|
| 8 |
+
*.patch
|
| 9 |
+
__pycache__/
|
| 10 |
+
*.egg-info/
|
| 11 |
+
/viz*
|
| 12 |
+
/submit*
|
| 13 |
+
build/
|
| 14 |
+
*.pyd
|
| 15 |
+
/cache*
|
| 16 |
+
*.stl
|
| 17 |
+
*.pth
|
| 18 |
+
/venv/
|
| 19 |
+
.nk8s
|
| 20 |
+
*.mp4
|
| 21 |
+
.vs
|
| 22 |
+
/exp/
|
| 23 |
+
/dev/
|
| 24 |
+
*.pyi
|
| 25 |
+
!/data/imagenet/imagenet1000_clsidx_to_labels.txt
|
LICENSE.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# License for pi-FLUX.2
|
| 2 |
+
|
| 3 |
+
This repository distributes a **pi-FLUX.2 app** that is a **Derivative** of **FLUX.2 [dev]** by **Black Forest Labs Inc.**
|
| 4 |
+
|
| 5 |
+
Use and distribution of these adapters are governed by the **FLUX [dev] Non-Commercial License v2.0**.
|
| 6 |
+
**No commercial use** of these adapters (or other Derivatives) is permitted without a separate commercial license from Black Forest Labs.
|
| 7 |
+
|
| 8 |
+
- Full license: https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/LICENSE.txt
|
| 9 |
+
- This repository does **not** grant any rights beyond the license above.
|
README.md
CHANGED
|
@@ -1,12 +1,27 @@
|
|
| 1 |
---
|
| 2 |
-
title: Pi
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Pi-FLUX.2 Demo
|
| 3 |
+
emoji: 🚀
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: other
|
| 11 |
+
license_name: flux-dev-non-commercial-license
|
| 12 |
+
license_link: LICENSE.md
|
| 13 |
---
|
| 14 |
|
| 15 |
+
Official demo of the paper:
|
| 16 |
+
|
| 17 |
+
**pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation**
|
| 18 |
+
<br>
|
| 19 |
+
[Hansheng Chen](https://lakonik.github.io/)<sup>1</sup>,
|
| 20 |
+
[Kai Zhang](https://kai-46.github.io/website/)<sup>2</sup>,
|
| 21 |
+
[Hao Tan](https://research.adobe.com/person/hao-tan/)<sup>2</sup>,
|
| 22 |
+
[Leonidas Guibas](https://geometry.stanford.edu/?member=guibas)<sup>1</sup>,
|
| 23 |
+
[Gordon Wetzstein](http://web.stanford.edu/~gordonwz/)<sup>1</sup>,
|
| 24 |
+
[Sai Bi](https://sai-bi.github.io/)<sup>2</sup><br>
|
| 25 |
+
<sup>1</sup>Stanford University, <sup>2</sup>Adobe Research
|
| 26 |
+
<br>
|
| 27 |
+
[[arXiv](https://arxiv.org/abs/2510.14974)] [[pi-Qwen Demo🤗](https://huggingface.co/spaces/Lakonik/pi-Qwen)] [[pi-FLUX Demo🤗](https://huggingface.co/spaces/Lakonik/pi-FLUX.1)] [[pi-FLUX.2 Demo🤗](https://huggingface.co/spaces/Lakonik/pi-FLUX.2)]
|
app.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 4 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import numpy as np
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import spaces
|
| 11 |
+
from lakonlab.models.diffusions.schedulers.flow_map_sde import FlowMapSDEScheduler
|
| 12 |
+
from lakonlab.ui.gradio.create_img_edit import create_interface_img_edit
|
| 13 |
+
from lakonlab.pipelines.pipeline_piflux2 import PiFlux2Pipeline
|
| 14 |
+
from lakonlab.pipelines.prompt_rewriters.qwen3_vl import Qwen3VLPromptRewriter
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DEFAULT_PROMPT = """Museum-style FIELD GUIDE poster on neutral parchment (#F3EEE3). Use Inter (or Helvetica/Arial). All text #2D3748, thin connector lines 1px #A0AEC0.
|
| 18 |
+
|
| 19 |
+
Center: full-body original fantasy creature, 3/4 standing pose. Around it: four small inset boxes labeled exactly "EYE DETAIL", "FOOT DETAIL", "SKIN TEXTURE", "SILHOUETTE SCALE" (with a simple human comparison silhouette). Bottom: a short footprint trail diagram. One small habitat vignette (misty rocky shoreline with tide pools).
|
| 20 |
+
|
| 21 |
+
Exact text (only these, clean print layout):
|
| 22 |
+
Top: "FIELD GUIDE"
|
| 23 |
+
Sub: "AURORA SHOREWALKER"
|
| 24 |
+
Small line: "CLASS: COASTAL DRIFTER"
|
| 25 |
+
Under silhouette: "HEIGHT: 1.7 m"
|
| 26 |
+
|
| 27 |
+
Crisp ink outlines with soft watercolor-like fills, high readability, balanced hierarchy, premium poster aesthetic."""
|
| 28 |
+
|
| 29 |
+
SYSTEM_PROMPT_TEXT_ONLY_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_text_only.txt'
|
| 30 |
+
SYSTEM_PROMPT_WITH_IMAGES_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_with_images.txt'
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _patch_diffusers_bnb_shape_check():
|
| 34 |
+
try:
|
| 35 |
+
import diffusers.quantizers.bitsandbytes.bnb_quantizer as bnbq
|
| 36 |
+
except Exception:
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
def _numel(shape):
|
| 40 |
+
if shape is None:
|
| 41 |
+
return None
|
| 42 |
+
if hasattr(shape, "numel"): # torch.Size
|
| 43 |
+
return int(shape.numel())
|
| 44 |
+
# plain tuple/list
|
| 45 |
+
n = 1
|
| 46 |
+
for d in shape:
|
| 47 |
+
n *= int(d)
|
| 48 |
+
return n
|
| 49 |
+
|
| 50 |
+
def patched_check(self, param_name, current_param, loaded_param):
|
| 51 |
+
cshape = getattr(current_param, "shape", None)
|
| 52 |
+
lshape = getattr(loaded_param, "shape", None)
|
| 53 |
+
n = _numel(cshape)
|
| 54 |
+
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
|
| 55 |
+
if tuple(lshape) != tuple(inferred_shape):
|
| 56 |
+
raise ValueError(
|
| 57 |
+
f"Expected flattened shape mismatch for {param_name}: "
|
| 58 |
+
f"loaded={tuple(lshape)} inferred={tuple(inferred_shape)}"
|
| 59 |
+
)
|
| 60 |
+
return True
|
| 61 |
+
|
| 62 |
+
# Patch any quantizer class in that module that defines the method
|
| 63 |
+
for name, obj in vars(bnbq).items():
|
| 64 |
+
if isinstance(obj, type) and hasattr(obj, "check_quantized_param_shape"):
|
| 65 |
+
setattr(obj, "check_quantized_param_shape", patched_check)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
_patch_diffusers_bnb_shape_check()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
pipe = PiFlux2Pipeline.from_pretrained(
|
| 72 |
+
'diffusers/FLUX.2-dev-bnb-4bit',
|
| 73 |
+
torch_dtype=torch.bfloat16)
|
| 74 |
+
pipe.load_piflow_adapter(
|
| 75 |
+
'Lakonik/pi-FLUX.2',
|
| 76 |
+
subfolder='gmflux2_k8_piid_4step',
|
| 77 |
+
target_module_name='transformer')
|
| 78 |
+
pipe.scheduler = FlowMapSDEScheduler.from_config( # use fixed shift=3.2
|
| 79 |
+
pipe.scheduler.config, shift=3.2, use_dynamic_shifting=False, final_step_size_scale=0.5)
|
| 80 |
+
pipe = pipe.to('cuda')
|
| 81 |
+
|
| 82 |
+
prompt_rewriter = Qwen3VLPromptRewriter(
|
| 83 |
+
device_map="cuda",
|
| 84 |
+
system_prompt_text_only=open(SYSTEM_PROMPT_TEXT_ONLY_PATH, 'r').read(),
|
| 85 |
+
system_prompt_wigh_images=open(SYSTEM_PROMPT_WITH_IMAGES_PATH, 'r').read(),
|
| 86 |
+
max_new_tokens_default=512,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def set_random_seed(seed: int, deterministic: bool = True) -> None:
|
| 91 |
+
random.seed(seed)
|
| 92 |
+
np.random.seed(seed)
|
| 93 |
+
torch.manual_seed(seed)
|
| 94 |
+
torch.cuda.manual_seed(seed)
|
| 95 |
+
torch.cuda.manual_seed_all(seed)
|
| 96 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 97 |
+
if deterministic:
|
| 98 |
+
torch.backends.cudnn.deterministic = True
|
| 99 |
+
torch.backends.cudnn.benchmark = False
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@spaces.GPU
|
| 103 |
+
def run_rewrite_prompt(seed, prompt, rewrite_prompt, in_image, progress=gr.Progress(track_tqdm=True)):
|
| 104 |
+
image_list = None
|
| 105 |
+
if in_image is not None and len(in_image) > 0:
|
| 106 |
+
image_list = []
|
| 107 |
+
for item in in_image:
|
| 108 |
+
image_list.append(item[0])
|
| 109 |
+
if rewrite_prompt:
|
| 110 |
+
set_random_seed(seed)
|
| 111 |
+
progress(0.05, desc="Rewriting prompt...")
|
| 112 |
+
if image_list is None:
|
| 113 |
+
final_prompt = prompt_rewriter.rewrite_text_batch(
|
| 114 |
+
[prompt])[0]
|
| 115 |
+
else:
|
| 116 |
+
final_prompt = prompt_rewriter.rewrite_edit_batch(
|
| 117 |
+
[image_list], [prompt])[0]
|
| 118 |
+
return final_prompt, None
|
| 119 |
+
else:
|
| 120 |
+
return '', None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@spaces.GPU
|
| 124 |
+
def generate(
|
| 125 |
+
seed, prompt, rewrite_prompt, rewritten_prompt, in_image, width, height, steps,
|
| 126 |
+
progress=gr.Progress(track_tqdm=True)):
|
| 127 |
+
image_list = None
|
| 128 |
+
if in_image is not None and len(in_image) > 0:
|
| 129 |
+
image_list = []
|
| 130 |
+
for item in in_image:
|
| 131 |
+
image_list.append(item[0])
|
| 132 |
+
return pipe(
|
| 133 |
+
image=image_list,
|
| 134 |
+
prompt=rewritten_prompt if rewrite_prompt else prompt,
|
| 135 |
+
width=width,
|
| 136 |
+
height=height,
|
| 137 |
+
num_inference_steps=steps,
|
| 138 |
+
generator=torch.Generator().manual_seed(seed),
|
| 139 |
+
).images[0]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
with gr.Blocks(analytics_enabled=False,
|
| 143 |
+
title='pi-FLUX.2 Demo',
|
| 144 |
+
css_paths='lakonlab/ui/gradio/style.css'
|
| 145 |
+
) as demo:
|
| 146 |
+
|
| 147 |
+
md_txt = '# pi-FLUX.2 Demo\n\n' \
|
| 148 |
+
'Official demo of the paper [pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation](https://arxiv.org/abs/2510.14974). ' \
|
| 149 |
+
'**Base model:** [FLUX.2 dev](https://huggingface.co/black-forest-labs/FLUX.2-dev). **Fast policy:** GMFlow. **Code:** [https://github.com/Lakonik/piFlow](https://github.com/Lakonik/piFlow).\n' \
|
| 150 |
+
'<br> Use and distribution of this app are governed by the [FLUX [dev] Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/LICENSE.txt).'
|
| 151 |
+
gr.Markdown(md_txt)
|
| 152 |
+
|
| 153 |
+
create_interface_img_edit(
|
| 154 |
+
generate,
|
| 155 |
+
prompt=DEFAULT_PROMPT,
|
| 156 |
+
steps=4, guidance_scale=None,
|
| 157 |
+
args=['last_seed', 'prompt', 'rewrite_prompt', 'rewritten_prompt', 'in_image', 'width', 'height', 'steps'],
|
| 158 |
+
rewrite_prompt_api=run_rewrite_prompt,
|
| 159 |
+
rewrite_prompt_args=['last_seed', 'prompt', 'rewrite_prompt', 'in_image'],
|
| 160 |
+
height=1024,
|
| 161 |
+
width=1024
|
| 162 |
+
)
|
| 163 |
+
demo.queue().launch()
|
lakonlab/__init__.py
ADDED
|
File without changes
|
lakonlab/models/__init__.py
ADDED
|
File without changes
|
lakonlab/models/architecture/__init__.py
ADDED
|
File without changes
|
lakonlab/models/architecture/gmflow/__init__.py
ADDED
|
File without changes
|
lakonlab/models/architecture/gmflow/gm_output.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from diffusers.utils import BaseOutput
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class GMFlowModelOutput(BaseOutput):
|
| 8 |
+
"""
|
| 9 |
+
The output of GMFlow models.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
means (`torch.Tensor` of shape `(batch_size, num_gaussians, num_channels, height, width)` or
|
| 13 |
+
`(batch_size, num_gaussians, num_channels, frame, height, width)`):
|
| 14 |
+
Gaussian mixture means.
|
| 15 |
+
logweights (`torch.Tensor` of shape `(batch_size, num_gaussians, 1, height, width)` or
|
| 16 |
+
`(batch_size, num_gaussians, 1, frame, height, width)`):
|
| 17 |
+
Gaussian mixture log-weights (logits).
|
| 18 |
+
logstds (`torch.Tensor` of shape `(batch_size, 1, 1, 1, 1)` or `(batch_size, 1, 1, 1, 1, 1)`):
|
| 19 |
+
Gaussian mixture log-standard-deviations (logstds are shared across all Gaussians and channels).
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
means: torch.Tensor
|
| 23 |
+
logweights: torch.Tensor
|
| 24 |
+
logstds: torch.Tensor
|
lakonlab/models/architecture/gmflow/gmflux2.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 5 |
+
from diffusers.models import ModelMixin
|
| 6 |
+
from diffusers.models.transformers.transformer_flux2 import (
|
| 7 |
+
Flux2Transformer2DModel, Flux2PosEmbed, Flux2TransformerBlock, Flux2SingleTransformerBlock,
|
| 8 |
+
Flux2TimestepGuidanceEmbeddings, Flux2Modulation)
|
| 9 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
| 10 |
+
from diffusers.configuration_utils import register_to_config
|
| 11 |
+
from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers, is_torch_npu_available
|
| 12 |
+
from .gm_output import GMFlowModelOutput
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class _GMFlux2Transformer2DModel(Flux2Transformer2DModel):
|
| 16 |
+
|
| 17 |
+
@register_to_config
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
num_gaussians=16,
|
| 21 |
+
constant_logstd=None,
|
| 22 |
+
logstd_inner_dim=1024,
|
| 23 |
+
gm_num_logstd_layers=2,
|
| 24 |
+
logweights_channels=1,
|
| 25 |
+
in_channels: int = 128,
|
| 26 |
+
out_channels: Optional[int] = None,
|
| 27 |
+
num_layers: int = 8,
|
| 28 |
+
num_single_layers: int = 48,
|
| 29 |
+
attention_head_dim: int = 128,
|
| 30 |
+
num_attention_heads: int = 48,
|
| 31 |
+
joint_attention_dim: int = 15360,
|
| 32 |
+
timestep_guidance_channels: int = 256,
|
| 33 |
+
mlp_ratio: float = 3.0,
|
| 34 |
+
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
| 35 |
+
rope_theta: int = 2000,
|
| 36 |
+
eps: float = 1e-6):
|
| 37 |
+
super(Flux2Transformer2DModel, self).__init__()
|
| 38 |
+
|
| 39 |
+
self.num_gaussians = num_gaussians
|
| 40 |
+
self.logweights_channels = logweights_channels
|
| 41 |
+
|
| 42 |
+
self.out_channels = out_channels or in_channels
|
| 43 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 44 |
+
|
| 45 |
+
# 1. Sinusoidal positional embedding for RoPE on image and text tokens
|
| 46 |
+
self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
|
| 47 |
+
|
| 48 |
+
# 2. Combined timestep + guidance embedding
|
| 49 |
+
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
| 50 |
+
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
| 54 |
+
# Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
|
| 55 |
+
self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
| 56 |
+
self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
| 57 |
+
# Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
|
| 58 |
+
self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
|
| 59 |
+
|
| 60 |
+
# 4. Input projections
|
| 61 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
|
| 62 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
|
| 63 |
+
|
| 64 |
+
# 5. Double Stream Transformer Blocks
|
| 65 |
+
self.transformer_blocks = nn.ModuleList(
|
| 66 |
+
[
|
| 67 |
+
Flux2TransformerBlock(
|
| 68 |
+
dim=self.inner_dim,
|
| 69 |
+
num_attention_heads=num_attention_heads,
|
| 70 |
+
attention_head_dim=attention_head_dim,
|
| 71 |
+
mlp_ratio=mlp_ratio,
|
| 72 |
+
eps=eps,
|
| 73 |
+
bias=False,
|
| 74 |
+
)
|
| 75 |
+
for _ in range(num_layers)
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# 6. Single Stream Transformer Blocks
|
| 80 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 81 |
+
[
|
| 82 |
+
Flux2SingleTransformerBlock(
|
| 83 |
+
dim=self.inner_dim,
|
| 84 |
+
num_attention_heads=num_attention_heads,
|
| 85 |
+
attention_head_dim=attention_head_dim,
|
| 86 |
+
mlp_ratio=mlp_ratio,
|
| 87 |
+
eps=eps,
|
| 88 |
+
bias=False,
|
| 89 |
+
)
|
| 90 |
+
for _ in range(num_single_layers)
|
| 91 |
+
]
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# 7. Output layers
|
| 95 |
+
self.norm_out = AdaLayerNormContinuous(
|
| 96 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
|
| 97 |
+
)
|
| 98 |
+
self.proj_out_means = nn.Linear(self.inner_dim, self.num_gaussians * self.out_channels)
|
| 99 |
+
self.proj_out_logweights = nn.Linear(self.inner_dim, self.num_gaussians * self.logweights_channels)
|
| 100 |
+
self.constant_logstd = constant_logstd
|
| 101 |
+
|
| 102 |
+
if self.constant_logstd is None:
|
| 103 |
+
assert gm_num_logstd_layers >= 1
|
| 104 |
+
in_dim = self.inner_dim
|
| 105 |
+
logstd_layers = []
|
| 106 |
+
for _ in range(gm_num_logstd_layers - 1):
|
| 107 |
+
logstd_layers.extend([
|
| 108 |
+
nn.SiLU(),
|
| 109 |
+
nn.Linear(in_dim, logstd_inner_dim)])
|
| 110 |
+
in_dim = logstd_inner_dim
|
| 111 |
+
self.proj_out_logstds = nn.Sequential(
|
| 112 |
+
*logstd_layers,
|
| 113 |
+
nn.SiLU(),
|
| 114 |
+
nn.Linear(in_dim, 1))
|
| 115 |
+
|
| 116 |
+
self.gradient_checkpointing = False
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
hidden_states: torch.Tensor,
|
| 121 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 122 |
+
timestep: torch.LongTensor = None,
|
| 123 |
+
img_ids: torch.Tensor = None,
|
| 124 |
+
txt_ids: torch.Tensor = None,
|
| 125 |
+
guidance: torch.Tensor = None,
|
| 126 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,):
|
| 127 |
+
# 0. Handle input arguments
|
| 128 |
+
if joint_attention_kwargs is not None:
|
| 129 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 130 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 131 |
+
else:
|
| 132 |
+
lora_scale = 1.0
|
| 133 |
+
|
| 134 |
+
if USE_PEFT_BACKEND:
|
| 135 |
+
scale_lora_layers(self, lora_scale)
|
| 136 |
+
else:
|
| 137 |
+
assert joint_attention_kwargs is None or joint_attention_kwargs.get('scale', None) is None
|
| 138 |
+
|
| 139 |
+
num_txt_tokens = encoder_hidden_states.shape[1]
|
| 140 |
+
|
| 141 |
+
# 1. Calculate timestep embedding and modulation parameters
|
| 142 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 143 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 144 |
+
|
| 145 |
+
temb = self.time_guidance_embed(timestep, guidance)
|
| 146 |
+
|
| 147 |
+
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
| 148 |
+
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
| 149 |
+
single_stream_mod = self.single_stream_modulation(temb)[0]
|
| 150 |
+
|
| 151 |
+
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
| 152 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 153 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 154 |
+
|
| 155 |
+
# 3. Calculate RoPE embeddings from image and text tokens
|
| 156 |
+
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
| 157 |
+
# text prompts of differents lengths. Is this a use case we want to support?
|
| 158 |
+
if img_ids.ndim == 3:
|
| 159 |
+
img_ids = img_ids[0]
|
| 160 |
+
if txt_ids.ndim == 3:
|
| 161 |
+
txt_ids = txt_ids[0]
|
| 162 |
+
|
| 163 |
+
if is_torch_npu_available():
|
| 164 |
+
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
|
| 165 |
+
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
|
| 166 |
+
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
|
| 167 |
+
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
|
| 168 |
+
else:
|
| 169 |
+
image_rotary_emb = self.pos_embed(img_ids)
|
| 170 |
+
text_rotary_emb = self.pos_embed(txt_ids)
|
| 171 |
+
concat_rotary_emb = (
|
| 172 |
+
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0).to(hidden_states.dtype),
|
| 173 |
+
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0).to(hidden_states.dtype),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# 4. Double Stream Transformer Blocks
|
| 177 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 178 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 179 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 180 |
+
block,
|
| 181 |
+
hidden_states,
|
| 182 |
+
encoder_hidden_states,
|
| 183 |
+
double_stream_mod_img,
|
| 184 |
+
double_stream_mod_txt,
|
| 185 |
+
concat_rotary_emb,
|
| 186 |
+
joint_attention_kwargs,
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
encoder_hidden_states, hidden_states = block(
|
| 190 |
+
hidden_states=hidden_states,
|
| 191 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 192 |
+
temb_mod_params_img=double_stream_mod_img,
|
| 193 |
+
temb_mod_params_txt=double_stream_mod_txt,
|
| 194 |
+
image_rotary_emb=concat_rotary_emb,
|
| 195 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 196 |
+
)
|
| 197 |
+
# Concatenate text and image streams for single-block inference
|
| 198 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 199 |
+
|
| 200 |
+
# 5. Single Stream Transformer Blocks
|
| 201 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 202 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 203 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 204 |
+
block,
|
| 205 |
+
hidden_states,
|
| 206 |
+
None,
|
| 207 |
+
single_stream_mod,
|
| 208 |
+
concat_rotary_emb,
|
| 209 |
+
joint_attention_kwargs,
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
hidden_states = block(
|
| 213 |
+
hidden_states=hidden_states,
|
| 214 |
+
encoder_hidden_states=None,
|
| 215 |
+
temb_mod_params=single_stream_mod,
|
| 216 |
+
image_rotary_emb=concat_rotary_emb,
|
| 217 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 218 |
+
)
|
| 219 |
+
# Remove text tokens from concatenated stream
|
| 220 |
+
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
| 221 |
+
|
| 222 |
+
# 6. Output layers
|
| 223 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 224 |
+
|
| 225 |
+
bs, seq_len, _ = hidden_states.size()
|
| 226 |
+
out_means = self.proj_out_means(hidden_states).reshape(
|
| 227 |
+
bs, seq_len, self.num_gaussians, self.out_channels)
|
| 228 |
+
out_logweights = self.proj_out_logweights(hidden_states).reshape(
|
| 229 |
+
bs, seq_len, self.num_gaussians, self.logweights_channels).log_softmax(dim=-2)
|
| 230 |
+
if self.constant_logstd is None:
|
| 231 |
+
out_logstds = self.proj_out_logstds(temb.detach()).reshape(bs, 1, 1, 1)
|
| 232 |
+
else:
|
| 233 |
+
out_logstds = hidden_states.new_full((bs, 1, 1, 1), float(self.constant_logstd))
|
| 234 |
+
|
| 235 |
+
if USE_PEFT_BACKEND:
|
| 236 |
+
unscale_lora_layers(self, lora_scale)
|
| 237 |
+
|
| 238 |
+
return GMFlowModelOutput(
|
| 239 |
+
means=out_means,
|
| 240 |
+
logweights=out_logweights,
|
| 241 |
+
logstds=out_logstds)
|
lakonlab/models/diffusions/__init__.py
ADDED
|
File without changes
|
lakonlab/models/diffusions/piflow_policies/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dx import DXPolicy
|
| 2 |
+
from .gmflow import GMFlowPolicy
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
POLICY_CLASSES = dict(
|
| 6 |
+
DX=DXPolicy,
|
| 7 |
+
GMFlow=GMFlowPolicy
|
| 8 |
+
)
|
lakonlab/models/diffusions/piflow_policies/base.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABCMeta, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BasePolicy(metaclass=ABCMeta):
|
| 5 |
+
|
| 6 |
+
@abstractmethod
|
| 7 |
+
def pi(self, x_t, sigma_t):
|
| 8 |
+
"""Compute the flow velocity at (x_t, t).
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
x_t (torch.Tensor): Noisy input at time t.
|
| 12 |
+
sigma_t (torch.Tensor): Noise level at time t.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
torch.Tensor: The computed flow velocity u_t.
|
| 16 |
+
"""
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def detach(self):
|
| 21 |
+
pass
|
lakonlab/models/diffusions/piflow_policies/dx.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from .base import BasePolicy
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DXPolicy(BasePolicy):
|
| 8 |
+
"""DX policy. The number of grid points N is inferred from the denoising output.
|
| 9 |
+
|
| 10 |
+
Note: segment_size and shift are intrinsic parameters of the DX policy. For elastic inference (i.e., changing
|
| 11 |
+
the number of function evaluations or noise schedule at test time), these parameters should be kept unchanged.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
denoising_output (torch.Tensor): The output of the denoising model. Shape (B, N, C, H, W) or (B, N, C, T, H, W).
|
| 15 |
+
x_t_src (torch.Tensor): The initial noisy sample. Shape (B, C, H, W) or (B, C, T, H, W).
|
| 16 |
+
sigma_t_src (torch.Tensor): The initial noise level. Shape (B,).
|
| 17 |
+
segment_size (float): The size of each DX policy time segment. Defaults to 1.0.
|
| 18 |
+
shift (float): The shift parameter for the DX policy noise schedule. Defaults to 1.0.
|
| 19 |
+
mode (str): Either 'grid' or 'polynomial' mode for calculating x_0. Defaults to 'grid'.
|
| 20 |
+
eps (float): A small value to avoid numerical issues. Defaults to 1e-4.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
denoising_output: torch.Tensor,
|
| 26 |
+
x_t_src: torch.Tensor,
|
| 27 |
+
sigma_t_src: torch.Tensor,
|
| 28 |
+
segment_size: float = 1.0,
|
| 29 |
+
shift: float = 1.0,
|
| 30 |
+
mode: str = 'grid',
|
| 31 |
+
eps: float = 1e-4):
|
| 32 |
+
self.x_t_src = x_t_src
|
| 33 |
+
self.ndim = x_t_src.dim()
|
| 34 |
+
self.shift = shift
|
| 35 |
+
self.eps = eps
|
| 36 |
+
|
| 37 |
+
assert mode in ['grid', 'polynomial']
|
| 38 |
+
self.mode = mode
|
| 39 |
+
|
| 40 |
+
self.sigma_t_src = sigma_t_src.reshape(*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1]))
|
| 41 |
+
self.raw_t_src = self._unwarp_t(self.sigma_t_src)
|
| 42 |
+
self.raw_t_dst = (self.raw_t_src - segment_size).clamp(min=0)
|
| 43 |
+
self.segment_size = (self.raw_t_src - self.raw_t_dst).clamp(min=eps)
|
| 44 |
+
|
| 45 |
+
self.denoising_output_x_0 = self._u_to_x_0(
|
| 46 |
+
denoising_output, self.x_t_src, self.sigma_t_src)
|
| 47 |
+
|
| 48 |
+
def _unwarp_t(self, sigma_t):
|
| 49 |
+
return sigma_t / (self.shift + (1 - self.shift) * sigma_t)
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def _u_to_x_0(denoising_output, x_t, sigma_t):
|
| 53 |
+
x_0 = x_t.unsqueeze(1) - sigma_t.unsqueeze(1) * denoising_output
|
| 54 |
+
return x_0
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def _interpolate(x, t):
|
| 58 |
+
"""
|
| 59 |
+
Args:
|
| 60 |
+
x (torch.Tensor): (B, N, *)
|
| 61 |
+
t (torch.Tensor): (B, *) in [0, 1]
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
torch.Tensor: (B, *)
|
| 65 |
+
"""
|
| 66 |
+
n = x.size(1)
|
| 67 |
+
if n < 2:
|
| 68 |
+
return x.squeeze(1)
|
| 69 |
+
t = t.clamp(min=0, max=1) * (n - 1)
|
| 70 |
+
t0 = t.floor().to(torch.long).clamp(min=0, max=n - 2)
|
| 71 |
+
t1 = t0 + 1
|
| 72 |
+
t0t1 = torch.stack([t0, t1], dim=1) # (B, 2, *)
|
| 73 |
+
x0x1 = torch.gather(x, dim=1, index=t0t1.expand(-1, -1, *x.shape[2:]))
|
| 74 |
+
x_interp = (t1 - t) * x0x1[:, 0] + (t - t0) * x0x1[:, 1]
|
| 75 |
+
return x_interp
|
| 76 |
+
|
| 77 |
+
def pi(self, x_t, sigma_t):
|
| 78 |
+
"""Compute the flow velocity at (x_t, t).
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
x_t (torch.Tensor): Noisy input at time t.
|
| 82 |
+
sigma_t (torch.Tensor): Noise level at time t.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
torch.Tensor: The computed flow velocity u_t.
|
| 86 |
+
"""
|
| 87 |
+
sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
|
| 88 |
+
raw_t = self._unwarp_t(sigma_t)
|
| 89 |
+
if self.mode == 'grid':
|
| 90 |
+
x_0 = self._interpolate(
|
| 91 |
+
self.denoising_output_x_0, (raw_t - self.raw_t_dst) / self.segment_size)
|
| 92 |
+
elif self.mode == 'polynomial':
|
| 93 |
+
p_order = self.denoising_output_x_0.size(1)
|
| 94 |
+
diff_t = self.raw_t_src - raw_t # (B, 1, 1, 1)
|
| 95 |
+
basis = torch.stack(
|
| 96 |
+
[diff_t ** i for i in range(p_order)], dim=1) # (B, N, 1, 1, 1)
|
| 97 |
+
x_0 = torch.sum(basis * self.denoising_output_x_0, dim=1)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(f"Unknown mode: {self.mode}")
|
| 100 |
+
u = (x_t - x_0) / sigma_t.clamp(min=self.eps)
|
| 101 |
+
return u
|
| 102 |
+
|
| 103 |
+
def copy(self):
|
| 104 |
+
new_policy = DXPolicy.__new__(DXPolicy)
|
| 105 |
+
new_policy.x_t_src = self.x_t_src
|
| 106 |
+
new_policy.ndim = self.ndim
|
| 107 |
+
new_policy.shift = self.shift
|
| 108 |
+
new_policy.eps = self.eps
|
| 109 |
+
new_policy.mode = self.mode
|
| 110 |
+
new_policy.sigma_t_src = self.sigma_t_src
|
| 111 |
+
new_policy.raw_t_src = self.raw_t_src
|
| 112 |
+
new_policy.raw_t_dst = self.raw_t_dst
|
| 113 |
+
new_policy.segment_size = self.segment_size
|
| 114 |
+
new_policy.denoising_output_x_0 = self.denoising_output_x_0
|
| 115 |
+
return new_policy
|
| 116 |
+
|
| 117 |
+
def detach_(self):
|
| 118 |
+
self.denoising_output_x_0 = self.denoising_output_x_0.detach()
|
| 119 |
+
return self
|
| 120 |
+
|
| 121 |
+
def detach(self):
|
| 122 |
+
new_policy = self.copy()
|
| 123 |
+
return new_policy.detach_()
|
lakonlab/models/diffusions/piflow_policies/gmflow.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from typing import Dict
|
| 7 |
+
from .base import BasePolicy
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.jit.script
|
| 11 |
+
def gmflow_posterior_mean_jit(
|
| 12 |
+
sigma_t_src, sigma_t, x_t_src, x_t,
|
| 13 |
+
gm_means, gm_vars, gm_logweights,
|
| 14 |
+
eps: float, gm_dim: int = -4, channel_dim: int = -3):
|
| 15 |
+
sigma_t_src = sigma_t_src.clamp(min=eps)
|
| 16 |
+
sigma_t = sigma_t.clamp(min=eps)
|
| 17 |
+
|
| 18 |
+
alpha_t_src = 1 - sigma_t_src
|
| 19 |
+
alpha_t = 1 - sigma_t
|
| 20 |
+
|
| 21 |
+
alpha_over_sigma_t_src = alpha_t_src / sigma_t_src
|
| 22 |
+
alpha_over_sigma_t = alpha_t / sigma_t
|
| 23 |
+
|
| 24 |
+
zeta = alpha_over_sigma_t.square() - alpha_over_sigma_t_src.square()
|
| 25 |
+
nu = alpha_over_sigma_t * x_t / sigma_t - alpha_over_sigma_t_src * x_t_src / sigma_t_src
|
| 26 |
+
|
| 27 |
+
nu = nu.unsqueeze(gm_dim) # (bs, *, 1, out_channels, h, w)
|
| 28 |
+
denom = (gm_vars * zeta + 1).clamp(min=eps)
|
| 29 |
+
|
| 30 |
+
out_means = (gm_vars * nu + gm_means) / denom
|
| 31 |
+
# (bs, *, num_gaussians, 1, h, w)
|
| 32 |
+
logweights_delta = (gm_means * (nu - 0.5 * zeta * gm_means)).sum(
|
| 33 |
+
dim=channel_dim, keepdim=True) / denom
|
| 34 |
+
out_weights = (gm_logweights + logweights_delta).softmax(dim=gm_dim)
|
| 35 |
+
|
| 36 |
+
out_mean = (out_means * out_weights).sum(dim=gm_dim)
|
| 37 |
+
|
| 38 |
+
return out_mean
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def gm_temperature(gm, temperature, gm_dim=-4, eps=1e-6):
|
| 42 |
+
gm = gm.copy()
|
| 43 |
+
temperature = max(temperature, eps)
|
| 44 |
+
gm['logweights'] = (gm['logweights'] / temperature).log_softmax(dim=gm_dim)
|
| 45 |
+
if 'logstds' in gm:
|
| 46 |
+
gm['logstds'] = gm['logstds'] + (0.5 * math.log(temperature))
|
| 47 |
+
if 'gm_vars' in gm:
|
| 48 |
+
gm['gm_vars'] = gm['gm_vars'] * temperature
|
| 49 |
+
return gm
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class GMFlowPolicy(BasePolicy):
|
| 53 |
+
"""GMFlow policy. The number of components K is inferred from the denoising output.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
denoising_output (dict): The output of the denoising model, containing:
|
| 57 |
+
means (torch.Tensor): The means of the Gaussian components. Shape (B, K, C, H, W) or (B, K, C, T, H, W).
|
| 58 |
+
logstds (torch.Tensor): The log standard deviations of the Gaussian components. Shape (B, K, 1, 1, 1)
|
| 59 |
+
or (B, K, 1, 1, 1, 1).
|
| 60 |
+
logweights (torch.Tensor): The log weights of the Gaussian components. Shape (B, K, 1, H, W) or
|
| 61 |
+
(B, K, 1, T, H, W).
|
| 62 |
+
x_t_src (torch.Tensor): The initial noisy sample. Shape (B, C, H, W) or (B, C, T, H, W).
|
| 63 |
+
sigma_t_src (torch.Tensor): The initial noise level. Shape (B,).
|
| 64 |
+
checkpointing (bool): Whether to use gradient checkpointing to save memory. Defaults to True.
|
| 65 |
+
eps (float): A small value to avoid numerical issues. Defaults to 1e-4.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
denoising_output: Dict[str, torch.Tensor],
|
| 71 |
+
x_t_src: torch.Tensor,
|
| 72 |
+
sigma_t_src: torch.Tensor,
|
| 73 |
+
checkpointing: bool = True,
|
| 74 |
+
eps: float = 1e-6):
|
| 75 |
+
self.x_t_src = x_t_src
|
| 76 |
+
self.ndim = x_t_src.dim()
|
| 77 |
+
self.checkpointing = checkpointing
|
| 78 |
+
self.eps = eps
|
| 79 |
+
|
| 80 |
+
self.sigma_t_src = sigma_t_src.reshape(*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1]))
|
| 81 |
+
self.denoising_output_x_0 = self._u_to_x_0(
|
| 82 |
+
denoising_output, self.x_t_src, self.sigma_t_src)
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def _u_to_x_0(denoising_output, x_t, sigma_t):
|
| 86 |
+
x_t = x_t.unsqueeze(1)
|
| 87 |
+
sigma_t = sigma_t.unsqueeze(1)
|
| 88 |
+
means_x_0 = x_t - sigma_t * denoising_output['means']
|
| 89 |
+
gm_vars = (denoising_output['logstds'] * 2).exp() * sigma_t.square()
|
| 90 |
+
return dict(
|
| 91 |
+
means=means_x_0,
|
| 92 |
+
gm_vars=gm_vars,
|
| 93 |
+
logweights=denoising_output['logweights'])
|
| 94 |
+
|
| 95 |
+
def pi(self, x_t, sigma_t):
|
| 96 |
+
"""Compute the flow velocity at (x_t, t).
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
x_t (torch.Tensor): Noisy input at time t.
|
| 100 |
+
sigma_t (torch.Tensor): Noise level at time t.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
torch.Tensor: The computed flow velocity u_t.
|
| 104 |
+
"""
|
| 105 |
+
sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
|
| 106 |
+
means = self.denoising_output_x_0['means']
|
| 107 |
+
gm_vars = self.denoising_output_x_0['gm_vars']
|
| 108 |
+
logweights = self.denoising_output_x_0['logweights']
|
| 109 |
+
if (sigma_t == self.sigma_t_src).all() and (x_t == self.x_t_src).all():
|
| 110 |
+
x_0 = (logweights.softmax(dim=1) * means).sum(dim=1)
|
| 111 |
+
else:
|
| 112 |
+
if self.checkpointing and torch.is_grad_enabled():
|
| 113 |
+
x_0 = torch.utils.checkpoint.checkpoint(
|
| 114 |
+
gmflow_posterior_mean_jit,
|
| 115 |
+
self.sigma_t_src, sigma_t, self.x_t_src, x_t,
|
| 116 |
+
means,
|
| 117 |
+
gm_vars,
|
| 118 |
+
logweights,
|
| 119 |
+
self.eps, 1, 2,
|
| 120 |
+
use_reentrant=True) # use_reentrant=False does not work with jit
|
| 121 |
+
else:
|
| 122 |
+
x_0 = gmflow_posterior_mean_jit(
|
| 123 |
+
self.sigma_t_src, sigma_t, self.x_t_src, x_t,
|
| 124 |
+
means,
|
| 125 |
+
gm_vars,
|
| 126 |
+
logweights,
|
| 127 |
+
self.eps, 1, 2)
|
| 128 |
+
u = (x_t - x_0) / sigma_t.clamp(min=self.eps)
|
| 129 |
+
return u
|
| 130 |
+
|
| 131 |
+
def copy(self):
|
| 132 |
+
new_policy = GMFlowPolicy.__new__(GMFlowPolicy)
|
| 133 |
+
new_policy.x_t_src = self.x_t_src
|
| 134 |
+
new_policy.ndim = self.ndim
|
| 135 |
+
new_policy.checkpointing = self.checkpointing
|
| 136 |
+
new_policy.eps = self.eps
|
| 137 |
+
new_policy.sigma_t_src = self.sigma_t_src
|
| 138 |
+
new_policy.denoising_output_x_0 = self.denoising_output_x_0.copy()
|
| 139 |
+
return new_policy
|
| 140 |
+
|
| 141 |
+
def detach_(self):
|
| 142 |
+
self.denoising_output_x_0 = {k: v.detach() for k, v in self.denoising_output_x_0.items()}
|
| 143 |
+
return self
|
| 144 |
+
|
| 145 |
+
def detach(self):
|
| 146 |
+
new_policy = self.copy()
|
| 147 |
+
return new_policy.detach_()
|
| 148 |
+
|
| 149 |
+
def dropout_(self, p):
|
| 150 |
+
if p <= 0 or p >= 1:
|
| 151 |
+
return self
|
| 152 |
+
logweights = self.denoising_output_x_0['logweights']
|
| 153 |
+
dropout_mask = torch.rand(
|
| 154 |
+
(*logweights.shape[:2], *((self.ndim - 1) * [1])), device=logweights.device) < p
|
| 155 |
+
is_all_dropout = dropout_mask.all(dim=1, keepdim=True)
|
| 156 |
+
dropout_mask &= ~is_all_dropout
|
| 157 |
+
self.denoising_output_x_0['logweights'] = logweights.masked_fill(
|
| 158 |
+
dropout_mask, float('-inf'))
|
| 159 |
+
return self
|
| 160 |
+
|
| 161 |
+
def dropout(self, p):
|
| 162 |
+
new_policy = self.copy()
|
| 163 |
+
return new_policy.dropout_(p)
|
| 164 |
+
|
| 165 |
+
def temperature_(self, temp):
|
| 166 |
+
if temp >= 1.0:
|
| 167 |
+
return self
|
| 168 |
+
self.denoising_output_x_0 = gm_temperature(
|
| 169 |
+
self.denoising_output_x_0, temp, gm_dim=1, eps=self.eps)
|
| 170 |
+
return self
|
| 171 |
+
|
| 172 |
+
def temperature(self, temp):
|
| 173 |
+
new_policy = self.copy()
|
| 174 |
+
return new_policy.temperature_(temp)
|
lakonlab/models/diffusions/schedulers/__init__.py
ADDED
|
File without changes
|
lakonlab/models/diffusions/schedulers/flow_map_sde.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple, Union
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.utils import BaseOutput, logging
|
| 10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 11 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 12 |
+
|
| 13 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class FlowMapSDESchedulerOutput(BaseOutput):
|
| 18 |
+
prev_sample: torch.FloatTensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FlowMapSDEScheduler(SchedulerMixin, ConfigMixin):
|
| 22 |
+
|
| 23 |
+
_compatibles = []
|
| 24 |
+
order = 1
|
| 25 |
+
|
| 26 |
+
@register_to_config
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
num_train_timesteps: int = 1000,
|
| 30 |
+
h: Union[float, str] = 0.0,
|
| 31 |
+
shift: float = 1.0,
|
| 32 |
+
use_dynamic_shifting=False,
|
| 33 |
+
base_seq_len=256,
|
| 34 |
+
max_seq_len=4096,
|
| 35 |
+
base_logshift=0.5,
|
| 36 |
+
max_logshift=1.15,
|
| 37 |
+
final_step_size_scale=1.0):
|
| 38 |
+
sigmas = torch.from_numpy(1 - np.linspace(
|
| 39 |
+
0, 1, num_train_timesteps, dtype=np.float32, endpoint=False))
|
| 40 |
+
self.sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 41 |
+
self.timesteps = self.sigmas * num_train_timesteps
|
| 42 |
+
|
| 43 |
+
self._step_index = None
|
| 44 |
+
self._begin_index = None
|
| 45 |
+
|
| 46 |
+
self.sigma_min = self.sigmas[-1].item()
|
| 47 |
+
self.sigma_max = self.sigmas[0].item()
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def step_index(self):
|
| 51 |
+
return self._step_index
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def begin_index(self):
|
| 55 |
+
return self._begin_index
|
| 56 |
+
|
| 57 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 58 |
+
self._begin_index = begin_index
|
| 59 |
+
|
| 60 |
+
def get_shift(self, seq_len=None):
|
| 61 |
+
if self.config.use_dynamic_shifting and seq_len is not None:
|
| 62 |
+
m = (self.config.max_logshift - self.config.base_logshift
|
| 63 |
+
) / (self.config.max_seq_len - self.config.base_seq_len)
|
| 64 |
+
logshift = (seq_len - self.config.base_seq_len) * m + self.config.base_logshift
|
| 65 |
+
if isinstance(logshift, torch.Tensor):
|
| 66 |
+
shift = torch.exp(logshift)
|
| 67 |
+
else:
|
| 68 |
+
shift = np.exp(logshift)
|
| 69 |
+
else:
|
| 70 |
+
shift = self.config.shift
|
| 71 |
+
return shift
|
| 72 |
+
|
| 73 |
+
def warp_t(self, raw_t, seq_len=None):
|
| 74 |
+
shift = self.get_shift(seq_len=seq_len)
|
| 75 |
+
return shift * raw_t / (1 + (shift - 1) * raw_t)
|
| 76 |
+
|
| 77 |
+
def unwarp_t(self, sigma_t, seq_len=None):
|
| 78 |
+
shift = self.get_shift(seq_len=seq_len)
|
| 79 |
+
return sigma_t / (shift + (1 - shift) * sigma_t)
|
| 80 |
+
|
| 81 |
+
def set_timesteps(self, num_inference_steps: int, seq_len=None, device=None):
|
| 82 |
+
self.num_inference_steps = num_inference_steps
|
| 83 |
+
|
| 84 |
+
raw_timesteps = torch.from_numpy(np.linspace(
|
| 85 |
+
1, (self.config.final_step_size_scale - 1) / (num_inference_steps + self.config.final_step_size_scale - 1),
|
| 86 |
+
num_inference_steps, dtype=np.float32, endpoint=False)).to(device).clamp(min=0)
|
| 87 |
+
sigmas = self.warp_t(raw_timesteps, seq_len=seq_len)
|
| 88 |
+
|
| 89 |
+
self.timesteps = sigmas * self.config.num_train_timesteps
|
| 90 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=device)])
|
| 91 |
+
|
| 92 |
+
sigmas_dst, m = self.calculate_sigmas_dst(self.sigmas)
|
| 93 |
+
self.timesteps_dst = sigmas_dst * self.config.num_train_timesteps
|
| 94 |
+
self.m_vals = m
|
| 95 |
+
|
| 96 |
+
self._step_index = None
|
| 97 |
+
self._begin_index = None
|
| 98 |
+
|
| 99 |
+
def calculate_sigmas_dst(self, sigmas, eps=1e-6):
|
| 100 |
+
alphas = 1 - sigmas
|
| 101 |
+
|
| 102 |
+
sigmas_src = sigmas[:-1]
|
| 103 |
+
sigmas_to = sigmas[1:]
|
| 104 |
+
alphas_src = alphas[:-1]
|
| 105 |
+
alphas_to = alphas[1:]
|
| 106 |
+
|
| 107 |
+
if self.config.h == 'inf':
|
| 108 |
+
m = torch.zeros_like(sigmas_src)
|
| 109 |
+
elif self.config.h == 0.0:
|
| 110 |
+
m = torch.ones_like(sigmas_src)
|
| 111 |
+
else:
|
| 112 |
+
assert self.config.h > 0.0
|
| 113 |
+
h2 = self.config.h * self.config.h
|
| 114 |
+
m = (sigmas_to * alphas_src / (sigmas_src * alphas_to).clamp(min=eps)) ** h2
|
| 115 |
+
|
| 116 |
+
sigmas_to_mul_m = sigmas_to * m
|
| 117 |
+
sigmas_dst = sigmas_to_mul_m / (alphas_to + sigmas_to_mul_m).clamp(min=eps)
|
| 118 |
+
|
| 119 |
+
return sigmas_dst, m
|
| 120 |
+
|
| 121 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 122 |
+
if schedule_timesteps is None:
|
| 123 |
+
schedule_timesteps = self.timesteps
|
| 124 |
+
|
| 125 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 126 |
+
|
| 127 |
+
pos = 1 if len(indices) > 1 else 0
|
| 128 |
+
|
| 129 |
+
return indices[pos].item()
|
| 130 |
+
|
| 131 |
+
def _init_step_index(self, timestep):
|
| 132 |
+
if self.begin_index is None:
|
| 133 |
+
if isinstance(timestep, torch.Tensor):
|
| 134 |
+
timestep = timestep.to(self.timesteps.device)
|
| 135 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 136 |
+
else:
|
| 137 |
+
self._step_index = self._begin_index
|
| 138 |
+
|
| 139 |
+
def step(
|
| 140 |
+
self,
|
| 141 |
+
model_output: torch.FloatTensor,
|
| 142 |
+
timestep: Union[float, torch.FloatTensor],
|
| 143 |
+
sample: torch.FloatTensor,
|
| 144 |
+
generator: Optional[torch.Generator] = None,
|
| 145 |
+
return_dict: bool = True) -> Union[FlowMapSDESchedulerOutput, Tuple]:
|
| 146 |
+
|
| 147 |
+
if isinstance(timestep, int) \
|
| 148 |
+
or isinstance(timestep, torch.IntTensor) \
|
| 149 |
+
or isinstance(timestep, torch.LongTensor):
|
| 150 |
+
raise ValueError(
|
| 151 |
+
(
|
| 152 |
+
'Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to'
|
| 153 |
+
' `EulerDiscreteScheduler.step()` is not supported. Make sure to pass'
|
| 154 |
+
' one of the `scheduler.timesteps` as a timestep.'
|
| 155 |
+
),
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if self.step_index is None:
|
| 159 |
+
self._init_step_index(timestep)
|
| 160 |
+
|
| 161 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 162 |
+
ori_dtype = model_output.dtype
|
| 163 |
+
model_output = model_output.to(torch.float32) # x_t_dst
|
| 164 |
+
|
| 165 |
+
sigma_to = self.sigmas[self.step_index + 1]
|
| 166 |
+
alpha_to = 1 - sigma_to
|
| 167 |
+
m = self.m_vals[self.step_index]
|
| 168 |
+
|
| 169 |
+
noise = randn_tensor(
|
| 170 |
+
model_output.shape, dtype=torch.float32, device=model_output.device, generator=generator)
|
| 171 |
+
|
| 172 |
+
prev_sample = (alpha_to + sigma_to * m) * model_output + sigma_to * (1 - m.square()).clamp(min=0).sqrt() * noise
|
| 173 |
+
|
| 174 |
+
# Cast sample back to model compatible dtype
|
| 175 |
+
prev_sample = prev_sample.to(ori_dtype)
|
| 176 |
+
|
| 177 |
+
# upon completion increase step index by one
|
| 178 |
+
self._step_index += 1
|
| 179 |
+
|
| 180 |
+
if not return_dict:
|
| 181 |
+
return (prev_sample,)
|
| 182 |
+
|
| 183 |
+
return FlowMapSDESchedulerOutput(prev_sample=prev_sample)
|
| 184 |
+
|
| 185 |
+
def __len__(self):
|
| 186 |
+
return self.config.num_train_timesteps
|
lakonlab/pipelines/__init__.py
ADDED
|
File without changes
|
lakonlab/pipelines/piflow_utils.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Union, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import accelerate
|
| 8 |
+
import diffusers
|
| 9 |
+
from diffusers.models import AutoModel
|
| 10 |
+
from diffusers.models.modeling_utils import (
|
| 11 |
+
load_state_dict,
|
| 12 |
+
_LOW_CPU_MEM_USAGE_DEFAULT,
|
| 13 |
+
no_init_weights,
|
| 14 |
+
ContextManagers
|
| 15 |
+
)
|
| 16 |
+
from diffusers.utils import (
|
| 17 |
+
SAFETENSORS_WEIGHTS_NAME,
|
| 18 |
+
WEIGHTS_NAME,
|
| 19 |
+
_add_variant,
|
| 20 |
+
_get_model_file,
|
| 21 |
+
is_accelerate_available,
|
| 22 |
+
is_torch_version,
|
| 23 |
+
logging,
|
| 24 |
+
)
|
| 25 |
+
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
|
| 26 |
+
from diffusers.quantizers import DiffusersAutoQuantizer
|
| 27 |
+
from diffusers.utils.torch_utils import empty_device_cache
|
| 28 |
+
from lakonlab.models.architecture.gmflow.gmflux2 import _GMFlux2Transformer2DModel
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
LOCAL_CLASS_MAPPING = {
|
| 32 |
+
"GMFlux2Transformer2DModel": _GMFlux2Transformer2DModel,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
_SET_ADAPTER_SCALE_FN_MAPPING.update(
|
| 36 |
+
_GMFlux2Transformer2DModel=lambda model_cls, weights: weights,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def assign_param(module, tensor_name: str, param: torch.nn.Parameter):
|
| 43 |
+
if "." in tensor_name:
|
| 44 |
+
splits = tensor_name.split(".")
|
| 45 |
+
for split in splits[:-1]:
|
| 46 |
+
new_module = getattr(module, split)
|
| 47 |
+
if new_module is None:
|
| 48 |
+
raise ValueError(f"{module} has no attribute {split}.")
|
| 49 |
+
module = new_module
|
| 50 |
+
tensor_name = splits[-1]
|
| 51 |
+
module._parameters[tensor_name] = param
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class PiFlowMixin:
|
| 55 |
+
|
| 56 |
+
def load_piflow_adapter(
|
| 57 |
+
self,
|
| 58 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 59 |
+
target_module_name: str = "transformer",
|
| 60 |
+
adapter_name: Optional[str] = None,
|
| 61 |
+
**kwargs
|
| 62 |
+
):
|
| 63 |
+
r"""
|
| 64 |
+
Load a PiFlow adapter from a pretrained model repository into the target module.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 68 |
+
Can be either:
|
| 69 |
+
|
| 70 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
| 71 |
+
the Hub.
|
| 72 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
| 73 |
+
with [`~ModelMixin.save_pretrained`].
|
| 74 |
+
|
| 75 |
+
target_module_name (`str`, *optional*, defaults to `"transformer"`):
|
| 76 |
+
The module name in the model to load the PiFlow adapter into.
|
| 77 |
+
adapter_name (`str`, *optional*):
|
| 78 |
+
The name to assign to the loaded adapter. If not provided, it defaults to
|
| 79 |
+
`"{target_module_name}_piflow"`.
|
| 80 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 81 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
| 82 |
+
is not used.
|
| 83 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 84 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 85 |
+
cached versions if they exist.
|
| 86 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 87 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
| 88 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 89 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 90 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
| 91 |
+
won't be downloaded from the Hub.
|
| 92 |
+
token (`str` or *bool*, *optional*):
|
| 93 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
| 94 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
| 95 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 96 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
| 97 |
+
allowed by Git.
|
| 98 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 99 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
| 100 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
| 101 |
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
| 102 |
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
| 103 |
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
| 104 |
+
argument to `True` will raise an error.
|
| 105 |
+
variant (`str`, *optional*):
|
| 106 |
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
| 107 |
+
loading `from_flax`.
|
| 108 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
| 109 |
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
| 110 |
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
| 111 |
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
| 112 |
+
disable_mmap ('bool', *optional*, defaults to 'False'):
|
| 113 |
+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
| 114 |
+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
`str` or `None`: The name assigned to the loaded adapter, or `None` if no LoRA weights were found.
|
| 118 |
+
"""
|
| 119 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 120 |
+
force_download = kwargs.pop("force_download", False)
|
| 121 |
+
proxies = kwargs.pop("proxies", None)
|
| 122 |
+
token = kwargs.pop("token", None)
|
| 123 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 124 |
+
revision = kwargs.pop("revision", None)
|
| 125 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 126 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
| 127 |
+
variant = kwargs.pop("variant", None)
|
| 128 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
| 129 |
+
disable_mmap = kwargs.pop("disable_mmap", False)
|
| 130 |
+
|
| 131 |
+
allow_pickle = False
|
| 132 |
+
if use_safetensors is None:
|
| 133 |
+
use_safetensors = True
|
| 134 |
+
allow_pickle = True
|
| 135 |
+
|
| 136 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
| 137 |
+
low_cpu_mem_usage = False
|
| 138 |
+
logger.warning(
|
| 139 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
| 140 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
| 141 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
| 142 |
+
" install accelerate\n```\n."
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
| 146 |
+
raise NotImplementedError(
|
| 147 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
| 148 |
+
" `low_cpu_mem_usage=False`."
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
user_agent = {
|
| 152 |
+
"diffusers": diffusers.__version__,
|
| 153 |
+
"file_type": "model",
|
| 154 |
+
"framework": "pytorch",
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# 1. Determine model class from config
|
| 158 |
+
|
| 159 |
+
load_config_kwargs = {
|
| 160 |
+
"cache_dir": cache_dir,
|
| 161 |
+
"force_download": force_download,
|
| 162 |
+
"proxies": proxies,
|
| 163 |
+
"token": token,
|
| 164 |
+
"local_files_only": local_files_only,
|
| 165 |
+
"revision": revision,
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
config = AutoModel.load_config(pretrained_model_name_or_path, subfolder=subfolder, **load_config_kwargs)
|
| 169 |
+
|
| 170 |
+
orig_class_name = config["_class_name"]
|
| 171 |
+
|
| 172 |
+
if orig_class_name in LOCAL_CLASS_MAPPING:
|
| 173 |
+
model_cls = LOCAL_CLASS_MAPPING[orig_class_name]
|
| 174 |
+
|
| 175 |
+
else:
|
| 176 |
+
load_config_kwargs.update({"subfolder": subfolder})
|
| 177 |
+
|
| 178 |
+
from diffusers.pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
| 179 |
+
|
| 180 |
+
model_cls, _ = get_class_obj_and_candidates(
|
| 181 |
+
library_name="diffusers",
|
| 182 |
+
class_name=orig_class_name,
|
| 183 |
+
importable_classes=ALL_IMPORTABLE_CLASSES,
|
| 184 |
+
pipelines=None,
|
| 185 |
+
is_pipeline_module=False,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
if model_cls is None:
|
| 189 |
+
raise ValueError(f"Can't find a model linked to {orig_class_name}.")
|
| 190 |
+
|
| 191 |
+
# 2. Get model file
|
| 192 |
+
|
| 193 |
+
model_file = None
|
| 194 |
+
|
| 195 |
+
if use_safetensors:
|
| 196 |
+
try:
|
| 197 |
+
model_file = _get_model_file(
|
| 198 |
+
pretrained_model_name_or_path,
|
| 199 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
| 200 |
+
cache_dir=cache_dir,
|
| 201 |
+
force_download=force_download,
|
| 202 |
+
proxies=proxies,
|
| 203 |
+
local_files_only=local_files_only,
|
| 204 |
+
token=token,
|
| 205 |
+
revision=revision,
|
| 206 |
+
subfolder=subfolder,
|
| 207 |
+
user_agent=user_agent,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
except IOError as e:
|
| 211 |
+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
| 212 |
+
if not allow_pickle:
|
| 213 |
+
raise
|
| 214 |
+
logger.warning(
|
| 215 |
+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
if model_file is None:
|
| 219 |
+
model_file = _get_model_file(
|
| 220 |
+
pretrained_model_name_or_path,
|
| 221 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
| 222 |
+
cache_dir=cache_dir,
|
| 223 |
+
force_download=force_download,
|
| 224 |
+
proxies=proxies,
|
| 225 |
+
local_files_only=local_files_only,
|
| 226 |
+
token=token,
|
| 227 |
+
revision=revision,
|
| 228 |
+
subfolder=subfolder,
|
| 229 |
+
user_agent=user_agent,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
assert model_file is not None, \
|
| 233 |
+
f"Could not find adapter weights for {pretrained_model_name_or_path}."
|
| 234 |
+
|
| 235 |
+
# 3. Initialize model
|
| 236 |
+
|
| 237 |
+
base_module = getattr(self, target_module_name)
|
| 238 |
+
|
| 239 |
+
torch_dtype = base_module.dtype
|
| 240 |
+
device = base_module.device
|
| 241 |
+
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
|
| 242 |
+
|
| 243 |
+
# load the state dict early to determine keep_in_fp32_modules
|
| 244 |
+
#######################################
|
| 245 |
+
overwrite_state_dict = dict()
|
| 246 |
+
lora_state_dict = dict()
|
| 247 |
+
|
| 248 |
+
adapter_state_dict = load_state_dict(model_file, disable_mmap=disable_mmap)
|
| 249 |
+
for k in adapter_state_dict.keys():
|
| 250 |
+
adapter_state_dict[k] = adapter_state_dict[k].to(dtype=torch_dtype, device=device)
|
| 251 |
+
if "lora" in k:
|
| 252 |
+
lora_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
|
| 253 |
+
else:
|
| 254 |
+
overwrite_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
|
| 255 |
+
|
| 256 |
+
# determine initial quantization config.
|
| 257 |
+
#######################################
|
| 258 |
+
pre_quantized = ("quantization_config" in base_module.config
|
| 259 |
+
and base_module.config["quantization_config"] is not None)
|
| 260 |
+
if pre_quantized:
|
| 261 |
+
config["quantization_config"] = base_module.config.quantization_config
|
| 262 |
+
hf_quantizer = DiffusersAutoQuantizer.from_config(
|
| 263 |
+
config["quantization_config"], pre_quantized=True
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
hf_quantizer.validate_environment(torch_dtype=torch_dtype)
|
| 267 |
+
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
| 268 |
+
|
| 269 |
+
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
|
| 270 |
+
|
| 271 |
+
# Force-set to `True` for more mem efficiency
|
| 272 |
+
if low_cpu_mem_usage is None:
|
| 273 |
+
low_cpu_mem_usage = True
|
| 274 |
+
logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
|
| 275 |
+
elif not low_cpu_mem_usage:
|
| 276 |
+
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
|
| 277 |
+
|
| 278 |
+
else:
|
| 279 |
+
hf_quantizer = None
|
| 280 |
+
|
| 281 |
+
# Check if `_keep_in_fp32_modules` is not None
|
| 282 |
+
use_keep_in_fp32_modules = model_cls._keep_in_fp32_modules is not None and (
|
| 283 |
+
hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if use_keep_in_fp32_modules:
|
| 287 |
+
keep_in_fp32_modules = model_cls._keep_in_fp32_modules
|
| 288 |
+
if not isinstance(keep_in_fp32_modules, list):
|
| 289 |
+
keep_in_fp32_modules = [keep_in_fp32_modules]
|
| 290 |
+
|
| 291 |
+
if low_cpu_mem_usage is None:
|
| 292 |
+
low_cpu_mem_usage = True
|
| 293 |
+
logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
|
| 294 |
+
elif not low_cpu_mem_usage:
|
| 295 |
+
raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
|
| 296 |
+
else:
|
| 297 |
+
keep_in_fp32_modules = []
|
| 298 |
+
|
| 299 |
+
# append modules in overwrite_state_dict to keep_in_fp32_modules
|
| 300 |
+
for k in overwrite_state_dict.keys():
|
| 301 |
+
module_name = k.rsplit('.', 1)[0]
|
| 302 |
+
if module_name and module_name not in keep_in_fp32_modules:
|
| 303 |
+
keep_in_fp32_modules.append(module_name)
|
| 304 |
+
|
| 305 |
+
init_contexts = [no_init_weights()]
|
| 306 |
+
|
| 307 |
+
if low_cpu_mem_usage:
|
| 308 |
+
init_contexts.append(accelerate.init_empty_weights())
|
| 309 |
+
|
| 310 |
+
with ContextManagers(init_contexts):
|
| 311 |
+
piflow_module = model_cls.from_config(config).eval()
|
| 312 |
+
|
| 313 |
+
torch.set_default_dtype(dtype_orig)
|
| 314 |
+
|
| 315 |
+
if hf_quantizer is not None:
|
| 316 |
+
hf_quantizer.preprocess_model(
|
| 317 |
+
model=piflow_module, device_map=None, keep_in_fp32_modules=keep_in_fp32_modules
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# 4. Load model weights
|
| 321 |
+
|
| 322 |
+
base_state_dict = base_module.state_dict()
|
| 323 |
+
base_state_dict.update(overwrite_state_dict)
|
| 324 |
+
empty_state_dict = piflow_module.state_dict()
|
| 325 |
+
for param_name, param in base_state_dict.items():
|
| 326 |
+
if param_name not in empty_state_dict:
|
| 327 |
+
continue
|
| 328 |
+
if hf_quantizer is not None and (
|
| 329 |
+
hf_quantizer.check_if_quantized_param(
|
| 330 |
+
piflow_module, param, param_name, base_state_dict, param_device=device)):
|
| 331 |
+
hf_quantizer.create_quantized_param(
|
| 332 |
+
piflow_module, param, param_name, device, base_state_dict, dtype=torch_dtype
|
| 333 |
+
)
|
| 334 |
+
else:
|
| 335 |
+
assign_param(piflow_module, param_name, param)
|
| 336 |
+
|
| 337 |
+
empty_device_cache()
|
| 338 |
+
|
| 339 |
+
if hf_quantizer is not None:
|
| 340 |
+
hf_quantizer.postprocess_model(piflow_module)
|
| 341 |
+
piflow_module.hf_quantizer = hf_quantizer
|
| 342 |
+
|
| 343 |
+
if len(lora_state_dict) == 0:
|
| 344 |
+
adapter_name = None
|
| 345 |
+
else:
|
| 346 |
+
if adapter_name is None:
|
| 347 |
+
adapter_name = f"{target_module_name}_piflow"
|
| 348 |
+
piflow_module.load_lora_adapter(
|
| 349 |
+
lora_state_dict, prefix=None, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
|
| 350 |
+
if adapter_name is None:
|
| 351 |
+
logger.warning(
|
| 352 |
+
f"No LoRA weights were found in {pretrained_model_name_or_path}."
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
setattr(self, target_module_name, piflow_module)
|
| 356 |
+
|
| 357 |
+
return adapter_name
|
| 358 |
+
|
| 359 |
+
def policy_rollout(
|
| 360 |
+
self,
|
| 361 |
+
x_t_start: torch.Tensor, # (B, C, *, H, W)
|
| 362 |
+
sigma_t_start: torch.Tensor,
|
| 363 |
+
sigma_t_end: torch.Tensor,
|
| 364 |
+
total_substeps: int,
|
| 365 |
+
policy,
|
| 366 |
+
**kwargs):
|
| 367 |
+
assert sigma_t_start.numel() == 1 and sigma_t_end.numel() == 1, \
|
| 368 |
+
"Only supports scalar sigma_t_start and sigma_t_end."
|
| 369 |
+
raw_t_start = self.scheduler.unwarp_t(
|
| 370 |
+
sigma_t_start, **kwargs)
|
| 371 |
+
raw_t_end = self.scheduler.unwarp_t(
|
| 372 |
+
sigma_t_end, **kwargs)
|
| 373 |
+
|
| 374 |
+
delta_raw_t = raw_t_start - raw_t_end
|
| 375 |
+
num_substeps = (delta_raw_t * total_substeps).round().to(torch.long).clamp(min=1)
|
| 376 |
+
substep_size = delta_raw_t / num_substeps
|
| 377 |
+
|
| 378 |
+
raw_t = raw_t_start
|
| 379 |
+
sigma_t = sigma_t_start
|
| 380 |
+
x_t = x_t_start
|
| 381 |
+
|
| 382 |
+
for substep_id in range(num_substeps.item()):
|
| 383 |
+
u = policy.pi(x_t, sigma_t)
|
| 384 |
+
|
| 385 |
+
raw_t_minus = (raw_t - substep_size).clamp(min=0)
|
| 386 |
+
sigma_t_minus = self.scheduler.warp_t(raw_t_minus, **kwargs)
|
| 387 |
+
x_t_minus = x_t + u * (sigma_t_minus - sigma_t)
|
| 388 |
+
|
| 389 |
+
x_t = x_t_minus
|
| 390 |
+
sigma_t = sigma_t_minus
|
| 391 |
+
raw_t = raw_t_minus
|
| 392 |
+
|
| 393 |
+
x_t_end = x_t
|
| 394 |
+
|
| 395 |
+
return x_t_end
|
lakonlab/pipelines/pipeline_piflux2.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import PIL
|
| 4 |
+
import torch
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 6 |
+
from functools import partial
|
| 7 |
+
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
| 8 |
+
from diffusers.utils import is_torch_xla_available
|
| 9 |
+
from diffusers.models import AutoencoderKLFlux2, Flux2Transformer2DModel
|
| 10 |
+
from diffusers.pipelines.flux2.pipeline_flux2 import Flux2Pipeline, Flux2PipelineOutput
|
| 11 |
+
from lakonlab.models.diffusions.schedulers.flow_map_sde import FlowMapSDEScheduler
|
| 12 |
+
from lakonlab.models.diffusions.piflow_policies import POLICY_CLASSES
|
| 13 |
+
from .piflow_utils import PiFlowMixin
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if is_torch_xla_available():
|
| 17 |
+
import torch_xla.core.xla_model as xm
|
| 18 |
+
|
| 19 |
+
XLA_AVAILABLE = True
|
| 20 |
+
else:
|
| 21 |
+
XLA_AVAILABLE = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PiFlux2Pipeline(Flux2Pipeline, PiFlowMixin):
|
| 25 |
+
r"""
|
| 26 |
+
The policy-based Flux2 pipeline for text-to-image generation.
|
| 27 |
+
|
| 28 |
+
Reference:
|
| 29 |
+
https://arxiv.org/abs/2510.14974
|
| 30 |
+
https://bfl.ai/blog/flux-2
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
transformer ([`Flux2Transformer2DModel`]):
|
| 34 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 35 |
+
scheduler ([`FlowMapSDEScheduler`]):
|
| 36 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 37 |
+
vae ([`AutoencoderKLFlux2`]):
|
| 38 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 39 |
+
text_encoder ([`Mistral3ForConditionalGeneration`]):
|
| 40 |
+
[Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
|
| 41 |
+
tokenizer (`AutoProcessor`):
|
| 42 |
+
Tokenizer of class
|
| 43 |
+
[PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
|
| 44 |
+
policy_type (`str`, *optional*, defaults to `"GMFlow"`):
|
| 45 |
+
The type of flow policy to use. Currently supports `"GMFlow"` and `"DX"`.
|
| 46 |
+
policy_kwargs (`Dict`, *optional*):
|
| 47 |
+
Additional keyword arguments to pass to the policy class.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 51 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
scheduler: FlowMapSDEScheduler,
|
| 56 |
+
vae: AutoencoderKLFlux2,
|
| 57 |
+
text_encoder: Mistral3ForConditionalGeneration,
|
| 58 |
+
tokenizer: AutoProcessor,
|
| 59 |
+
transformer: Flux2Transformer2DModel,
|
| 60 |
+
policy_type: str = 'GMFlow',
|
| 61 |
+
policy_kwargs: Optional[Dict[str, Any]] = None,
|
| 62 |
+
):
|
| 63 |
+
super().__init__(
|
| 64 |
+
scheduler,
|
| 65 |
+
vae,
|
| 66 |
+
text_encoder,
|
| 67 |
+
tokenizer,
|
| 68 |
+
transformer,
|
| 69 |
+
)
|
| 70 |
+
assert policy_type in POLICY_CLASSES, f'Invalid policy: {policy_type}. Supported policies are {list(POLICY_CLASSES.keys())}.'
|
| 71 |
+
self.policy_type = policy_type
|
| 72 |
+
self.policy_class = partial(
|
| 73 |
+
POLICY_CLASSES[policy_type], **policy_kwargs
|
| 74 |
+
) if policy_kwargs else POLICY_CLASSES[policy_type]
|
| 75 |
+
|
| 76 |
+
def _unpack_gm(self, gm, height, width, num_channels_latents, patch_size=2, gm_patch_size=1):
|
| 77 |
+
c = num_channels_latents * patch_size * patch_size
|
| 78 |
+
h = (int(height) // (self.vae_scale_factor * patch_size))
|
| 79 |
+
w = (int(width) // (self.vae_scale_factor * patch_size))
|
| 80 |
+
bs = gm['means'].size(0)
|
| 81 |
+
k = self.transformer.num_gaussians
|
| 82 |
+
scale = patch_size // gm_patch_size
|
| 83 |
+
gm['means'] = gm['means'].reshape(
|
| 84 |
+
bs, h, w, k, c // (scale * scale), scale, scale
|
| 85 |
+
).permute(
|
| 86 |
+
0, 3, 4, 1, 5, 2, 6
|
| 87 |
+
).reshape(
|
| 88 |
+
bs, k, c // (scale * scale), h * scale, w * scale)
|
| 89 |
+
gm['logweights'] = gm['logweights'].reshape(
|
| 90 |
+
bs, h, w, k, 1, scale, scale
|
| 91 |
+
).permute(
|
| 92 |
+
0, 3, 4, 1, 5, 2, 6
|
| 93 |
+
).reshape(
|
| 94 |
+
bs, k, 1, h * scale, w * scale)
|
| 95 |
+
gm['logstds'] = gm['logstds'].reshape(bs, 1, 1, 1, 1)
|
| 96 |
+
return gm
|
| 97 |
+
|
| 98 |
+
@torch.no_grad()
|
| 99 |
+
def __call__(
|
| 100 |
+
self,
|
| 101 |
+
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
|
| 102 |
+
prompt: Union[str, List[str]] = None,
|
| 103 |
+
height: Optional[int] = None,
|
| 104 |
+
width: Optional[int] = None,
|
| 105 |
+
num_inference_steps: int = 50,
|
| 106 |
+
total_substeps: int = 128,
|
| 107 |
+
temperature: Union[float, str] = 'auto',
|
| 108 |
+
guidance_scale: Optional[float] = 4.0,
|
| 109 |
+
num_images_per_prompt: int = 1,
|
| 110 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 111 |
+
latents: Optional[torch.Tensor] = None,
|
| 112 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 113 |
+
output_type: Optional[str] = "pil",
|
| 114 |
+
return_dict: bool = True,
|
| 115 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 116 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 117 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 118 |
+
max_sequence_length: int = 512,
|
| 119 |
+
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
|
| 120 |
+
):
|
| 121 |
+
r"""
|
| 122 |
+
Function invoked when calling the pipeline for generation.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 126 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 127 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 128 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 129 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 130 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 131 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 132 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 133 |
+
instead.
|
| 134 |
+
guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 135 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 136 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 137 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 138 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 139 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 140 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 141 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 142 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 143 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 144 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 145 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 146 |
+
expense of slower inference.
|
| 147 |
+
total_substeps (`int`, *optional*, defaults to 128):
|
| 148 |
+
The total number of substeps for policy-based flow integration.
|
| 149 |
+
temperature (`float` or `"auto"`, *optional*, defaults to `"auto"`):
|
| 150 |
+
The tmperature parameter for the flow policy.
|
| 151 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 152 |
+
The number of images to generate per prompt.
|
| 153 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 154 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 155 |
+
to make generation deterministic.
|
| 156 |
+
latents (`torch.Tensor`, *optional*):
|
| 157 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 158 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 159 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 160 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 161 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 162 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 163 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 164 |
+
The output format of the generate image. Choose between
|
| 165 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 166 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 167 |
+
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
| 168 |
+
attention_kwargs (`dict`, *optional*):
|
| 169 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 170 |
+
`self.processor` in
|
| 171 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 172 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 173 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 174 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 175 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 176 |
+
`callback_on_step_end_tensor_inputs`.
|
| 177 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 178 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 179 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 180 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 181 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 182 |
+
text_encoder_out_layers (`Tuple[int]`):
|
| 183 |
+
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
|
| 184 |
+
|
| 185 |
+
Examples:
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
|
| 189 |
+
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
| 190 |
+
generated images.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
# 1. Check inputs. Raise error if not correct
|
| 194 |
+
self.check_inputs(
|
| 195 |
+
prompt=prompt,
|
| 196 |
+
height=height,
|
| 197 |
+
width=width,
|
| 198 |
+
prompt_embeds=prompt_embeds,
|
| 199 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
self._guidance_scale = guidance_scale
|
| 203 |
+
self._attention_kwargs = attention_kwargs
|
| 204 |
+
self._current_timestep = None
|
| 205 |
+
self._interrupt = False
|
| 206 |
+
|
| 207 |
+
# 2. Define call parameters
|
| 208 |
+
if prompt is not None and isinstance(prompt, str):
|
| 209 |
+
batch_size = 1
|
| 210 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 211 |
+
batch_size = len(prompt)
|
| 212 |
+
else:
|
| 213 |
+
batch_size = prompt_embeds.shape[0]
|
| 214 |
+
|
| 215 |
+
device = self._execution_device
|
| 216 |
+
|
| 217 |
+
# 3. prepare text embeddings
|
| 218 |
+
prompt_embeds, text_ids = self.encode_prompt(
|
| 219 |
+
prompt=prompt,
|
| 220 |
+
prompt_embeds=prompt_embeds,
|
| 221 |
+
device=device,
|
| 222 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 223 |
+
max_sequence_length=max_sequence_length,
|
| 224 |
+
text_encoder_out_layers=text_encoder_out_layers,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# 4. process images
|
| 228 |
+
if image is not None and not isinstance(image, list):
|
| 229 |
+
image = [image]
|
| 230 |
+
|
| 231 |
+
condition_images = None
|
| 232 |
+
if image is not None:
|
| 233 |
+
for img in image:
|
| 234 |
+
self.image_processor.check_image_input(img)
|
| 235 |
+
|
| 236 |
+
condition_images = []
|
| 237 |
+
for img in image:
|
| 238 |
+
image_width, image_height = img.size
|
| 239 |
+
if image_width * image_height > 1024 * 1024:
|
| 240 |
+
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
| 241 |
+
image_width, image_height = img.size
|
| 242 |
+
|
| 243 |
+
multiple_of = self.vae_scale_factor * 2
|
| 244 |
+
image_width = (image_width // multiple_of) * multiple_of
|
| 245 |
+
image_height = (image_height // multiple_of) * multiple_of
|
| 246 |
+
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
| 247 |
+
condition_images.append(img)
|
| 248 |
+
height = height or image_height
|
| 249 |
+
width = width or image_width
|
| 250 |
+
|
| 251 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 252 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 253 |
+
|
| 254 |
+
# 5. prepare latent variables
|
| 255 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 256 |
+
latents, latent_ids = self.prepare_latents(
|
| 257 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 258 |
+
num_latents_channels=num_channels_latents,
|
| 259 |
+
height=height,
|
| 260 |
+
width=width,
|
| 261 |
+
dtype=torch.float32,
|
| 262 |
+
device=device,
|
| 263 |
+
generator=generator,
|
| 264 |
+
latents=latents,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
image_latents = None
|
| 268 |
+
image_latent_ids = None
|
| 269 |
+
if condition_images is not None:
|
| 270 |
+
image_latents, image_latent_ids = self.prepare_image_latents(
|
| 271 |
+
images=condition_images,
|
| 272 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 273 |
+
generator=generator,
|
| 274 |
+
device=device,
|
| 275 |
+
dtype=self.vae.dtype,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# 6. Prepare timesteps
|
| 279 |
+
image_seq_len = latents.shape[1]
|
| 280 |
+
self.scheduler.set_timesteps(num_inference_steps, seq_len=image_seq_len, device=self._execution_device)
|
| 281 |
+
timesteps = self.scheduler.timesteps
|
| 282 |
+
timesteps_dst = self.scheduler.timesteps_dst
|
| 283 |
+
self._num_timesteps = len(timesteps)
|
| 284 |
+
|
| 285 |
+
# handle guidance
|
| 286 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 287 |
+
guidance = guidance.expand(latents.shape[0])
|
| 288 |
+
|
| 289 |
+
# 7. Denoising loop
|
| 290 |
+
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
| 291 |
+
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
| 292 |
+
self.scheduler.set_begin_index(0)
|
| 293 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 294 |
+
for i, (t_src, t_dst) in enumerate(zip(timesteps, timesteps_dst)):
|
| 295 |
+
if self.interrupt:
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
self._current_timestep = t_src
|
| 299 |
+
time_scaling = self.scheduler.config.num_train_timesteps
|
| 300 |
+
sigma_t_src = t_src / time_scaling
|
| 301 |
+
sigma_t_dst = t_dst / time_scaling
|
| 302 |
+
|
| 303 |
+
latent_model_input = latents.to(self.transformer.dtype)
|
| 304 |
+
latent_image_ids = latent_ids
|
| 305 |
+
|
| 306 |
+
if image_latents is not None:
|
| 307 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
|
| 308 |
+
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
|
| 309 |
+
|
| 310 |
+
denoising_output = self.transformer(
|
| 311 |
+
hidden_states=latent_model_input, # (B, image_seq_len, C)
|
| 312 |
+
timestep=t_src.expand(latents.shape[0]) / 1000,
|
| 313 |
+
guidance=guidance,
|
| 314 |
+
encoder_hidden_states=prompt_embeds,
|
| 315 |
+
txt_ids=text_ids, # B, text_seq_len, 4
|
| 316 |
+
img_ids=latent_image_ids, # B, image_seq_len, 4
|
| 317 |
+
joint_attention_kwargs=self._attention_kwargs,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
denoising_output = {
|
| 321 |
+
k: v[:, :latents.size(1)].to(torch.float32) for k, v in denoising_output.items()}
|
| 322 |
+
|
| 323 |
+
# unpack and create policy
|
| 324 |
+
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
| 325 |
+
latents = self._unpatchify_latents(latents)
|
| 326 |
+
if self.policy_type == 'GMFlow':
|
| 327 |
+
denoising_output = self._unpack_gm(
|
| 328 |
+
denoising_output, height, width, num_channels_latents, gm_patch_size=1)
|
| 329 |
+
policy = self.policy_class(
|
| 330 |
+
denoising_output, latents, sigma_t_src)
|
| 331 |
+
if i < self.num_timesteps - 1:
|
| 332 |
+
if temperature == 'auto':
|
| 333 |
+
temperature = min(max(0.1 * (num_inference_steps - 1), 0), 1)
|
| 334 |
+
else:
|
| 335 |
+
assert isinstance(temperature, (float, int))
|
| 336 |
+
policy.temperature_(temperature)
|
| 337 |
+
elif self.policy_type == 'DX':
|
| 338 |
+
denoising_output = denoising_output[0]
|
| 339 |
+
denoising_output = self._unpack_latents_with_ids(denoising_output, latent_ids)
|
| 340 |
+
denoising_output = self._unpatchify_latents(denoising_output)
|
| 341 |
+
denoising_output = denoising_output.reshape(latents.size(0), -1, *latents.shape[1:])
|
| 342 |
+
policy = self.policy_class(
|
| 343 |
+
denoising_output, latents, sigma_t_src)
|
| 344 |
+
else:
|
| 345 |
+
raise ValueError(f'Unknown policy type: {self.policy_type}.')
|
| 346 |
+
|
| 347 |
+
latents_dst = self.policy_rollout(
|
| 348 |
+
latents, sigma_t_src, sigma_t_dst, total_substeps,
|
| 349 |
+
policy, seq_len=image_seq_len)
|
| 350 |
+
|
| 351 |
+
latents = self.scheduler.step(latents_dst, t_src, latents, return_dict=False)[0]
|
| 352 |
+
|
| 353 |
+
# repack
|
| 354 |
+
latents = self._patchify_latents(latents)
|
| 355 |
+
latents = self._pack_latents(latents)
|
| 356 |
+
|
| 357 |
+
if callback_on_step_end is not None:
|
| 358 |
+
callback_kwargs = {}
|
| 359 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 360 |
+
callback_kwargs[k] = locals()[k]
|
| 361 |
+
callback_outputs = callback_on_step_end(self, i, t_src, callback_kwargs)
|
| 362 |
+
|
| 363 |
+
latents = callback_outputs.pop("latents", latents)
|
| 364 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 365 |
+
|
| 366 |
+
progress_bar.update()
|
| 367 |
+
|
| 368 |
+
if XLA_AVAILABLE:
|
| 369 |
+
xm.mark_step()
|
| 370 |
+
|
| 371 |
+
self._current_timestep = None
|
| 372 |
+
|
| 373 |
+
if output_type == "latent":
|
| 374 |
+
image = latents
|
| 375 |
+
else:
|
| 376 |
+
torch.save({"pred": latents}, "pred_d.pt")
|
| 377 |
+
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
| 378 |
+
|
| 379 |
+
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
| 380 |
+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
| 381 |
+
latents.device, latents.dtype
|
| 382 |
+
)
|
| 383 |
+
latents = latents * latents_bn_std + latents_bn_mean
|
| 384 |
+
latents = self._unpatchify_latents(latents)
|
| 385 |
+
|
| 386 |
+
image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
| 387 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 388 |
+
|
| 389 |
+
# Offload all models
|
| 390 |
+
self.maybe_free_model_hooks()
|
| 391 |
+
|
| 392 |
+
if not return_dict:
|
| 393 |
+
return (image,)
|
| 394 |
+
|
| 395 |
+
return Flux2PipelineOutput(images=image)
|
lakonlab/pipelines/prompt_rewriters/__init__.py
ADDED
|
File without changes
|
lakonlab/pipelines/prompt_rewriters/qwen3_vl.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from typing import List, Sequence, Union, Optional
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
DEFAULT_TEXT_ONLY_PATH = os.path.abspath(os.path.join(__file__, '../system_prompts/default_text_only.txt'))
|
| 9 |
+
DEFAULT_WITH_IMAGES_PATH = os.path.abspath(os.path.join(__file__, '../system_prompts/default_with_images.txt'))
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Qwen3VLPromptRewriter:
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
from_pretrained="Qwen/Qwen3-VL-8B-Instruct",
|
| 17 |
+
torch_dtype='bfloat16',
|
| 18 |
+
device_map="auto",
|
| 19 |
+
max_new_tokens_default=128,
|
| 20 |
+
system_prompt_text_only=None,
|
| 21 |
+
system_prompt_wigh_images=None,
|
| 22 |
+
**kwargs):
|
| 23 |
+
if torch_dtype is not None:
|
| 24 |
+
kwargs.update(torch_dtype=getattr(torch, torch_dtype))
|
| 25 |
+
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 26 |
+
from_pretrained,
|
| 27 |
+
device_map=device_map,
|
| 28 |
+
**kwargs)
|
| 29 |
+
self.processor = AutoProcessor.from_pretrained(from_pretrained)
|
| 30 |
+
# Left padding is safer for batched generation
|
| 31 |
+
if hasattr(self.processor, "tokenizer"):
|
| 32 |
+
self.processor.tokenizer.padding_side = "left"
|
| 33 |
+
self.max_new_tokens_default = max_new_tokens_default
|
| 34 |
+
if system_prompt_text_only is None:
|
| 35 |
+
system_prompt_text_only = open(DEFAULT_TEXT_ONLY_PATH, 'r').read()
|
| 36 |
+
if system_prompt_wigh_images is None:
|
| 37 |
+
system_prompt_wigh_images = open(DEFAULT_WITH_IMAGES_PATH, 'r').read()
|
| 38 |
+
self.system_prompt_text_only = system_prompt_text_only
|
| 39 |
+
self.system_prompt_wigh_images = system_prompt_wigh_images
|
| 40 |
+
|
| 41 |
+
@torch.inference_mode()
|
| 42 |
+
def _generate_from_messages(
|
| 43 |
+
self,
|
| 44 |
+
batch_messages: Sequence[Sequence[dict]],
|
| 45 |
+
max_new_tokens: Optional[int] = None,
|
| 46 |
+
**kwargs) -> List[str]:
|
| 47 |
+
if max_new_tokens is None:
|
| 48 |
+
max_new_tokens = self.max_new_tokens_default
|
| 49 |
+
|
| 50 |
+
inputs = self.processor.apply_chat_template(
|
| 51 |
+
batch_messages,
|
| 52 |
+
tokenize=True,
|
| 53 |
+
add_generation_prompt=True,
|
| 54 |
+
return_dict=True,
|
| 55 |
+
return_tensors="pt",
|
| 56 |
+
padding=True,
|
| 57 |
+
)
|
| 58 |
+
inputs.pop("token_type_ids", None)
|
| 59 |
+
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 60 |
+
|
| 61 |
+
generated_ids = self.model.generate(
|
| 62 |
+
**inputs,
|
| 63 |
+
max_new_tokens=max_new_tokens,
|
| 64 |
+
**kwargs)
|
| 65 |
+
|
| 66 |
+
input_ids = inputs["input_ids"]
|
| 67 |
+
tokenizer = self.processor.tokenizer
|
| 68 |
+
outputs: List[str] = []
|
| 69 |
+
|
| 70 |
+
# Decode only the new tokens after each input sequence
|
| 71 |
+
for in_ids, out_ids in zip(input_ids, generated_ids):
|
| 72 |
+
trimmed_ids = out_ids[len(in_ids):]
|
| 73 |
+
text = tokenizer.decode(
|
| 74 |
+
trimmed_ids.tolist(),
|
| 75 |
+
skip_special_tokens=True,
|
| 76 |
+
clean_up_tokenization_spaces=False,
|
| 77 |
+
)
|
| 78 |
+
outputs.append(text.strip())
|
| 79 |
+
|
| 80 |
+
return outputs
|
| 81 |
+
|
| 82 |
+
def rewrite_text_batch(
|
| 83 |
+
self,
|
| 84 |
+
prompts: Sequence[str],
|
| 85 |
+
max_new_tokens: Optional[int] = None,
|
| 86 |
+
top_p=0.6,
|
| 87 |
+
top_k=40,
|
| 88 |
+
temperature=0.5,
|
| 89 |
+
repetition_penalty=1.0,
|
| 90 |
+
**kwargs) -> List[str]:
|
| 91 |
+
"""
|
| 92 |
+
Rewrite a batch of text-only prompts into detailed prompts.
|
| 93 |
+
"""
|
| 94 |
+
batch_messages = []
|
| 95 |
+
for p in prompts:
|
| 96 |
+
conv = [
|
| 97 |
+
{
|
| 98 |
+
"role": "system",
|
| 99 |
+
"content": [
|
| 100 |
+
{"type": "text", "text": self.system_prompt_text_only},
|
| 101 |
+
],
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"role": "user",
|
| 105 |
+
"content": [
|
| 106 |
+
{"type": "text", "text": p},
|
| 107 |
+
],
|
| 108 |
+
},
|
| 109 |
+
]
|
| 110 |
+
batch_messages.append(conv)
|
| 111 |
+
|
| 112 |
+
return self._generate_from_messages(
|
| 113 |
+
batch_messages,
|
| 114 |
+
max_new_tokens=max_new_tokens,
|
| 115 |
+
top_p=top_p,
|
| 116 |
+
top_k=top_k,
|
| 117 |
+
temperature=temperature,
|
| 118 |
+
repetition_penalty=repetition_penalty,
|
| 119 |
+
**kwargs,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def rewrite_edit_batch(
|
| 123 |
+
self,
|
| 124 |
+
image: Sequence[Union[str, 'Image.Image', Sequence[Union[str, 'Image.Image']]]],
|
| 125 |
+
edit_requests: Sequence[str],
|
| 126 |
+
max_new_tokens: Optional[int] = None,
|
| 127 |
+
top_p=0.5,
|
| 128 |
+
top_k=20,
|
| 129 |
+
temperature=0.4,
|
| 130 |
+
repetition_penalty=1.0,
|
| 131 |
+
**kwargs) -> List[str]:
|
| 132 |
+
"""
|
| 133 |
+
Rewrite a batch of (image, edit-request) pairs into concise edit instructions.
|
| 134 |
+
"""
|
| 135 |
+
if len(image) != len(edit_requests):
|
| 136 |
+
raise ValueError("image and edit_requests must have the same length")
|
| 137 |
+
|
| 138 |
+
batch_messages = []
|
| 139 |
+
for imgs, req in zip(image, edit_requests):
|
| 140 |
+
if isinstance(imgs, (str, Image.Image)):
|
| 141 |
+
img_list = [imgs]
|
| 142 |
+
else:
|
| 143 |
+
img_list = list(imgs)
|
| 144 |
+
|
| 145 |
+
user_content = []
|
| 146 |
+
for im in img_list:
|
| 147 |
+
user_content.append({"type": "image", "image": im})
|
| 148 |
+
user_content.append({"type": "text", "text": req})
|
| 149 |
+
|
| 150 |
+
conv = [
|
| 151 |
+
{
|
| 152 |
+
"role": "system",
|
| 153 |
+
"content": [
|
| 154 |
+
{"type": "text", "text": self.system_prompt_wigh_images},
|
| 155 |
+
],
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"role": "user",
|
| 159 |
+
"content": user_content,
|
| 160 |
+
},
|
| 161 |
+
]
|
| 162 |
+
batch_messages.append(conv)
|
| 163 |
+
|
| 164 |
+
return self._generate_from_messages(
|
| 165 |
+
batch_messages,
|
| 166 |
+
max_new_tokens=max_new_tokens,
|
| 167 |
+
top_p=top_p,
|
| 168 |
+
top_k=top_k,
|
| 169 |
+
temperature=temperature,
|
| 170 |
+
repetition_penalty=repetition_penalty,
|
| 171 |
+
**kwargs,
|
| 172 |
+
)
|
lakonlab/pipelines/prompt_rewriters/system_prompts/default_text_only.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an expert prompt engineer for a text-guided image generation system. Rewrite user prompts into a more descriptive, concrete prompt while strictly preserving their core content.
|
| 2 |
+
|
| 3 |
+
Rules:
|
| 4 |
+
- Preserve the core content: do not change the main subjects, how many there are, their roles, or the primary actions and relationships described in the prompt.
|
| 5 |
+
- Preserve any explicitly stated attributes such as colors, clothing, objects, style tags (e.g., “photorealistic”, “cinematic”), and viewpoint (e.g., close-up, wide shot). Do not contradict them.
|
| 6 |
+
- When the prompt implies realism (e.g., uses words like “realistic”, “photorealistic”, “photo”), avoid introducing exaggerated or fantastical traits unless intended by the user.
|
| 7 |
+
- You may add supporting details that are clearly compatible with the original text: background, props, textures, materials, lighting (quality, direction, color), atmosphere, and other environmental context, as long as they do not introduce new main characters or conflicting concepts.
|
| 8 |
+
- Always include a clear description of the composition and camera framing (for example, where the main subjects are in the frame, whether it is a close-up or wide shot, and the approximate viewpoint or angle).
|
| 9 |
+
- Structure: keep any existing structure (tags, aspect-ratio tags, etc.) and enhance only within those fields. For plain text prompts, expand them into a clear, coherent paragraph.
|
| 10 |
+
- Text in images: put ALL visible or implied text in quotation marks, matching the prompt’s language. Provide explicit quoted text for any object that would realistically contain text (signs, labels, screens, interfaces, book covers, etc.).
|
| 11 |
+
|
| 12 |
+
Output only the revised prompt and nothing else.
|
lakonlab/pipelines/prompt_rewriters/system_prompts/default_with_images.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an expert prompt engineer for a text-guided image editing system. Rewrite user prompts into a more descriptive instruction (40–70 words, ~25 for brief requests) while strictly preserving their core content.
|
| 2 |
+
|
| 3 |
+
Rules:
|
| 4 |
+
- Treat the user prompt as a strict specification: do not change or omit any mentioned entities, actions, or relationships.
|
| 5 |
+
- Explicitly state both the requested edits and which core aspects (e.g., character identities) must remain as in the original image(s). Ignore background elements unless they are explicitly mentioned.
|
| 6 |
+
- Refer to core visual elements that are relavant to the edit (e.g., people, animals, and objects mentioned in the user prompt).
|
| 7 |
+
- Do not invent new elements unless the user explicitly asks for them.
|
| 8 |
+
- Turn negatives into positives (“do not change X” → “keep X the same”).
|
| 9 |
+
|
| 10 |
+
Output only the final instruction in plain text and nothing else.
|
lakonlab/ui/__init__.py
ADDED
|
File without changes
|
lakonlab/ui/gradio/__init__.py
ADDED
|
File without changes
|
lakonlab/ui/gradio/create_img_edit.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from .shared_opts import create_base_opts, create_generate_bar, set_seed, create_prompt_opts, create_image_size_bar
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_interface_img_edit(
|
| 6 |
+
api, prompt='', seed=42, steps=32, min_steps=4, max_steps=50, steps_slider_step=1,
|
| 7 |
+
height=768, width=1360, hw_slider_step=16,
|
| 8 |
+
guidance_scale=None, temperature=None, api_name='text_to_img',
|
| 9 |
+
create_negative_prompt=False,
|
| 10 |
+
create_prompt_rewrite=True, rewrite_prompt=False,
|
| 11 |
+
args=['last_seed', 'prompt', 'in_image', 'width', 'height', 'steps', 'guidance_scale'],
|
| 12 |
+
rewrite_prompt_api=None, rewrite_prompt_args=['last_seed', 'prompt', 'rewrite_prompt', 'in_image']):
|
| 13 |
+
var_dict = dict()
|
| 14 |
+
with gr.Blocks(analytics_enabled=False) as interface:
|
| 15 |
+
with gr.Row():
|
| 16 |
+
with gr.Column():
|
| 17 |
+
with gr.Accordion("Input image(s) (optional)", open=True, elem_classes=['custom-spacing']):
|
| 18 |
+
var_dict['in_image'] = gr.Gallery(
|
| 19 |
+
label="Input image(s)",
|
| 20 |
+
type="pil",
|
| 21 |
+
columns=3,
|
| 22 |
+
rows=1)
|
| 23 |
+
|
| 24 |
+
with gr.Column(variant='compact', elem_classes=['custom-spacing']):
|
| 25 |
+
create_prompt_opts(
|
| 26 |
+
var_dict, create_negative_prompt=create_negative_prompt, prompt=prompt, display_label=True)
|
| 27 |
+
|
| 28 |
+
if create_prompt_rewrite:
|
| 29 |
+
var_dict['rewrite_prompt'] = gr.Checkbox(
|
| 30 |
+
label='Rewrite prompt', value=rewrite_prompt, container=False)
|
| 31 |
+
with gr.Accordion("Rewritten prompt", open=False, elem_classes=['custom-spacing']):
|
| 32 |
+
var_dict['rewritten_prompt'] = gr.Textbox(
|
| 33 |
+
lines=4, interactive=False, show_label=False, container=False)
|
| 34 |
+
|
| 35 |
+
with gr.Column(variant='compact', elem_classes=['custom-spacing']):
|
| 36 |
+
create_image_size_bar(
|
| 37 |
+
var_dict, height=height, width=width, hw_slider_step=hw_slider_step)
|
| 38 |
+
|
| 39 |
+
create_generate_bar(var_dict, text='Generate', seed=seed)
|
| 40 |
+
|
| 41 |
+
create_base_opts(
|
| 42 |
+
var_dict,
|
| 43 |
+
steps=steps,
|
| 44 |
+
min_steps=min_steps,
|
| 45 |
+
max_steps=max_steps,
|
| 46 |
+
steps_slider_step=steps_slider_step,
|
| 47 |
+
guidance_scale=guidance_scale,
|
| 48 |
+
temperature=temperature)
|
| 49 |
+
|
| 50 |
+
with gr.Column():
|
| 51 |
+
var_dict['output_image'] = gr.Image(
|
| 52 |
+
type='pil', image_mode='RGB', label='Output image', interactive=False,
|
| 53 |
+
elem_classes=['vh-img', 'vh-img-1000'])
|
| 54 |
+
|
| 55 |
+
if create_prompt_rewrite:
|
| 56 |
+
assert rewrite_prompt_api is not None
|
| 57 |
+
var_dict['run_btn'].click(
|
| 58 |
+
fn=set_seed,
|
| 59 |
+
inputs=var_dict['seed'],
|
| 60 |
+
outputs=var_dict['last_seed'],
|
| 61 |
+
show_progress=False,
|
| 62 |
+
api_name=False
|
| 63 |
+
).success(
|
| 64 |
+
fn=rewrite_prompt_api,
|
| 65 |
+
inputs=[var_dict[arg] for arg in rewrite_prompt_args],
|
| 66 |
+
outputs=[var_dict['rewritten_prompt'], var_dict['output_image']],
|
| 67 |
+
concurrency_id='default_group', api_name=False
|
| 68 |
+
).success(
|
| 69 |
+
fn=api,
|
| 70 |
+
inputs=[var_dict[arg] for arg in args],
|
| 71 |
+
outputs=var_dict['output_image'],
|
| 72 |
+
concurrency_id='default_group', api_name=api_name
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
var_dict['run_btn'].click(
|
| 76 |
+
fn=set_seed,
|
| 77 |
+
inputs=var_dict['seed'],
|
| 78 |
+
outputs=var_dict['last_seed'],
|
| 79 |
+
show_progress=False,
|
| 80 |
+
api_name=False
|
| 81 |
+
).success(
|
| 82 |
+
fn=api,
|
| 83 |
+
inputs=[var_dict[arg] for arg in args],
|
| 84 |
+
outputs=var_dict['output_image'],
|
| 85 |
+
concurrency_id='default_group', api_name=api_name
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return interface, var_dict
|
lakonlab/ui/gradio/create_text_to_img.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from .shared_opts import create_base_opts, create_generate_bar, set_seed, create_prompt_opts, create_image_size_bar
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_interface_text_to_img(
|
| 6 |
+
api, prompt='', seed=42, steps=32, min_steps=4, max_steps=50, steps_slider_step=1,
|
| 7 |
+
height=768, width=1360, hw_slider_step=16,
|
| 8 |
+
guidance_scale=None, temperature=None, api_name='text_to_img',
|
| 9 |
+
create_negative_prompt=False, args=['last_seed', 'prompt', 'width', 'height', 'steps', 'guidance_scale']):
|
| 10 |
+
var_dict = dict()
|
| 11 |
+
with gr.Blocks(analytics_enabled=False) as interface:
|
| 12 |
+
var_dict['output_image'] = gr.Image(
|
| 13 |
+
type='pil', image_mode='RGB', label='Output image', interactive=False, elem_classes=['vh-img', 'vh-img-700'])
|
| 14 |
+
create_prompt_opts(var_dict, create_negative_prompt=create_negative_prompt, prompt=prompt)
|
| 15 |
+
with gr.Column(variant='compact', elem_classes=['custom-spacing']):
|
| 16 |
+
create_image_size_bar(
|
| 17 |
+
var_dict, height=height, width=width, hw_slider_step=hw_slider_step)
|
| 18 |
+
create_generate_bar(var_dict, text='Generate', seed=seed)
|
| 19 |
+
create_base_opts(
|
| 20 |
+
var_dict,
|
| 21 |
+
steps=steps,
|
| 22 |
+
min_steps=min_steps,
|
| 23 |
+
max_steps=max_steps,
|
| 24 |
+
steps_slider_step=steps_slider_step,
|
| 25 |
+
guidance_scale=guidance_scale,
|
| 26 |
+
temperature=temperature)
|
| 27 |
+
|
| 28 |
+
var_dict['run_btn'].click(
|
| 29 |
+
fn=set_seed,
|
| 30 |
+
inputs=var_dict['seed'],
|
| 31 |
+
outputs=var_dict['last_seed'],
|
| 32 |
+
show_progress=False,
|
| 33 |
+
api_name=False
|
| 34 |
+
).success(
|
| 35 |
+
fn=api,
|
| 36 |
+
inputs=[var_dict[arg] for arg in args],
|
| 37 |
+
outputs=var_dict['output_image'],
|
| 38 |
+
concurrency_id='default_group', api_name=api_name
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return interface, var_dict
|
lakonlab/ui/gradio/shared_opts.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import gradio as gr
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_prompt_opts(
|
| 6 |
+
var_dict, create_negative_prompt=True, prompt='', negatove_prompt='', display_label=False):
|
| 7 |
+
if display_label:
|
| 8 |
+
kwargs = dict(show_label=True, container=True, elem_classes=['force-hide-container'])
|
| 9 |
+
else:
|
| 10 |
+
kwargs = dict(show_label=False, container=False)
|
| 11 |
+
var_dict['prompt'] = gr.Textbox(
|
| 12 |
+
prompt, label='Prompt', lines=2, placeholder='Prompt', interactive=True, **kwargs)
|
| 13 |
+
if create_negative_prompt:
|
| 14 |
+
var_dict['negative_prompt'] = gr.Textbox(
|
| 15 |
+
negatove_prompt, label='Negative prompt', lines=2,
|
| 16 |
+
placeholder='Negative prompt', interactive=True, **kwargs)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def create_generate_bar(var_dict, text='Generate', variant='primary', seed=-1):
|
| 20 |
+
with gr.Row(equal_height=False, elem_classes=['generate-bar']):
|
| 21 |
+
var_dict['run_btn'] = gr.Button(text, variant=variant, scale=2)
|
| 22 |
+
var_dict['seed'] = gr.Number(
|
| 23 |
+
label='Seed', value=seed, min_width=100, precision=0, minimum=-1, maximum=2 ** 31,
|
| 24 |
+
elem_classes=['force-hide-container', 'seed-input'])
|
| 25 |
+
var_dict['random_seed'] = gr.Button('\U0001f3b2\ufe0f', elem_classes=['tool'])
|
| 26 |
+
var_dict['reuse_seed'] = gr.Button('\u267b\ufe0f', elem_classes=['tool'])
|
| 27 |
+
with gr.Column(visible=False):
|
| 28 |
+
var_dict['last_seed'] = gr.Number(value=seed, label='Last seed')
|
| 29 |
+
var_dict['reuse_seed'].click(
|
| 30 |
+
fn=lambda x: x,
|
| 31 |
+
inputs=var_dict['last_seed'],
|
| 32 |
+
outputs=var_dict['seed'],
|
| 33 |
+
show_progress=False,
|
| 34 |
+
api_name=False)
|
| 35 |
+
var_dict['random_seed'].click(
|
| 36 |
+
fn=lambda: -1,
|
| 37 |
+
outputs=var_dict['seed'],
|
| 38 |
+
show_progress=False,
|
| 39 |
+
api_name=False)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def create_image_size_bar(var_dict, height=768, width=1360, hw_slider_step=16):
|
| 43 |
+
with gr.Row(equal_height=True, variant='compact', elem_classes=['force-hide-container']):
|
| 44 |
+
var_dict['width'] = gr.Slider(
|
| 45 |
+
label='Width', minimum=64, maximum=2048, step=hw_slider_step, value=width,
|
| 46 |
+
elem_classes=['force-hide-container'])
|
| 47 |
+
var_dict['switch_hw'] = gr.Button('\U000021C6', elem_classes=['tool'])
|
| 48 |
+
var_dict['height'] = gr.Slider(
|
| 49 |
+
label='Height', minimum=64, maximum=2048, step=hw_slider_step, value=height,
|
| 50 |
+
elem_classes=['force-hide-container'])
|
| 51 |
+
var_dict['switch_hw'].click(
|
| 52 |
+
fn=lambda w, h: (h, w),
|
| 53 |
+
inputs=[var_dict['width'], var_dict['height']],
|
| 54 |
+
outputs=[var_dict['width'], var_dict['height']],
|
| 55 |
+
show_progress=False,
|
| 56 |
+
api_name=False)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def create_base_opts(var_dict,
|
| 60 |
+
steps=24,
|
| 61 |
+
min_steps=4,
|
| 62 |
+
max_steps=50,
|
| 63 |
+
steps_slider_step=1,
|
| 64 |
+
guidance_scale=None,
|
| 65 |
+
temperature=None,
|
| 66 |
+
render=True):
|
| 67 |
+
with gr.Column(variant='compact', elem_classes=['custom-spacing'], render=render) as base_opts:
|
| 68 |
+
with gr.Row(variant='compact', elem_classes=['force-hide-container']):
|
| 69 |
+
var_dict['steps'] = gr.Slider(
|
| 70 |
+
min_steps, max_steps, value=steps, step=steps_slider_step, label='Sampling steps',
|
| 71 |
+
elem_classes=['force-hide-container'])
|
| 72 |
+
if guidance_scale is not None or temperature is not None:
|
| 73 |
+
with gr.Row(variant='compact', elem_classes=['force-hide-container']):
|
| 74 |
+
if guidance_scale is not None:
|
| 75 |
+
var_dict['guidance_scale'] = gr.Slider(
|
| 76 |
+
0.0, 30.0, value=guidance_scale, step=0.5, label='Guidance scale',
|
| 77 |
+
elem_classes=['force-hide-container'])
|
| 78 |
+
if temperature is not None:
|
| 79 |
+
var_dict['temperature'] = gr.Slider(
|
| 80 |
+
0.0, 1.0, value=temperature, step=0.01, label='Temperature',
|
| 81 |
+
elem_classes=['force-hide-container'])
|
| 82 |
+
return base_opts
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def set_seed(seed):
|
| 86 |
+
seed = random.randint(0, 2**31) if seed == -1 else seed
|
| 87 |
+
return seed
|
lakonlab/ui/gradio/style.css
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.force-hide-container {
|
| 2 |
+
margin: 0;
|
| 3 |
+
box-shadow: none;
|
| 4 |
+
--block-border-width: 0;
|
| 5 |
+
background: transparent;
|
| 6 |
+
padding: 0;
|
| 7 |
+
overflow: visible;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
.svelte-1vd8eap {
|
| 11 |
+
display: flex;
|
| 12 |
+
flex-direction: inherit;
|
| 13 |
+
flex-wrap: wrap;
|
| 14 |
+
gap: 0;
|
| 15 |
+
box-shadow: none;
|
| 16 |
+
border: 0;
|
| 17 |
+
border-radius: 0;
|
| 18 |
+
background: transparent;
|
| 19 |
+
overflow-y: hidden;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
.custom-spacing {
|
| 23 |
+
padding: 10px;
|
| 24 |
+
gap: 20px;
|
| 25 |
+
flex-grow: 0 !important;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.generate-bar {
|
| 29 |
+
align-items: flex-end;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.tool{
|
| 33 |
+
max-width: 40px;
|
| 34 |
+
min-width: 40px !important;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
/* Center the component and allow it to use the full row width */
|
| 38 |
+
.vh-img {
|
| 39 |
+
display: grid;
|
| 40 |
+
justify-items: center;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
/* Container should size to the image, but never exceed the row width */
|
| 44 |
+
.vh-img .image-container {
|
| 45 |
+
inline-size: fit-content !important; /* prefers image’s natural width */
|
| 46 |
+
max-inline-size: 100% !important; /* ...but clamps to available width */
|
| 47 |
+
margin-inline: auto;
|
| 48 |
+
overflow: hidden; /* avoid odd overflow on iOS */
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
/* Image scales by BOTH constraints: height cap and row width */
|
| 52 |
+
.vh-img-700 .image-container img {
|
| 53 |
+
max-block-size: 700px !important; /* fixed max height cap */
|
| 54 |
+
max-inline-size: 100%; /* never wider than container */
|
| 55 |
+
inline-size: auto; /* keep aspect ratio */
|
| 56 |
+
block-size: auto;
|
| 57 |
+
object-fit: contain;
|
| 58 |
+
display: block;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
.vh-img-1000 .image-container img {
|
| 62 |
+
max-block-size: 1000px !important; /* fixed max height cap */
|
| 63 |
+
max-inline-size: 100%; /* never wider than container */
|
| 64 |
+
inline-size: auto; /* keep aspect ratio */
|
| 65 |
+
block-size: auto;
|
| 66 |
+
object-fit: contain;
|
| 67 |
+
display: block;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.gradio-container .seed-input input { /* remove the border and use box-shadow instead for height alignment */
|
| 71 |
+
border: none !important;
|
| 72 |
+
box-shadow: inset 0 0 0 var(--input-border-width) var(--input-border-color) !important;
|
| 73 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
torch==2.6.0
|
| 3 |
+
torchvision==0.21.0
|
| 4 |
+
diffusers==0.36.0
|
| 5 |
+
peft==0.17.0
|
| 6 |
+
sentencepiece
|
| 7 |
+
accelerate
|
| 8 |
+
transformers==4.57.3
|
| 9 |
+
gradio==5.49.0
|
| 10 |
+
bitsandbytes>=0.46.1
|