Hansheng Chen commited on
Commit
5d99e98
·
1 Parent(s): d46bfe9

Release pi-FLUX.2 demo

Browse files
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /.idea/
2
+ /work_dirs*
3
+ .vscode/
4
+ /tmp
5
+ /data
6
+ /checkpoints
7
+ *.so
8
+ *.patch
9
+ __pycache__/
10
+ *.egg-info/
11
+ /viz*
12
+ /submit*
13
+ build/
14
+ *.pyd
15
+ /cache*
16
+ *.stl
17
+ *.pth
18
+ /venv/
19
+ .nk8s
20
+ *.mp4
21
+ .vs
22
+ /exp/
23
+ /dev/
24
+ *.pyi
25
+ !/data/imagenet/imagenet1000_clsidx_to_labels.txt
LICENSE.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # License for pi-FLUX.2
2
+
3
+ This repository distributes a **pi-FLUX.2 app** that is a **Derivative** of **FLUX.2 [dev]** by **Black Forest Labs Inc.**
4
+
5
+ Use and distribution of these adapters are governed by the **FLUX [dev] Non-Commercial License v2.0**.
6
+ **No commercial use** of these adapters (or other Derivatives) is permitted without a separate commercial license from Black Forest Labs.
7
+
8
+ - Full license: https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/LICENSE.txt
9
+ - This repository does **not** grant any rights beyond the license above.
README.md CHANGED
@@ -1,12 +1,27 @@
1
  ---
2
- title: Pi FLUX.2
3
- emoji: 😻
4
  colorFrom: pink
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Pi-FLUX.2 Demo
3
+ emoji: 🚀
4
  colorFrom: pink
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.49.0
8
  app_file: app.py
9
  pinned: false
10
+ license: other
11
+ license_name: flux-dev-non-commercial-license
12
+ license_link: LICENSE.md
13
  ---
14
 
15
+ Official demo of the paper:
16
+
17
+ **pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation**
18
+ <br>
19
+ [Hansheng Chen](https://lakonik.github.io/)<sup>1</sup>,
20
+ [Kai Zhang](https://kai-46.github.io/website/)<sup>2</sup>,
21
+ [Hao Tan](https://research.adobe.com/person/hao-tan/)<sup>2</sup>,
22
+ [Leonidas Guibas](https://geometry.stanford.edu/?member=guibas)<sup>1</sup>,
23
+ [Gordon Wetzstein](http://web.stanford.edu/~gordonwz/)<sup>1</sup>,
24
+ [Sai Bi](https://sai-bi.github.io/)<sup>2</sup><br>
25
+ <sup>1</sup>Stanford University, <sup>2</sup>Adobe Research
26
+ <br>
27
+ [[arXiv](https://arxiv.org/abs/2510.14974)] [[pi-Qwen Demo🤗](https://huggingface.co/spaces/Lakonik/pi-Qwen)] [[pi-FLUX Demo🤗](https://huggingface.co/spaces/Lakonik/pi-FLUX.1)] [[pi-FLUX.2 Demo🤗](https://huggingface.co/spaces/Lakonik/pi-FLUX.2)]
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ torch.backends.cuda.matmul.allow_tf32 = True
4
+ torch.backends.cudnn.allow_tf32 = True
5
+
6
+ import os
7
+ import random
8
+ import numpy as np
9
+ import gradio as gr
10
+ import spaces
11
+ from lakonlab.models.diffusions.schedulers.flow_map_sde import FlowMapSDEScheduler
12
+ from lakonlab.ui.gradio.create_img_edit import create_interface_img_edit
13
+ from lakonlab.pipelines.pipeline_piflux2 import PiFlux2Pipeline
14
+ from lakonlab.pipelines.prompt_rewriters.qwen3_vl import Qwen3VLPromptRewriter
15
+
16
+
17
+ DEFAULT_PROMPT = """Museum-style FIELD GUIDE poster on neutral parchment (#F3EEE3). Use Inter (or Helvetica/Arial). All text #2D3748, thin connector lines 1px #A0AEC0.
18
+
19
+ Center: full-body original fantasy creature, 3/4 standing pose. Around it: four small inset boxes labeled exactly "EYE DETAIL", "FOOT DETAIL", "SKIN TEXTURE", "SILHOUETTE SCALE" (with a simple human comparison silhouette). Bottom: a short footprint trail diagram. One small habitat vignette (misty rocky shoreline with tide pools).
20
+
21
+ Exact text (only these, clean print layout):
22
+ Top: "FIELD GUIDE"
23
+ Sub: "AURORA SHOREWALKER"
24
+ Small line: "CLASS: COASTAL DRIFTER"
25
+ Under silhouette: "HEIGHT: 1.7 m"
26
+
27
+ Crisp ink outlines with soft watercolor-like fills, high readability, balanced hierarchy, premium poster aesthetic."""
28
+
29
+ SYSTEM_PROMPT_TEXT_ONLY_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_text_only.txt'
30
+ SYSTEM_PROMPT_WITH_IMAGES_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_with_images.txt'
31
+
32
+
33
+ def _patch_diffusers_bnb_shape_check():
34
+ try:
35
+ import diffusers.quantizers.bitsandbytes.bnb_quantizer as bnbq
36
+ except Exception:
37
+ return
38
+
39
+ def _numel(shape):
40
+ if shape is None:
41
+ return None
42
+ if hasattr(shape, "numel"): # torch.Size
43
+ return int(shape.numel())
44
+ # plain tuple/list
45
+ n = 1
46
+ for d in shape:
47
+ n *= int(d)
48
+ return n
49
+
50
+ def patched_check(self, param_name, current_param, loaded_param):
51
+ cshape = getattr(current_param, "shape", None)
52
+ lshape = getattr(loaded_param, "shape", None)
53
+ n = _numel(cshape)
54
+ inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
55
+ if tuple(lshape) != tuple(inferred_shape):
56
+ raise ValueError(
57
+ f"Expected flattened shape mismatch for {param_name}: "
58
+ f"loaded={tuple(lshape)} inferred={tuple(inferred_shape)}"
59
+ )
60
+ return True
61
+
62
+ # Patch any quantizer class in that module that defines the method
63
+ for name, obj in vars(bnbq).items():
64
+ if isinstance(obj, type) and hasattr(obj, "check_quantized_param_shape"):
65
+ setattr(obj, "check_quantized_param_shape", patched_check)
66
+
67
+
68
+ _patch_diffusers_bnb_shape_check()
69
+
70
+
71
+ pipe = PiFlux2Pipeline.from_pretrained(
72
+ 'diffusers/FLUX.2-dev-bnb-4bit',
73
+ torch_dtype=torch.bfloat16)
74
+ pipe.load_piflow_adapter(
75
+ 'Lakonik/pi-FLUX.2',
76
+ subfolder='gmflux2_k8_piid_4step',
77
+ target_module_name='transformer')
78
+ pipe.scheduler = FlowMapSDEScheduler.from_config( # use fixed shift=3.2
79
+ pipe.scheduler.config, shift=3.2, use_dynamic_shifting=False, final_step_size_scale=0.5)
80
+ pipe = pipe.to('cuda')
81
+
82
+ prompt_rewriter = Qwen3VLPromptRewriter(
83
+ device_map="cuda",
84
+ system_prompt_text_only=open(SYSTEM_PROMPT_TEXT_ONLY_PATH, 'r').read(),
85
+ system_prompt_wigh_images=open(SYSTEM_PROMPT_WITH_IMAGES_PATH, 'r').read(),
86
+ max_new_tokens_default=512,
87
+ )
88
+
89
+
90
+ def set_random_seed(seed: int, deterministic: bool = True) -> None:
91
+ random.seed(seed)
92
+ np.random.seed(seed)
93
+ torch.manual_seed(seed)
94
+ torch.cuda.manual_seed(seed)
95
+ torch.cuda.manual_seed_all(seed)
96
+ os.environ['PYTHONHASHSEED'] = str(seed)
97
+ if deterministic:
98
+ torch.backends.cudnn.deterministic = True
99
+ torch.backends.cudnn.benchmark = False
100
+
101
+
102
+ @spaces.GPU
103
+ def run_rewrite_prompt(seed, prompt, rewrite_prompt, in_image, progress=gr.Progress(track_tqdm=True)):
104
+ image_list = None
105
+ if in_image is not None and len(in_image) > 0:
106
+ image_list = []
107
+ for item in in_image:
108
+ image_list.append(item[0])
109
+ if rewrite_prompt:
110
+ set_random_seed(seed)
111
+ progress(0.05, desc="Rewriting prompt...")
112
+ if image_list is None:
113
+ final_prompt = prompt_rewriter.rewrite_text_batch(
114
+ [prompt])[0]
115
+ else:
116
+ final_prompt = prompt_rewriter.rewrite_edit_batch(
117
+ [image_list], [prompt])[0]
118
+ return final_prompt, None
119
+ else:
120
+ return '', None
121
+
122
+
123
+ @spaces.GPU
124
+ def generate(
125
+ seed, prompt, rewrite_prompt, rewritten_prompt, in_image, width, height, steps,
126
+ progress=gr.Progress(track_tqdm=True)):
127
+ image_list = None
128
+ if in_image is not None and len(in_image) > 0:
129
+ image_list = []
130
+ for item in in_image:
131
+ image_list.append(item[0])
132
+ return pipe(
133
+ image=image_list,
134
+ prompt=rewritten_prompt if rewrite_prompt else prompt,
135
+ width=width,
136
+ height=height,
137
+ num_inference_steps=steps,
138
+ generator=torch.Generator().manual_seed(seed),
139
+ ).images[0]
140
+
141
+
142
+ with gr.Blocks(analytics_enabled=False,
143
+ title='pi-FLUX.2 Demo',
144
+ css_paths='lakonlab/ui/gradio/style.css'
145
+ ) as demo:
146
+
147
+ md_txt = '# pi-FLUX.2 Demo\n\n' \
148
+ 'Official demo of the paper [pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation](https://arxiv.org/abs/2510.14974). ' \
149
+ '**Base model:** [FLUX.2 dev](https://huggingface.co/black-forest-labs/FLUX.2-dev). **Fast policy:** GMFlow. **Code:** [https://github.com/Lakonik/piFlow](https://github.com/Lakonik/piFlow).\n' \
150
+ '<br> Use and distribution of this app are governed by the [FLUX [dev] Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/LICENSE.txt).'
151
+ gr.Markdown(md_txt)
152
+
153
+ create_interface_img_edit(
154
+ generate,
155
+ prompt=DEFAULT_PROMPT,
156
+ steps=4, guidance_scale=None,
157
+ args=['last_seed', 'prompt', 'rewrite_prompt', 'rewritten_prompt', 'in_image', 'width', 'height', 'steps'],
158
+ rewrite_prompt_api=run_rewrite_prompt,
159
+ rewrite_prompt_args=['last_seed', 'prompt', 'rewrite_prompt', 'in_image'],
160
+ height=1024,
161
+ width=1024
162
+ )
163
+ demo.queue().launch()
lakonlab/__init__.py ADDED
File without changes
lakonlab/models/__init__.py ADDED
File without changes
lakonlab/models/architecture/__init__.py ADDED
File without changes
lakonlab/models/architecture/gmflow/__init__.py ADDED
File without changes
lakonlab/models/architecture/gmflow/gm_output.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from diffusers.utils import BaseOutput
4
+
5
+
6
+ @dataclass
7
+ class GMFlowModelOutput(BaseOutput):
8
+ """
9
+ The output of GMFlow models.
10
+
11
+ Args:
12
+ means (`torch.Tensor` of shape `(batch_size, num_gaussians, num_channels, height, width)` or
13
+ `(batch_size, num_gaussians, num_channels, frame, height, width)`):
14
+ Gaussian mixture means.
15
+ logweights (`torch.Tensor` of shape `(batch_size, num_gaussians, 1, height, width)` or
16
+ `(batch_size, num_gaussians, 1, frame, height, width)`):
17
+ Gaussian mixture log-weights (logits).
18
+ logstds (`torch.Tensor` of shape `(batch_size, 1, 1, 1, 1)` or `(batch_size, 1, 1, 1, 1, 1)`):
19
+ Gaussian mixture log-standard-deviations (logstds are shared across all Gaussians and channels).
20
+ """
21
+
22
+ means: torch.Tensor
23
+ logweights: torch.Tensor
24
+ logstds: torch.Tensor
lakonlab/models/architecture/gmflow/gmflux2.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Any, Dict, Optional, Tuple, List
5
+ from diffusers.models import ModelMixin
6
+ from diffusers.models.transformers.transformer_flux2 import (
7
+ Flux2Transformer2DModel, Flux2PosEmbed, Flux2TransformerBlock, Flux2SingleTransformerBlock,
8
+ Flux2TimestepGuidanceEmbeddings, Flux2Modulation)
9
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
10
+ from diffusers.configuration_utils import register_to_config
11
+ from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers, is_torch_npu_available
12
+ from .gm_output import GMFlowModelOutput
13
+
14
+
15
+ class _GMFlux2Transformer2DModel(Flux2Transformer2DModel):
16
+
17
+ @register_to_config
18
+ def __init__(
19
+ self,
20
+ num_gaussians=16,
21
+ constant_logstd=None,
22
+ logstd_inner_dim=1024,
23
+ gm_num_logstd_layers=2,
24
+ logweights_channels=1,
25
+ in_channels: int = 128,
26
+ out_channels: Optional[int] = None,
27
+ num_layers: int = 8,
28
+ num_single_layers: int = 48,
29
+ attention_head_dim: int = 128,
30
+ num_attention_heads: int = 48,
31
+ joint_attention_dim: int = 15360,
32
+ timestep_guidance_channels: int = 256,
33
+ mlp_ratio: float = 3.0,
34
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
35
+ rope_theta: int = 2000,
36
+ eps: float = 1e-6):
37
+ super(Flux2Transformer2DModel, self).__init__()
38
+
39
+ self.num_gaussians = num_gaussians
40
+ self.logweights_channels = logweights_channels
41
+
42
+ self.out_channels = out_channels or in_channels
43
+ self.inner_dim = num_attention_heads * attention_head_dim
44
+
45
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
46
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
47
+
48
+ # 2. Combined timestep + guidance embedding
49
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
50
+ in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
51
+ )
52
+
53
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
54
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
55
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
56
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
57
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
58
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
59
+
60
+ # 4. Input projections
61
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
62
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
63
+
64
+ # 5. Double Stream Transformer Blocks
65
+ self.transformer_blocks = nn.ModuleList(
66
+ [
67
+ Flux2TransformerBlock(
68
+ dim=self.inner_dim,
69
+ num_attention_heads=num_attention_heads,
70
+ attention_head_dim=attention_head_dim,
71
+ mlp_ratio=mlp_ratio,
72
+ eps=eps,
73
+ bias=False,
74
+ )
75
+ for _ in range(num_layers)
76
+ ]
77
+ )
78
+
79
+ # 6. Single Stream Transformer Blocks
80
+ self.single_transformer_blocks = nn.ModuleList(
81
+ [
82
+ Flux2SingleTransformerBlock(
83
+ dim=self.inner_dim,
84
+ num_attention_heads=num_attention_heads,
85
+ attention_head_dim=attention_head_dim,
86
+ mlp_ratio=mlp_ratio,
87
+ eps=eps,
88
+ bias=False,
89
+ )
90
+ for _ in range(num_single_layers)
91
+ ]
92
+ )
93
+
94
+ # 7. Output layers
95
+ self.norm_out = AdaLayerNormContinuous(
96
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
97
+ )
98
+ self.proj_out_means = nn.Linear(self.inner_dim, self.num_gaussians * self.out_channels)
99
+ self.proj_out_logweights = nn.Linear(self.inner_dim, self.num_gaussians * self.logweights_channels)
100
+ self.constant_logstd = constant_logstd
101
+
102
+ if self.constant_logstd is None:
103
+ assert gm_num_logstd_layers >= 1
104
+ in_dim = self.inner_dim
105
+ logstd_layers = []
106
+ for _ in range(gm_num_logstd_layers - 1):
107
+ logstd_layers.extend([
108
+ nn.SiLU(),
109
+ nn.Linear(in_dim, logstd_inner_dim)])
110
+ in_dim = logstd_inner_dim
111
+ self.proj_out_logstds = nn.Sequential(
112
+ *logstd_layers,
113
+ nn.SiLU(),
114
+ nn.Linear(in_dim, 1))
115
+
116
+ self.gradient_checkpointing = False
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ encoder_hidden_states: torch.Tensor = None,
122
+ timestep: torch.LongTensor = None,
123
+ img_ids: torch.Tensor = None,
124
+ txt_ids: torch.Tensor = None,
125
+ guidance: torch.Tensor = None,
126
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,):
127
+ # 0. Handle input arguments
128
+ if joint_attention_kwargs is not None:
129
+ joint_attention_kwargs = joint_attention_kwargs.copy()
130
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
131
+ else:
132
+ lora_scale = 1.0
133
+
134
+ if USE_PEFT_BACKEND:
135
+ scale_lora_layers(self, lora_scale)
136
+ else:
137
+ assert joint_attention_kwargs is None or joint_attention_kwargs.get('scale', None) is None
138
+
139
+ num_txt_tokens = encoder_hidden_states.shape[1]
140
+
141
+ # 1. Calculate timestep embedding and modulation parameters
142
+ timestep = timestep.to(hidden_states.dtype) * 1000
143
+ guidance = guidance.to(hidden_states.dtype) * 1000
144
+
145
+ temb = self.time_guidance_embed(timestep, guidance)
146
+
147
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
148
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
149
+ single_stream_mod = self.single_stream_modulation(temb)[0]
150
+
151
+ # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
152
+ hidden_states = self.x_embedder(hidden_states)
153
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
154
+
155
+ # 3. Calculate RoPE embeddings from image and text tokens
156
+ # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
157
+ # text prompts of differents lengths. Is this a use case we want to support?
158
+ if img_ids.ndim == 3:
159
+ img_ids = img_ids[0]
160
+ if txt_ids.ndim == 3:
161
+ txt_ids = txt_ids[0]
162
+
163
+ if is_torch_npu_available():
164
+ freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
165
+ image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
166
+ freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
167
+ text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
168
+ else:
169
+ image_rotary_emb = self.pos_embed(img_ids)
170
+ text_rotary_emb = self.pos_embed(txt_ids)
171
+ concat_rotary_emb = (
172
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0).to(hidden_states.dtype),
173
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0).to(hidden_states.dtype),
174
+ )
175
+
176
+ # 4. Double Stream Transformer Blocks
177
+ for index_block, block in enumerate(self.transformer_blocks):
178
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
179
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
180
+ block,
181
+ hidden_states,
182
+ encoder_hidden_states,
183
+ double_stream_mod_img,
184
+ double_stream_mod_txt,
185
+ concat_rotary_emb,
186
+ joint_attention_kwargs,
187
+ )
188
+ else:
189
+ encoder_hidden_states, hidden_states = block(
190
+ hidden_states=hidden_states,
191
+ encoder_hidden_states=encoder_hidden_states,
192
+ temb_mod_params_img=double_stream_mod_img,
193
+ temb_mod_params_txt=double_stream_mod_txt,
194
+ image_rotary_emb=concat_rotary_emb,
195
+ joint_attention_kwargs=joint_attention_kwargs,
196
+ )
197
+ # Concatenate text and image streams for single-block inference
198
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
199
+
200
+ # 5. Single Stream Transformer Blocks
201
+ for index_block, block in enumerate(self.single_transformer_blocks):
202
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
203
+ hidden_states = self._gradient_checkpointing_func(
204
+ block,
205
+ hidden_states,
206
+ None,
207
+ single_stream_mod,
208
+ concat_rotary_emb,
209
+ joint_attention_kwargs,
210
+ )
211
+ else:
212
+ hidden_states = block(
213
+ hidden_states=hidden_states,
214
+ encoder_hidden_states=None,
215
+ temb_mod_params=single_stream_mod,
216
+ image_rotary_emb=concat_rotary_emb,
217
+ joint_attention_kwargs=joint_attention_kwargs,
218
+ )
219
+ # Remove text tokens from concatenated stream
220
+ hidden_states = hidden_states[:, num_txt_tokens:, ...]
221
+
222
+ # 6. Output layers
223
+ hidden_states = self.norm_out(hidden_states, temb)
224
+
225
+ bs, seq_len, _ = hidden_states.size()
226
+ out_means = self.proj_out_means(hidden_states).reshape(
227
+ bs, seq_len, self.num_gaussians, self.out_channels)
228
+ out_logweights = self.proj_out_logweights(hidden_states).reshape(
229
+ bs, seq_len, self.num_gaussians, self.logweights_channels).log_softmax(dim=-2)
230
+ if self.constant_logstd is None:
231
+ out_logstds = self.proj_out_logstds(temb.detach()).reshape(bs, 1, 1, 1)
232
+ else:
233
+ out_logstds = hidden_states.new_full((bs, 1, 1, 1), float(self.constant_logstd))
234
+
235
+ if USE_PEFT_BACKEND:
236
+ unscale_lora_layers(self, lora_scale)
237
+
238
+ return GMFlowModelOutput(
239
+ means=out_means,
240
+ logweights=out_logweights,
241
+ logstds=out_logstds)
lakonlab/models/diffusions/__init__.py ADDED
File without changes
lakonlab/models/diffusions/piflow_policies/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .dx import DXPolicy
2
+ from .gmflow import GMFlowPolicy
3
+
4
+
5
+ POLICY_CLASSES = dict(
6
+ DX=DXPolicy,
7
+ GMFlow=GMFlowPolicy
8
+ )
lakonlab/models/diffusions/piflow_policies/base.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+
4
+ class BasePolicy(metaclass=ABCMeta):
5
+
6
+ @abstractmethod
7
+ def pi(self, x_t, sigma_t):
8
+ """Compute the flow velocity at (x_t, t).
9
+
10
+ Args:
11
+ x_t (torch.Tensor): Noisy input at time t.
12
+ sigma_t (torch.Tensor): Noise level at time t.
13
+
14
+ Returns:
15
+ torch.Tensor: The computed flow velocity u_t.
16
+ """
17
+ pass
18
+
19
+ @abstractmethod
20
+ def detach(self):
21
+ pass
lakonlab/models/diffusions/piflow_policies/dx.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import torch
4
+ from .base import BasePolicy
5
+
6
+
7
+ class DXPolicy(BasePolicy):
8
+ """DX policy. The number of grid points N is inferred from the denoising output.
9
+
10
+ Note: segment_size and shift are intrinsic parameters of the DX policy. For elastic inference (i.e., changing
11
+ the number of function evaluations or noise schedule at test time), these parameters should be kept unchanged.
12
+
13
+ Args:
14
+ denoising_output (torch.Tensor): The output of the denoising model. Shape (B, N, C, H, W) or (B, N, C, T, H, W).
15
+ x_t_src (torch.Tensor): The initial noisy sample. Shape (B, C, H, W) or (B, C, T, H, W).
16
+ sigma_t_src (torch.Tensor): The initial noise level. Shape (B,).
17
+ segment_size (float): The size of each DX policy time segment. Defaults to 1.0.
18
+ shift (float): The shift parameter for the DX policy noise schedule. Defaults to 1.0.
19
+ mode (str): Either 'grid' or 'polynomial' mode for calculating x_0. Defaults to 'grid'.
20
+ eps (float): A small value to avoid numerical issues. Defaults to 1e-4.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ denoising_output: torch.Tensor,
26
+ x_t_src: torch.Tensor,
27
+ sigma_t_src: torch.Tensor,
28
+ segment_size: float = 1.0,
29
+ shift: float = 1.0,
30
+ mode: str = 'grid',
31
+ eps: float = 1e-4):
32
+ self.x_t_src = x_t_src
33
+ self.ndim = x_t_src.dim()
34
+ self.shift = shift
35
+ self.eps = eps
36
+
37
+ assert mode in ['grid', 'polynomial']
38
+ self.mode = mode
39
+
40
+ self.sigma_t_src = sigma_t_src.reshape(*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1]))
41
+ self.raw_t_src = self._unwarp_t(self.sigma_t_src)
42
+ self.raw_t_dst = (self.raw_t_src - segment_size).clamp(min=0)
43
+ self.segment_size = (self.raw_t_src - self.raw_t_dst).clamp(min=eps)
44
+
45
+ self.denoising_output_x_0 = self._u_to_x_0(
46
+ denoising_output, self.x_t_src, self.sigma_t_src)
47
+
48
+ def _unwarp_t(self, sigma_t):
49
+ return sigma_t / (self.shift + (1 - self.shift) * sigma_t)
50
+
51
+ @staticmethod
52
+ def _u_to_x_0(denoising_output, x_t, sigma_t):
53
+ x_0 = x_t.unsqueeze(1) - sigma_t.unsqueeze(1) * denoising_output
54
+ return x_0
55
+
56
+ @staticmethod
57
+ def _interpolate(x, t):
58
+ """
59
+ Args:
60
+ x (torch.Tensor): (B, N, *)
61
+ t (torch.Tensor): (B, *) in [0, 1]
62
+
63
+ Returns:
64
+ torch.Tensor: (B, *)
65
+ """
66
+ n = x.size(1)
67
+ if n < 2:
68
+ return x.squeeze(1)
69
+ t = t.clamp(min=0, max=1) * (n - 1)
70
+ t0 = t.floor().to(torch.long).clamp(min=0, max=n - 2)
71
+ t1 = t0 + 1
72
+ t0t1 = torch.stack([t0, t1], dim=1) # (B, 2, *)
73
+ x0x1 = torch.gather(x, dim=1, index=t0t1.expand(-1, -1, *x.shape[2:]))
74
+ x_interp = (t1 - t) * x0x1[:, 0] + (t - t0) * x0x1[:, 1]
75
+ return x_interp
76
+
77
+ def pi(self, x_t, sigma_t):
78
+ """Compute the flow velocity at (x_t, t).
79
+
80
+ Args:
81
+ x_t (torch.Tensor): Noisy input at time t.
82
+ sigma_t (torch.Tensor): Noise level at time t.
83
+
84
+ Returns:
85
+ torch.Tensor: The computed flow velocity u_t.
86
+ """
87
+ sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
88
+ raw_t = self._unwarp_t(sigma_t)
89
+ if self.mode == 'grid':
90
+ x_0 = self._interpolate(
91
+ self.denoising_output_x_0, (raw_t - self.raw_t_dst) / self.segment_size)
92
+ elif self.mode == 'polynomial':
93
+ p_order = self.denoising_output_x_0.size(1)
94
+ diff_t = self.raw_t_src - raw_t # (B, 1, 1, 1)
95
+ basis = torch.stack(
96
+ [diff_t ** i for i in range(p_order)], dim=1) # (B, N, 1, 1, 1)
97
+ x_0 = torch.sum(basis * self.denoising_output_x_0, dim=1)
98
+ else:
99
+ raise ValueError(f"Unknown mode: {self.mode}")
100
+ u = (x_t - x_0) / sigma_t.clamp(min=self.eps)
101
+ return u
102
+
103
+ def copy(self):
104
+ new_policy = DXPolicy.__new__(DXPolicy)
105
+ new_policy.x_t_src = self.x_t_src
106
+ new_policy.ndim = self.ndim
107
+ new_policy.shift = self.shift
108
+ new_policy.eps = self.eps
109
+ new_policy.mode = self.mode
110
+ new_policy.sigma_t_src = self.sigma_t_src
111
+ new_policy.raw_t_src = self.raw_t_src
112
+ new_policy.raw_t_dst = self.raw_t_dst
113
+ new_policy.segment_size = self.segment_size
114
+ new_policy.denoising_output_x_0 = self.denoising_output_x_0
115
+ return new_policy
116
+
117
+ def detach_(self):
118
+ self.denoising_output_x_0 = self.denoising_output_x_0.detach()
119
+ return self
120
+
121
+ def detach(self):
122
+ new_policy = self.copy()
123
+ return new_policy.detach_()
lakonlab/models/diffusions/piflow_policies/gmflow.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import math
4
+ import torch
5
+
6
+ from typing import Dict
7
+ from .base import BasePolicy
8
+
9
+
10
+ @torch.jit.script
11
+ def gmflow_posterior_mean_jit(
12
+ sigma_t_src, sigma_t, x_t_src, x_t,
13
+ gm_means, gm_vars, gm_logweights,
14
+ eps: float, gm_dim: int = -4, channel_dim: int = -3):
15
+ sigma_t_src = sigma_t_src.clamp(min=eps)
16
+ sigma_t = sigma_t.clamp(min=eps)
17
+
18
+ alpha_t_src = 1 - sigma_t_src
19
+ alpha_t = 1 - sigma_t
20
+
21
+ alpha_over_sigma_t_src = alpha_t_src / sigma_t_src
22
+ alpha_over_sigma_t = alpha_t / sigma_t
23
+
24
+ zeta = alpha_over_sigma_t.square() - alpha_over_sigma_t_src.square()
25
+ nu = alpha_over_sigma_t * x_t / sigma_t - alpha_over_sigma_t_src * x_t_src / sigma_t_src
26
+
27
+ nu = nu.unsqueeze(gm_dim) # (bs, *, 1, out_channels, h, w)
28
+ denom = (gm_vars * zeta + 1).clamp(min=eps)
29
+
30
+ out_means = (gm_vars * nu + gm_means) / denom
31
+ # (bs, *, num_gaussians, 1, h, w)
32
+ logweights_delta = (gm_means * (nu - 0.5 * zeta * gm_means)).sum(
33
+ dim=channel_dim, keepdim=True) / denom
34
+ out_weights = (gm_logweights + logweights_delta).softmax(dim=gm_dim)
35
+
36
+ out_mean = (out_means * out_weights).sum(dim=gm_dim)
37
+
38
+ return out_mean
39
+
40
+
41
+ def gm_temperature(gm, temperature, gm_dim=-4, eps=1e-6):
42
+ gm = gm.copy()
43
+ temperature = max(temperature, eps)
44
+ gm['logweights'] = (gm['logweights'] / temperature).log_softmax(dim=gm_dim)
45
+ if 'logstds' in gm:
46
+ gm['logstds'] = gm['logstds'] + (0.5 * math.log(temperature))
47
+ if 'gm_vars' in gm:
48
+ gm['gm_vars'] = gm['gm_vars'] * temperature
49
+ return gm
50
+
51
+
52
+ class GMFlowPolicy(BasePolicy):
53
+ """GMFlow policy. The number of components K is inferred from the denoising output.
54
+
55
+ Args:
56
+ denoising_output (dict): The output of the denoising model, containing:
57
+ means (torch.Tensor): The means of the Gaussian components. Shape (B, K, C, H, W) or (B, K, C, T, H, W).
58
+ logstds (torch.Tensor): The log standard deviations of the Gaussian components. Shape (B, K, 1, 1, 1)
59
+ or (B, K, 1, 1, 1, 1).
60
+ logweights (torch.Tensor): The log weights of the Gaussian components. Shape (B, K, 1, H, W) or
61
+ (B, K, 1, T, H, W).
62
+ x_t_src (torch.Tensor): The initial noisy sample. Shape (B, C, H, W) or (B, C, T, H, W).
63
+ sigma_t_src (torch.Tensor): The initial noise level. Shape (B,).
64
+ checkpointing (bool): Whether to use gradient checkpointing to save memory. Defaults to True.
65
+ eps (float): A small value to avoid numerical issues. Defaults to 1e-4.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ denoising_output: Dict[str, torch.Tensor],
71
+ x_t_src: torch.Tensor,
72
+ sigma_t_src: torch.Tensor,
73
+ checkpointing: bool = True,
74
+ eps: float = 1e-6):
75
+ self.x_t_src = x_t_src
76
+ self.ndim = x_t_src.dim()
77
+ self.checkpointing = checkpointing
78
+ self.eps = eps
79
+
80
+ self.sigma_t_src = sigma_t_src.reshape(*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1]))
81
+ self.denoising_output_x_0 = self._u_to_x_0(
82
+ denoising_output, self.x_t_src, self.sigma_t_src)
83
+
84
+ @staticmethod
85
+ def _u_to_x_0(denoising_output, x_t, sigma_t):
86
+ x_t = x_t.unsqueeze(1)
87
+ sigma_t = sigma_t.unsqueeze(1)
88
+ means_x_0 = x_t - sigma_t * denoising_output['means']
89
+ gm_vars = (denoising_output['logstds'] * 2).exp() * sigma_t.square()
90
+ return dict(
91
+ means=means_x_0,
92
+ gm_vars=gm_vars,
93
+ logweights=denoising_output['logweights'])
94
+
95
+ def pi(self, x_t, sigma_t):
96
+ """Compute the flow velocity at (x_t, t).
97
+
98
+ Args:
99
+ x_t (torch.Tensor): Noisy input at time t.
100
+ sigma_t (torch.Tensor): Noise level at time t.
101
+
102
+ Returns:
103
+ torch.Tensor: The computed flow velocity u_t.
104
+ """
105
+ sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
106
+ means = self.denoising_output_x_0['means']
107
+ gm_vars = self.denoising_output_x_0['gm_vars']
108
+ logweights = self.denoising_output_x_0['logweights']
109
+ if (sigma_t == self.sigma_t_src).all() and (x_t == self.x_t_src).all():
110
+ x_0 = (logweights.softmax(dim=1) * means).sum(dim=1)
111
+ else:
112
+ if self.checkpointing and torch.is_grad_enabled():
113
+ x_0 = torch.utils.checkpoint.checkpoint(
114
+ gmflow_posterior_mean_jit,
115
+ self.sigma_t_src, sigma_t, self.x_t_src, x_t,
116
+ means,
117
+ gm_vars,
118
+ logweights,
119
+ self.eps, 1, 2,
120
+ use_reentrant=True) # use_reentrant=False does not work with jit
121
+ else:
122
+ x_0 = gmflow_posterior_mean_jit(
123
+ self.sigma_t_src, sigma_t, self.x_t_src, x_t,
124
+ means,
125
+ gm_vars,
126
+ logweights,
127
+ self.eps, 1, 2)
128
+ u = (x_t - x_0) / sigma_t.clamp(min=self.eps)
129
+ return u
130
+
131
+ def copy(self):
132
+ new_policy = GMFlowPolicy.__new__(GMFlowPolicy)
133
+ new_policy.x_t_src = self.x_t_src
134
+ new_policy.ndim = self.ndim
135
+ new_policy.checkpointing = self.checkpointing
136
+ new_policy.eps = self.eps
137
+ new_policy.sigma_t_src = self.sigma_t_src
138
+ new_policy.denoising_output_x_0 = self.denoising_output_x_0.copy()
139
+ return new_policy
140
+
141
+ def detach_(self):
142
+ self.denoising_output_x_0 = {k: v.detach() for k, v in self.denoising_output_x_0.items()}
143
+ return self
144
+
145
+ def detach(self):
146
+ new_policy = self.copy()
147
+ return new_policy.detach_()
148
+
149
+ def dropout_(self, p):
150
+ if p <= 0 or p >= 1:
151
+ return self
152
+ logweights = self.denoising_output_x_0['logweights']
153
+ dropout_mask = torch.rand(
154
+ (*logweights.shape[:2], *((self.ndim - 1) * [1])), device=logweights.device) < p
155
+ is_all_dropout = dropout_mask.all(dim=1, keepdim=True)
156
+ dropout_mask &= ~is_all_dropout
157
+ self.denoising_output_x_0['logweights'] = logweights.masked_fill(
158
+ dropout_mask, float('-inf'))
159
+ return self
160
+
161
+ def dropout(self, p):
162
+ new_policy = self.copy()
163
+ return new_policy.dropout_(p)
164
+
165
+ def temperature_(self, temp):
166
+ if temp >= 1.0:
167
+ return self
168
+ self.denoising_output_x_0 = gm_temperature(
169
+ self.denoising_output_x_0, temp, gm_dim=1, eps=self.eps)
170
+ return self
171
+
172
+ def temperature(self, temp):
173
+ new_policy = self.copy()
174
+ return new_policy.temperature_(temp)
lakonlab/models/diffusions/schedulers/__init__.py ADDED
File without changes
lakonlab/models/diffusions/schedulers/flow_map_sde.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.utils import BaseOutput, logging
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
12
+
13
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
14
+
15
+
16
+ @dataclass
17
+ class FlowMapSDESchedulerOutput(BaseOutput):
18
+ prev_sample: torch.FloatTensor
19
+
20
+
21
+ class FlowMapSDEScheduler(SchedulerMixin, ConfigMixin):
22
+
23
+ _compatibles = []
24
+ order = 1
25
+
26
+ @register_to_config
27
+ def __init__(
28
+ self,
29
+ num_train_timesteps: int = 1000,
30
+ h: Union[float, str] = 0.0,
31
+ shift: float = 1.0,
32
+ use_dynamic_shifting=False,
33
+ base_seq_len=256,
34
+ max_seq_len=4096,
35
+ base_logshift=0.5,
36
+ max_logshift=1.15,
37
+ final_step_size_scale=1.0):
38
+ sigmas = torch.from_numpy(1 - np.linspace(
39
+ 0, 1, num_train_timesteps, dtype=np.float32, endpoint=False))
40
+ self.sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
41
+ self.timesteps = self.sigmas * num_train_timesteps
42
+
43
+ self._step_index = None
44
+ self._begin_index = None
45
+
46
+ self.sigma_min = self.sigmas[-1].item()
47
+ self.sigma_max = self.sigmas[0].item()
48
+
49
+ @property
50
+ def step_index(self):
51
+ return self._step_index
52
+
53
+ @property
54
+ def begin_index(self):
55
+ return self._begin_index
56
+
57
+ def set_begin_index(self, begin_index: int = 0):
58
+ self._begin_index = begin_index
59
+
60
+ def get_shift(self, seq_len=None):
61
+ if self.config.use_dynamic_shifting and seq_len is not None:
62
+ m = (self.config.max_logshift - self.config.base_logshift
63
+ ) / (self.config.max_seq_len - self.config.base_seq_len)
64
+ logshift = (seq_len - self.config.base_seq_len) * m + self.config.base_logshift
65
+ if isinstance(logshift, torch.Tensor):
66
+ shift = torch.exp(logshift)
67
+ else:
68
+ shift = np.exp(logshift)
69
+ else:
70
+ shift = self.config.shift
71
+ return shift
72
+
73
+ def warp_t(self, raw_t, seq_len=None):
74
+ shift = self.get_shift(seq_len=seq_len)
75
+ return shift * raw_t / (1 + (shift - 1) * raw_t)
76
+
77
+ def unwarp_t(self, sigma_t, seq_len=None):
78
+ shift = self.get_shift(seq_len=seq_len)
79
+ return sigma_t / (shift + (1 - shift) * sigma_t)
80
+
81
+ def set_timesteps(self, num_inference_steps: int, seq_len=None, device=None):
82
+ self.num_inference_steps = num_inference_steps
83
+
84
+ raw_timesteps = torch.from_numpy(np.linspace(
85
+ 1, (self.config.final_step_size_scale - 1) / (num_inference_steps + self.config.final_step_size_scale - 1),
86
+ num_inference_steps, dtype=np.float32, endpoint=False)).to(device).clamp(min=0)
87
+ sigmas = self.warp_t(raw_timesteps, seq_len=seq_len)
88
+
89
+ self.timesteps = sigmas * self.config.num_train_timesteps
90
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=device)])
91
+
92
+ sigmas_dst, m = self.calculate_sigmas_dst(self.sigmas)
93
+ self.timesteps_dst = sigmas_dst * self.config.num_train_timesteps
94
+ self.m_vals = m
95
+
96
+ self._step_index = None
97
+ self._begin_index = None
98
+
99
+ def calculate_sigmas_dst(self, sigmas, eps=1e-6):
100
+ alphas = 1 - sigmas
101
+
102
+ sigmas_src = sigmas[:-1]
103
+ sigmas_to = sigmas[1:]
104
+ alphas_src = alphas[:-1]
105
+ alphas_to = alphas[1:]
106
+
107
+ if self.config.h == 'inf':
108
+ m = torch.zeros_like(sigmas_src)
109
+ elif self.config.h == 0.0:
110
+ m = torch.ones_like(sigmas_src)
111
+ else:
112
+ assert self.config.h > 0.0
113
+ h2 = self.config.h * self.config.h
114
+ m = (sigmas_to * alphas_src / (sigmas_src * alphas_to).clamp(min=eps)) ** h2
115
+
116
+ sigmas_to_mul_m = sigmas_to * m
117
+ sigmas_dst = sigmas_to_mul_m / (alphas_to + sigmas_to_mul_m).clamp(min=eps)
118
+
119
+ return sigmas_dst, m
120
+
121
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
122
+ if schedule_timesteps is None:
123
+ schedule_timesteps = self.timesteps
124
+
125
+ indices = (schedule_timesteps == timestep).nonzero()
126
+
127
+ pos = 1 if len(indices) > 1 else 0
128
+
129
+ return indices[pos].item()
130
+
131
+ def _init_step_index(self, timestep):
132
+ if self.begin_index is None:
133
+ if isinstance(timestep, torch.Tensor):
134
+ timestep = timestep.to(self.timesteps.device)
135
+ self._step_index = self.index_for_timestep(timestep)
136
+ else:
137
+ self._step_index = self._begin_index
138
+
139
+ def step(
140
+ self,
141
+ model_output: torch.FloatTensor,
142
+ timestep: Union[float, torch.FloatTensor],
143
+ sample: torch.FloatTensor,
144
+ generator: Optional[torch.Generator] = None,
145
+ return_dict: bool = True) -> Union[FlowMapSDESchedulerOutput, Tuple]:
146
+
147
+ if isinstance(timestep, int) \
148
+ or isinstance(timestep, torch.IntTensor) \
149
+ or isinstance(timestep, torch.LongTensor):
150
+ raise ValueError(
151
+ (
152
+ 'Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to'
153
+ ' `EulerDiscreteScheduler.step()` is not supported. Make sure to pass'
154
+ ' one of the `scheduler.timesteps` as a timestep.'
155
+ ),
156
+ )
157
+
158
+ if self.step_index is None:
159
+ self._init_step_index(timestep)
160
+
161
+ # Upcast to avoid precision issues when computing prev_sample
162
+ ori_dtype = model_output.dtype
163
+ model_output = model_output.to(torch.float32) # x_t_dst
164
+
165
+ sigma_to = self.sigmas[self.step_index + 1]
166
+ alpha_to = 1 - sigma_to
167
+ m = self.m_vals[self.step_index]
168
+
169
+ noise = randn_tensor(
170
+ model_output.shape, dtype=torch.float32, device=model_output.device, generator=generator)
171
+
172
+ prev_sample = (alpha_to + sigma_to * m) * model_output + sigma_to * (1 - m.square()).clamp(min=0).sqrt() * noise
173
+
174
+ # Cast sample back to model compatible dtype
175
+ prev_sample = prev_sample.to(ori_dtype)
176
+
177
+ # upon completion increase step index by one
178
+ self._step_index += 1
179
+
180
+ if not return_dict:
181
+ return (prev_sample,)
182
+
183
+ return FlowMapSDESchedulerOutput(prev_sample=prev_sample)
184
+
185
+ def __len__(self):
186
+ return self.config.num_train_timesteps
lakonlab/pipelines/__init__.py ADDED
File without changes
lakonlab/pipelines/piflow_utils.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import os
4
+ from typing import Union, Optional
5
+
6
+ import torch
7
+ import accelerate
8
+ import diffusers
9
+ from diffusers.models import AutoModel
10
+ from diffusers.models.modeling_utils import (
11
+ load_state_dict,
12
+ _LOW_CPU_MEM_USAGE_DEFAULT,
13
+ no_init_weights,
14
+ ContextManagers
15
+ )
16
+ from diffusers.utils import (
17
+ SAFETENSORS_WEIGHTS_NAME,
18
+ WEIGHTS_NAME,
19
+ _add_variant,
20
+ _get_model_file,
21
+ is_accelerate_available,
22
+ is_torch_version,
23
+ logging,
24
+ )
25
+ from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
26
+ from diffusers.quantizers import DiffusersAutoQuantizer
27
+ from diffusers.utils.torch_utils import empty_device_cache
28
+ from lakonlab.models.architecture.gmflow.gmflux2 import _GMFlux2Transformer2DModel
29
+
30
+
31
+ LOCAL_CLASS_MAPPING = {
32
+ "GMFlux2Transformer2DModel": _GMFlux2Transformer2DModel,
33
+ }
34
+
35
+ _SET_ADAPTER_SCALE_FN_MAPPING.update(
36
+ _GMFlux2Transformer2DModel=lambda model_cls, weights: weights,
37
+ )
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ def assign_param(module, tensor_name: str, param: torch.nn.Parameter):
43
+ if "." in tensor_name:
44
+ splits = tensor_name.split(".")
45
+ for split in splits[:-1]:
46
+ new_module = getattr(module, split)
47
+ if new_module is None:
48
+ raise ValueError(f"{module} has no attribute {split}.")
49
+ module = new_module
50
+ tensor_name = splits[-1]
51
+ module._parameters[tensor_name] = param
52
+
53
+
54
+ class PiFlowMixin:
55
+
56
+ def load_piflow_adapter(
57
+ self,
58
+ pretrained_model_name_or_path: Union[str, os.PathLike],
59
+ target_module_name: str = "transformer",
60
+ adapter_name: Optional[str] = None,
61
+ **kwargs
62
+ ):
63
+ r"""
64
+ Load a PiFlow adapter from a pretrained model repository into the target module.
65
+
66
+ Args:
67
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
68
+ Can be either:
69
+
70
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
71
+ the Hub.
72
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
73
+ with [`~ModelMixin.save_pretrained`].
74
+
75
+ target_module_name (`str`, *optional*, defaults to `"transformer"`):
76
+ The module name in the model to load the PiFlow adapter into.
77
+ adapter_name (`str`, *optional*):
78
+ The name to assign to the loaded adapter. If not provided, it defaults to
79
+ `"{target_module_name}_piflow"`.
80
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
81
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
82
+ is not used.
83
+ force_download (`bool`, *optional*, defaults to `False`):
84
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
85
+ cached versions if they exist.
86
+ proxies (`Dict[str, str]`, *optional*):
87
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
88
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
89
+ local_files_only(`bool`, *optional*, defaults to `False`):
90
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
91
+ won't be downloaded from the Hub.
92
+ token (`str` or *bool*, *optional*):
93
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
94
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
95
+ revision (`str`, *optional*, defaults to `"main"`):
96
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
97
+ allowed by Git.
98
+ subfolder (`str`, *optional*, defaults to `""`):
99
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
100
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
101
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
102
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
103
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
104
+ argument to `True` will raise an error.
105
+ variant (`str`, *optional*):
106
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
107
+ loading `from_flax`.
108
+ use_safetensors (`bool`, *optional*, defaults to `None`):
109
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
110
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
111
+ weights. If set to `False`, `safetensors` weights are not loaded.
112
+ disable_mmap ('bool', *optional*, defaults to 'False'):
113
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
114
+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
115
+
116
+ Returns:
117
+ `str` or `None`: The name assigned to the loaded adapter, or `None` if no LoRA weights were found.
118
+ """
119
+ cache_dir = kwargs.pop("cache_dir", None)
120
+ force_download = kwargs.pop("force_download", False)
121
+ proxies = kwargs.pop("proxies", None)
122
+ token = kwargs.pop("token", None)
123
+ local_files_only = kwargs.pop("local_files_only", False)
124
+ revision = kwargs.pop("revision", None)
125
+ subfolder = kwargs.pop("subfolder", None)
126
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
127
+ variant = kwargs.pop("variant", None)
128
+ use_safetensors = kwargs.pop("use_safetensors", None)
129
+ disable_mmap = kwargs.pop("disable_mmap", False)
130
+
131
+ allow_pickle = False
132
+ if use_safetensors is None:
133
+ use_safetensors = True
134
+ allow_pickle = True
135
+
136
+ if low_cpu_mem_usage and not is_accelerate_available():
137
+ low_cpu_mem_usage = False
138
+ logger.warning(
139
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
140
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
141
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
142
+ " install accelerate\n```\n."
143
+ )
144
+
145
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
146
+ raise NotImplementedError(
147
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
148
+ " `low_cpu_mem_usage=False`."
149
+ )
150
+
151
+ user_agent = {
152
+ "diffusers": diffusers.__version__,
153
+ "file_type": "model",
154
+ "framework": "pytorch",
155
+ }
156
+
157
+ # 1. Determine model class from config
158
+
159
+ load_config_kwargs = {
160
+ "cache_dir": cache_dir,
161
+ "force_download": force_download,
162
+ "proxies": proxies,
163
+ "token": token,
164
+ "local_files_only": local_files_only,
165
+ "revision": revision,
166
+ }
167
+
168
+ config = AutoModel.load_config(pretrained_model_name_or_path, subfolder=subfolder, **load_config_kwargs)
169
+
170
+ orig_class_name = config["_class_name"]
171
+
172
+ if orig_class_name in LOCAL_CLASS_MAPPING:
173
+ model_cls = LOCAL_CLASS_MAPPING[orig_class_name]
174
+
175
+ else:
176
+ load_config_kwargs.update({"subfolder": subfolder})
177
+
178
+ from diffusers.pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
179
+
180
+ model_cls, _ = get_class_obj_and_candidates(
181
+ library_name="diffusers",
182
+ class_name=orig_class_name,
183
+ importable_classes=ALL_IMPORTABLE_CLASSES,
184
+ pipelines=None,
185
+ is_pipeline_module=False,
186
+ )
187
+
188
+ if model_cls is None:
189
+ raise ValueError(f"Can't find a model linked to {orig_class_name}.")
190
+
191
+ # 2. Get model file
192
+
193
+ model_file = None
194
+
195
+ if use_safetensors:
196
+ try:
197
+ model_file = _get_model_file(
198
+ pretrained_model_name_or_path,
199
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
200
+ cache_dir=cache_dir,
201
+ force_download=force_download,
202
+ proxies=proxies,
203
+ local_files_only=local_files_only,
204
+ token=token,
205
+ revision=revision,
206
+ subfolder=subfolder,
207
+ user_agent=user_agent,
208
+ )
209
+
210
+ except IOError as e:
211
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
212
+ if not allow_pickle:
213
+ raise
214
+ logger.warning(
215
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
216
+ )
217
+
218
+ if model_file is None:
219
+ model_file = _get_model_file(
220
+ pretrained_model_name_or_path,
221
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
222
+ cache_dir=cache_dir,
223
+ force_download=force_download,
224
+ proxies=proxies,
225
+ local_files_only=local_files_only,
226
+ token=token,
227
+ revision=revision,
228
+ subfolder=subfolder,
229
+ user_agent=user_agent,
230
+ )
231
+
232
+ assert model_file is not None, \
233
+ f"Could not find adapter weights for {pretrained_model_name_or_path}."
234
+
235
+ # 3. Initialize model
236
+
237
+ base_module = getattr(self, target_module_name)
238
+
239
+ torch_dtype = base_module.dtype
240
+ device = base_module.device
241
+ dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
242
+
243
+ # load the state dict early to determine keep_in_fp32_modules
244
+ #######################################
245
+ overwrite_state_dict = dict()
246
+ lora_state_dict = dict()
247
+
248
+ adapter_state_dict = load_state_dict(model_file, disable_mmap=disable_mmap)
249
+ for k in adapter_state_dict.keys():
250
+ adapter_state_dict[k] = adapter_state_dict[k].to(dtype=torch_dtype, device=device)
251
+ if "lora" in k:
252
+ lora_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
253
+ else:
254
+ overwrite_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
255
+
256
+ # determine initial quantization config.
257
+ #######################################
258
+ pre_quantized = ("quantization_config" in base_module.config
259
+ and base_module.config["quantization_config"] is not None)
260
+ if pre_quantized:
261
+ config["quantization_config"] = base_module.config.quantization_config
262
+ hf_quantizer = DiffusersAutoQuantizer.from_config(
263
+ config["quantization_config"], pre_quantized=True
264
+ )
265
+
266
+ hf_quantizer.validate_environment(torch_dtype=torch_dtype)
267
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
268
+
269
+ user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
270
+
271
+ # Force-set to `True` for more mem efficiency
272
+ if low_cpu_mem_usage is None:
273
+ low_cpu_mem_usage = True
274
+ logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
275
+ elif not low_cpu_mem_usage:
276
+ raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
277
+
278
+ else:
279
+ hf_quantizer = None
280
+
281
+ # Check if `_keep_in_fp32_modules` is not None
282
+ use_keep_in_fp32_modules = model_cls._keep_in_fp32_modules is not None and (
283
+ hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
284
+ )
285
+
286
+ if use_keep_in_fp32_modules:
287
+ keep_in_fp32_modules = model_cls._keep_in_fp32_modules
288
+ if not isinstance(keep_in_fp32_modules, list):
289
+ keep_in_fp32_modules = [keep_in_fp32_modules]
290
+
291
+ if low_cpu_mem_usage is None:
292
+ low_cpu_mem_usage = True
293
+ logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
294
+ elif not low_cpu_mem_usage:
295
+ raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
296
+ else:
297
+ keep_in_fp32_modules = []
298
+
299
+ # append modules in overwrite_state_dict to keep_in_fp32_modules
300
+ for k in overwrite_state_dict.keys():
301
+ module_name = k.rsplit('.', 1)[0]
302
+ if module_name and module_name not in keep_in_fp32_modules:
303
+ keep_in_fp32_modules.append(module_name)
304
+
305
+ init_contexts = [no_init_weights()]
306
+
307
+ if low_cpu_mem_usage:
308
+ init_contexts.append(accelerate.init_empty_weights())
309
+
310
+ with ContextManagers(init_contexts):
311
+ piflow_module = model_cls.from_config(config).eval()
312
+
313
+ torch.set_default_dtype(dtype_orig)
314
+
315
+ if hf_quantizer is not None:
316
+ hf_quantizer.preprocess_model(
317
+ model=piflow_module, device_map=None, keep_in_fp32_modules=keep_in_fp32_modules
318
+ )
319
+
320
+ # 4. Load model weights
321
+
322
+ base_state_dict = base_module.state_dict()
323
+ base_state_dict.update(overwrite_state_dict)
324
+ empty_state_dict = piflow_module.state_dict()
325
+ for param_name, param in base_state_dict.items():
326
+ if param_name not in empty_state_dict:
327
+ continue
328
+ if hf_quantizer is not None and (
329
+ hf_quantizer.check_if_quantized_param(
330
+ piflow_module, param, param_name, base_state_dict, param_device=device)):
331
+ hf_quantizer.create_quantized_param(
332
+ piflow_module, param, param_name, device, base_state_dict, dtype=torch_dtype
333
+ )
334
+ else:
335
+ assign_param(piflow_module, param_name, param)
336
+
337
+ empty_device_cache()
338
+
339
+ if hf_quantizer is not None:
340
+ hf_quantizer.postprocess_model(piflow_module)
341
+ piflow_module.hf_quantizer = hf_quantizer
342
+
343
+ if len(lora_state_dict) == 0:
344
+ adapter_name = None
345
+ else:
346
+ if adapter_name is None:
347
+ adapter_name = f"{target_module_name}_piflow"
348
+ piflow_module.load_lora_adapter(
349
+ lora_state_dict, prefix=None, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
350
+ if adapter_name is None:
351
+ logger.warning(
352
+ f"No LoRA weights were found in {pretrained_model_name_or_path}."
353
+ )
354
+
355
+ setattr(self, target_module_name, piflow_module)
356
+
357
+ return adapter_name
358
+
359
+ def policy_rollout(
360
+ self,
361
+ x_t_start: torch.Tensor, # (B, C, *, H, W)
362
+ sigma_t_start: torch.Tensor,
363
+ sigma_t_end: torch.Tensor,
364
+ total_substeps: int,
365
+ policy,
366
+ **kwargs):
367
+ assert sigma_t_start.numel() == 1 and sigma_t_end.numel() == 1, \
368
+ "Only supports scalar sigma_t_start and sigma_t_end."
369
+ raw_t_start = self.scheduler.unwarp_t(
370
+ sigma_t_start, **kwargs)
371
+ raw_t_end = self.scheduler.unwarp_t(
372
+ sigma_t_end, **kwargs)
373
+
374
+ delta_raw_t = raw_t_start - raw_t_end
375
+ num_substeps = (delta_raw_t * total_substeps).round().to(torch.long).clamp(min=1)
376
+ substep_size = delta_raw_t / num_substeps
377
+
378
+ raw_t = raw_t_start
379
+ sigma_t = sigma_t_start
380
+ x_t = x_t_start
381
+
382
+ for substep_id in range(num_substeps.item()):
383
+ u = policy.pi(x_t, sigma_t)
384
+
385
+ raw_t_minus = (raw_t - substep_size).clamp(min=0)
386
+ sigma_t_minus = self.scheduler.warp_t(raw_t_minus, **kwargs)
387
+ x_t_minus = x_t + u * (sigma_t_minus - sigma_t)
388
+
389
+ x_t = x_t_minus
390
+ sigma_t = sigma_t_minus
391
+ raw_t = raw_t_minus
392
+
393
+ x_t_end = x_t
394
+
395
+ return x_t_end
lakonlab/pipelines/pipeline_piflux2.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import PIL
4
+ import torch
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+ from functools import partial
7
+ from transformers import AutoProcessor, Mistral3ForConditionalGeneration
8
+ from diffusers.utils import is_torch_xla_available
9
+ from diffusers.models import AutoencoderKLFlux2, Flux2Transformer2DModel
10
+ from diffusers.pipelines.flux2.pipeline_flux2 import Flux2Pipeline, Flux2PipelineOutput
11
+ from lakonlab.models.diffusions.schedulers.flow_map_sde import FlowMapSDEScheduler
12
+ from lakonlab.models.diffusions.piflow_policies import POLICY_CLASSES
13
+ from .piflow_utils import PiFlowMixin
14
+
15
+
16
+ if is_torch_xla_available():
17
+ import torch_xla.core.xla_model as xm
18
+
19
+ XLA_AVAILABLE = True
20
+ else:
21
+ XLA_AVAILABLE = False
22
+
23
+
24
+ class PiFlux2Pipeline(Flux2Pipeline, PiFlowMixin):
25
+ r"""
26
+ The policy-based Flux2 pipeline for text-to-image generation.
27
+
28
+ Reference:
29
+ https://arxiv.org/abs/2510.14974
30
+ https://bfl.ai/blog/flux-2
31
+
32
+ Args:
33
+ transformer ([`Flux2Transformer2DModel`]):
34
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
35
+ scheduler ([`FlowMapSDEScheduler`]):
36
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
37
+ vae ([`AutoencoderKLFlux2`]):
38
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
39
+ text_encoder ([`Mistral3ForConditionalGeneration`]):
40
+ [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
41
+ tokenizer (`AutoProcessor`):
42
+ Tokenizer of class
43
+ [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
44
+ policy_type (`str`, *optional*, defaults to `"GMFlow"`):
45
+ The type of flow policy to use. Currently supports `"GMFlow"` and `"DX"`.
46
+ policy_kwargs (`Dict`, *optional*):
47
+ Additional keyword arguments to pass to the policy class.
48
+ """
49
+
50
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
51
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
52
+
53
+ def __init__(
54
+ self,
55
+ scheduler: FlowMapSDEScheduler,
56
+ vae: AutoencoderKLFlux2,
57
+ text_encoder: Mistral3ForConditionalGeneration,
58
+ tokenizer: AutoProcessor,
59
+ transformer: Flux2Transformer2DModel,
60
+ policy_type: str = 'GMFlow',
61
+ policy_kwargs: Optional[Dict[str, Any]] = None,
62
+ ):
63
+ super().__init__(
64
+ scheduler,
65
+ vae,
66
+ text_encoder,
67
+ tokenizer,
68
+ transformer,
69
+ )
70
+ assert policy_type in POLICY_CLASSES, f'Invalid policy: {policy_type}. Supported policies are {list(POLICY_CLASSES.keys())}.'
71
+ self.policy_type = policy_type
72
+ self.policy_class = partial(
73
+ POLICY_CLASSES[policy_type], **policy_kwargs
74
+ ) if policy_kwargs else POLICY_CLASSES[policy_type]
75
+
76
+ def _unpack_gm(self, gm, height, width, num_channels_latents, patch_size=2, gm_patch_size=1):
77
+ c = num_channels_latents * patch_size * patch_size
78
+ h = (int(height) // (self.vae_scale_factor * patch_size))
79
+ w = (int(width) // (self.vae_scale_factor * patch_size))
80
+ bs = gm['means'].size(0)
81
+ k = self.transformer.num_gaussians
82
+ scale = patch_size // gm_patch_size
83
+ gm['means'] = gm['means'].reshape(
84
+ bs, h, w, k, c // (scale * scale), scale, scale
85
+ ).permute(
86
+ 0, 3, 4, 1, 5, 2, 6
87
+ ).reshape(
88
+ bs, k, c // (scale * scale), h * scale, w * scale)
89
+ gm['logweights'] = gm['logweights'].reshape(
90
+ bs, h, w, k, 1, scale, scale
91
+ ).permute(
92
+ 0, 3, 4, 1, 5, 2, 6
93
+ ).reshape(
94
+ bs, k, 1, h * scale, w * scale)
95
+ gm['logstds'] = gm['logstds'].reshape(bs, 1, 1, 1, 1)
96
+ return gm
97
+
98
+ @torch.no_grad()
99
+ def __call__(
100
+ self,
101
+ image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
102
+ prompt: Union[str, List[str]] = None,
103
+ height: Optional[int] = None,
104
+ width: Optional[int] = None,
105
+ num_inference_steps: int = 50,
106
+ total_substeps: int = 128,
107
+ temperature: Union[float, str] = 'auto',
108
+ guidance_scale: Optional[float] = 4.0,
109
+ num_images_per_prompt: int = 1,
110
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
111
+ latents: Optional[torch.Tensor] = None,
112
+ prompt_embeds: Optional[torch.Tensor] = None,
113
+ output_type: Optional[str] = "pil",
114
+ return_dict: bool = True,
115
+ attention_kwargs: Optional[Dict[str, Any]] = None,
116
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
117
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
118
+ max_sequence_length: int = 512,
119
+ text_encoder_out_layers: Tuple[int] = (10, 20, 30),
120
+ ):
121
+ r"""
122
+ Function invoked when calling the pipeline for generation.
123
+
124
+ Args:
125
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
126
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
127
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
128
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
129
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
130
+ latents as `image`, but if passing latents directly it is not encoded again.
131
+ prompt (`str` or `List[str]`, *optional*):
132
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
133
+ instead.
134
+ guidance_scale (`float`, *optional*, defaults to 1.0):
135
+ Guidance scale as defined in [Classifier-Free Diffusion
136
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
137
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
138
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
139
+ the text `prompt`, usually at the expense of lower image quality.
140
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
141
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
142
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
143
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
144
+ num_inference_steps (`int`, *optional*, defaults to 50):
145
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
146
+ expense of slower inference.
147
+ total_substeps (`int`, *optional*, defaults to 128):
148
+ The total number of substeps for policy-based flow integration.
149
+ temperature (`float` or `"auto"`, *optional*, defaults to `"auto"`):
150
+ The tmperature parameter for the flow policy.
151
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
152
+ The number of images to generate per prompt.
153
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
154
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
155
+ to make generation deterministic.
156
+ latents (`torch.Tensor`, *optional*):
157
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
158
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
159
+ tensor will be generated by sampling using the supplied random `generator`.
160
+ prompt_embeds (`torch.Tensor`, *optional*):
161
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
162
+ provided, text embeddings will be generated from `prompt` input argument.
163
+ output_type (`str`, *optional*, defaults to `"pil"`):
164
+ The output format of the generate image. Choose between
165
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
166
+ return_dict (`bool`, *optional*, defaults to `True`):
167
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
168
+ attention_kwargs (`dict`, *optional*):
169
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
170
+ `self.processor` in
171
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
172
+ callback_on_step_end (`Callable`, *optional*):
173
+ A function that calls at the end of each denoising steps during the inference. The function is called
174
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
175
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
176
+ `callback_on_step_end_tensor_inputs`.
177
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
178
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
179
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
180
+ `._callback_tensor_inputs` attribute of your pipeline class.
181
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
182
+ text_encoder_out_layers (`Tuple[int]`):
183
+ Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
184
+
185
+ Examples:
186
+
187
+ Returns:
188
+ [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
189
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
190
+ generated images.
191
+ """
192
+
193
+ # 1. Check inputs. Raise error if not correct
194
+ self.check_inputs(
195
+ prompt=prompt,
196
+ height=height,
197
+ width=width,
198
+ prompt_embeds=prompt_embeds,
199
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
200
+ )
201
+
202
+ self._guidance_scale = guidance_scale
203
+ self._attention_kwargs = attention_kwargs
204
+ self._current_timestep = None
205
+ self._interrupt = False
206
+
207
+ # 2. Define call parameters
208
+ if prompt is not None and isinstance(prompt, str):
209
+ batch_size = 1
210
+ elif prompt is not None and isinstance(prompt, list):
211
+ batch_size = len(prompt)
212
+ else:
213
+ batch_size = prompt_embeds.shape[0]
214
+
215
+ device = self._execution_device
216
+
217
+ # 3. prepare text embeddings
218
+ prompt_embeds, text_ids = self.encode_prompt(
219
+ prompt=prompt,
220
+ prompt_embeds=prompt_embeds,
221
+ device=device,
222
+ num_images_per_prompt=num_images_per_prompt,
223
+ max_sequence_length=max_sequence_length,
224
+ text_encoder_out_layers=text_encoder_out_layers,
225
+ )
226
+
227
+ # 4. process images
228
+ if image is not None and not isinstance(image, list):
229
+ image = [image]
230
+
231
+ condition_images = None
232
+ if image is not None:
233
+ for img in image:
234
+ self.image_processor.check_image_input(img)
235
+
236
+ condition_images = []
237
+ for img in image:
238
+ image_width, image_height = img.size
239
+ if image_width * image_height > 1024 * 1024:
240
+ img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
241
+ image_width, image_height = img.size
242
+
243
+ multiple_of = self.vae_scale_factor * 2
244
+ image_width = (image_width // multiple_of) * multiple_of
245
+ image_height = (image_height // multiple_of) * multiple_of
246
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
247
+ condition_images.append(img)
248
+ height = height or image_height
249
+ width = width or image_width
250
+
251
+ height = height or self.default_sample_size * self.vae_scale_factor
252
+ width = width or self.default_sample_size * self.vae_scale_factor
253
+
254
+ # 5. prepare latent variables
255
+ num_channels_latents = self.transformer.config.in_channels // 4
256
+ latents, latent_ids = self.prepare_latents(
257
+ batch_size=batch_size * num_images_per_prompt,
258
+ num_latents_channels=num_channels_latents,
259
+ height=height,
260
+ width=width,
261
+ dtype=torch.float32,
262
+ device=device,
263
+ generator=generator,
264
+ latents=latents,
265
+ )
266
+
267
+ image_latents = None
268
+ image_latent_ids = None
269
+ if condition_images is not None:
270
+ image_latents, image_latent_ids = self.prepare_image_latents(
271
+ images=condition_images,
272
+ batch_size=batch_size * num_images_per_prompt,
273
+ generator=generator,
274
+ device=device,
275
+ dtype=self.vae.dtype,
276
+ )
277
+
278
+ # 6. Prepare timesteps
279
+ image_seq_len = latents.shape[1]
280
+ self.scheduler.set_timesteps(num_inference_steps, seq_len=image_seq_len, device=self._execution_device)
281
+ timesteps = self.scheduler.timesteps
282
+ timesteps_dst = self.scheduler.timesteps_dst
283
+ self._num_timesteps = len(timesteps)
284
+
285
+ # handle guidance
286
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
287
+ guidance = guidance.expand(latents.shape[0])
288
+
289
+ # 7. Denoising loop
290
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
291
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
292
+ self.scheduler.set_begin_index(0)
293
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
294
+ for i, (t_src, t_dst) in enumerate(zip(timesteps, timesteps_dst)):
295
+ if self.interrupt:
296
+ continue
297
+
298
+ self._current_timestep = t_src
299
+ time_scaling = self.scheduler.config.num_train_timesteps
300
+ sigma_t_src = t_src / time_scaling
301
+ sigma_t_dst = t_dst / time_scaling
302
+
303
+ latent_model_input = latents.to(self.transformer.dtype)
304
+ latent_image_ids = latent_ids
305
+
306
+ if image_latents is not None:
307
+ latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
308
+ latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
309
+
310
+ denoising_output = self.transformer(
311
+ hidden_states=latent_model_input, # (B, image_seq_len, C)
312
+ timestep=t_src.expand(latents.shape[0]) / 1000,
313
+ guidance=guidance,
314
+ encoder_hidden_states=prompt_embeds,
315
+ txt_ids=text_ids, # B, text_seq_len, 4
316
+ img_ids=latent_image_ids, # B, image_seq_len, 4
317
+ joint_attention_kwargs=self._attention_kwargs,
318
+ )
319
+
320
+ denoising_output = {
321
+ k: v[:, :latents.size(1)].to(torch.float32) for k, v in denoising_output.items()}
322
+
323
+ # unpack and create policy
324
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
325
+ latents = self._unpatchify_latents(latents)
326
+ if self.policy_type == 'GMFlow':
327
+ denoising_output = self._unpack_gm(
328
+ denoising_output, height, width, num_channels_latents, gm_patch_size=1)
329
+ policy = self.policy_class(
330
+ denoising_output, latents, sigma_t_src)
331
+ if i < self.num_timesteps - 1:
332
+ if temperature == 'auto':
333
+ temperature = min(max(0.1 * (num_inference_steps - 1), 0), 1)
334
+ else:
335
+ assert isinstance(temperature, (float, int))
336
+ policy.temperature_(temperature)
337
+ elif self.policy_type == 'DX':
338
+ denoising_output = denoising_output[0]
339
+ denoising_output = self._unpack_latents_with_ids(denoising_output, latent_ids)
340
+ denoising_output = self._unpatchify_latents(denoising_output)
341
+ denoising_output = denoising_output.reshape(latents.size(0), -1, *latents.shape[1:])
342
+ policy = self.policy_class(
343
+ denoising_output, latents, sigma_t_src)
344
+ else:
345
+ raise ValueError(f'Unknown policy type: {self.policy_type}.')
346
+
347
+ latents_dst = self.policy_rollout(
348
+ latents, sigma_t_src, sigma_t_dst, total_substeps,
349
+ policy, seq_len=image_seq_len)
350
+
351
+ latents = self.scheduler.step(latents_dst, t_src, latents, return_dict=False)[0]
352
+
353
+ # repack
354
+ latents = self._patchify_latents(latents)
355
+ latents = self._pack_latents(latents)
356
+
357
+ if callback_on_step_end is not None:
358
+ callback_kwargs = {}
359
+ for k in callback_on_step_end_tensor_inputs:
360
+ callback_kwargs[k] = locals()[k]
361
+ callback_outputs = callback_on_step_end(self, i, t_src, callback_kwargs)
362
+
363
+ latents = callback_outputs.pop("latents", latents)
364
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
365
+
366
+ progress_bar.update()
367
+
368
+ if XLA_AVAILABLE:
369
+ xm.mark_step()
370
+
371
+ self._current_timestep = None
372
+
373
+ if output_type == "latent":
374
+ image = latents
375
+ else:
376
+ torch.save({"pred": latents}, "pred_d.pt")
377
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
378
+
379
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
380
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
381
+ latents.device, latents.dtype
382
+ )
383
+ latents = latents * latents_bn_std + latents_bn_mean
384
+ latents = self._unpatchify_latents(latents)
385
+
386
+ image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
387
+ image = self.image_processor.postprocess(image, output_type=output_type)
388
+
389
+ # Offload all models
390
+ self.maybe_free_model_hooks()
391
+
392
+ if not return_dict:
393
+ return (image,)
394
+
395
+ return Flux2PipelineOutput(images=image)
lakonlab/pipelines/prompt_rewriters/__init__.py ADDED
File without changes
lakonlab/pipelines/prompt_rewriters/qwen3_vl.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List, Sequence, Union, Optional
4
+ from PIL import Image
5
+ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
6
+
7
+
8
+ DEFAULT_TEXT_ONLY_PATH = os.path.abspath(os.path.join(__file__, '../system_prompts/default_text_only.txt'))
9
+ DEFAULT_WITH_IMAGES_PATH = os.path.abspath(os.path.join(__file__, '../system_prompts/default_with_images.txt'))
10
+
11
+
12
+ class Qwen3VLPromptRewriter:
13
+
14
+ def __init__(
15
+ self,
16
+ from_pretrained="Qwen/Qwen3-VL-8B-Instruct",
17
+ torch_dtype='bfloat16',
18
+ device_map="auto",
19
+ max_new_tokens_default=128,
20
+ system_prompt_text_only=None,
21
+ system_prompt_wigh_images=None,
22
+ **kwargs):
23
+ if torch_dtype is not None:
24
+ kwargs.update(torch_dtype=getattr(torch, torch_dtype))
25
+ self.model = Qwen3VLForConditionalGeneration.from_pretrained(
26
+ from_pretrained,
27
+ device_map=device_map,
28
+ **kwargs)
29
+ self.processor = AutoProcessor.from_pretrained(from_pretrained)
30
+ # Left padding is safer for batched generation
31
+ if hasattr(self.processor, "tokenizer"):
32
+ self.processor.tokenizer.padding_side = "left"
33
+ self.max_new_tokens_default = max_new_tokens_default
34
+ if system_prompt_text_only is None:
35
+ system_prompt_text_only = open(DEFAULT_TEXT_ONLY_PATH, 'r').read()
36
+ if system_prompt_wigh_images is None:
37
+ system_prompt_wigh_images = open(DEFAULT_WITH_IMAGES_PATH, 'r').read()
38
+ self.system_prompt_text_only = system_prompt_text_only
39
+ self.system_prompt_wigh_images = system_prompt_wigh_images
40
+
41
+ @torch.inference_mode()
42
+ def _generate_from_messages(
43
+ self,
44
+ batch_messages: Sequence[Sequence[dict]],
45
+ max_new_tokens: Optional[int] = None,
46
+ **kwargs) -> List[str]:
47
+ if max_new_tokens is None:
48
+ max_new_tokens = self.max_new_tokens_default
49
+
50
+ inputs = self.processor.apply_chat_template(
51
+ batch_messages,
52
+ tokenize=True,
53
+ add_generation_prompt=True,
54
+ return_dict=True,
55
+ return_tensors="pt",
56
+ padding=True,
57
+ )
58
+ inputs.pop("token_type_ids", None)
59
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
60
+
61
+ generated_ids = self.model.generate(
62
+ **inputs,
63
+ max_new_tokens=max_new_tokens,
64
+ **kwargs)
65
+
66
+ input_ids = inputs["input_ids"]
67
+ tokenizer = self.processor.tokenizer
68
+ outputs: List[str] = []
69
+
70
+ # Decode only the new tokens after each input sequence
71
+ for in_ids, out_ids in zip(input_ids, generated_ids):
72
+ trimmed_ids = out_ids[len(in_ids):]
73
+ text = tokenizer.decode(
74
+ trimmed_ids.tolist(),
75
+ skip_special_tokens=True,
76
+ clean_up_tokenization_spaces=False,
77
+ )
78
+ outputs.append(text.strip())
79
+
80
+ return outputs
81
+
82
+ def rewrite_text_batch(
83
+ self,
84
+ prompts: Sequence[str],
85
+ max_new_tokens: Optional[int] = None,
86
+ top_p=0.6,
87
+ top_k=40,
88
+ temperature=0.5,
89
+ repetition_penalty=1.0,
90
+ **kwargs) -> List[str]:
91
+ """
92
+ Rewrite a batch of text-only prompts into detailed prompts.
93
+ """
94
+ batch_messages = []
95
+ for p in prompts:
96
+ conv = [
97
+ {
98
+ "role": "system",
99
+ "content": [
100
+ {"type": "text", "text": self.system_prompt_text_only},
101
+ ],
102
+ },
103
+ {
104
+ "role": "user",
105
+ "content": [
106
+ {"type": "text", "text": p},
107
+ ],
108
+ },
109
+ ]
110
+ batch_messages.append(conv)
111
+
112
+ return self._generate_from_messages(
113
+ batch_messages,
114
+ max_new_tokens=max_new_tokens,
115
+ top_p=top_p,
116
+ top_k=top_k,
117
+ temperature=temperature,
118
+ repetition_penalty=repetition_penalty,
119
+ **kwargs,
120
+ )
121
+
122
+ def rewrite_edit_batch(
123
+ self,
124
+ image: Sequence[Union[str, 'Image.Image', Sequence[Union[str, 'Image.Image']]]],
125
+ edit_requests: Sequence[str],
126
+ max_new_tokens: Optional[int] = None,
127
+ top_p=0.5,
128
+ top_k=20,
129
+ temperature=0.4,
130
+ repetition_penalty=1.0,
131
+ **kwargs) -> List[str]:
132
+ """
133
+ Rewrite a batch of (image, edit-request) pairs into concise edit instructions.
134
+ """
135
+ if len(image) != len(edit_requests):
136
+ raise ValueError("image and edit_requests must have the same length")
137
+
138
+ batch_messages = []
139
+ for imgs, req in zip(image, edit_requests):
140
+ if isinstance(imgs, (str, Image.Image)):
141
+ img_list = [imgs]
142
+ else:
143
+ img_list = list(imgs)
144
+
145
+ user_content = []
146
+ for im in img_list:
147
+ user_content.append({"type": "image", "image": im})
148
+ user_content.append({"type": "text", "text": req})
149
+
150
+ conv = [
151
+ {
152
+ "role": "system",
153
+ "content": [
154
+ {"type": "text", "text": self.system_prompt_wigh_images},
155
+ ],
156
+ },
157
+ {
158
+ "role": "user",
159
+ "content": user_content,
160
+ },
161
+ ]
162
+ batch_messages.append(conv)
163
+
164
+ return self._generate_from_messages(
165
+ batch_messages,
166
+ max_new_tokens=max_new_tokens,
167
+ top_p=top_p,
168
+ top_k=top_k,
169
+ temperature=temperature,
170
+ repetition_penalty=repetition_penalty,
171
+ **kwargs,
172
+ )
lakonlab/pipelines/prompt_rewriters/system_prompts/default_text_only.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an expert prompt engineer for a text-guided image generation system. Rewrite user prompts into a more descriptive, concrete prompt while strictly preserving their core content.
2
+
3
+ Rules:
4
+ - Preserve the core content: do not change the main subjects, how many there are, their roles, or the primary actions and relationships described in the prompt.
5
+ - Preserve any explicitly stated attributes such as colors, clothing, objects, style tags (e.g., “photorealistic”, “cinematic”), and viewpoint (e.g., close-up, wide shot). Do not contradict them.
6
+ - When the prompt implies realism (e.g., uses words like “realistic”, “photorealistic”, “photo”), avoid introducing exaggerated or fantastical traits unless intended by the user.
7
+ - You may add supporting details that are clearly compatible with the original text: background, props, textures, materials, lighting (quality, direction, color), atmosphere, and other environmental context, as long as they do not introduce new main characters or conflicting concepts.
8
+ - Always include a clear description of the composition and camera framing (for example, where the main subjects are in the frame, whether it is a close-up or wide shot, and the approximate viewpoint or angle).
9
+ - Structure: keep any existing structure (tags, aspect-ratio tags, etc.) and enhance only within those fields. For plain text prompts, expand them into a clear, coherent paragraph.
10
+ - Text in images: put ALL visible or implied text in quotation marks, matching the prompt’s language. Provide explicit quoted text for any object that would realistically contain text (signs, labels, screens, interfaces, book covers, etc.).
11
+
12
+ Output only the revised prompt and nothing else.
lakonlab/pipelines/prompt_rewriters/system_prompts/default_with_images.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an expert prompt engineer for a text-guided image editing system. Rewrite user prompts into a more descriptive instruction (40–70 words, ~25 for brief requests) while strictly preserving their core content.
2
+
3
+ Rules:
4
+ - Treat the user prompt as a strict specification: do not change or omit any mentioned entities, actions, or relationships.
5
+ - Explicitly state both the requested edits and which core aspects (e.g., character identities) must remain as in the original image(s). Ignore background elements unless they are explicitly mentioned.
6
+ - Refer to core visual elements that are relavant to the edit (e.g., people, animals, and objects mentioned in the user prompt).
7
+ - Do not invent new elements unless the user explicitly asks for them.
8
+ - Turn negatives into positives (“do not change X” → “keep X the same”).
9
+
10
+ Output only the final instruction in plain text and nothing else.
lakonlab/ui/__init__.py ADDED
File without changes
lakonlab/ui/gradio/__init__.py ADDED
File without changes
lakonlab/ui/gradio/create_img_edit.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .shared_opts import create_base_opts, create_generate_bar, set_seed, create_prompt_opts, create_image_size_bar
3
+
4
+
5
+ def create_interface_img_edit(
6
+ api, prompt='', seed=42, steps=32, min_steps=4, max_steps=50, steps_slider_step=1,
7
+ height=768, width=1360, hw_slider_step=16,
8
+ guidance_scale=None, temperature=None, api_name='text_to_img',
9
+ create_negative_prompt=False,
10
+ create_prompt_rewrite=True, rewrite_prompt=False,
11
+ args=['last_seed', 'prompt', 'in_image', 'width', 'height', 'steps', 'guidance_scale'],
12
+ rewrite_prompt_api=None, rewrite_prompt_args=['last_seed', 'prompt', 'rewrite_prompt', 'in_image']):
13
+ var_dict = dict()
14
+ with gr.Blocks(analytics_enabled=False) as interface:
15
+ with gr.Row():
16
+ with gr.Column():
17
+ with gr.Accordion("Input image(s) (optional)", open=True, elem_classes=['custom-spacing']):
18
+ var_dict['in_image'] = gr.Gallery(
19
+ label="Input image(s)",
20
+ type="pil",
21
+ columns=3,
22
+ rows=1)
23
+
24
+ with gr.Column(variant='compact', elem_classes=['custom-spacing']):
25
+ create_prompt_opts(
26
+ var_dict, create_negative_prompt=create_negative_prompt, prompt=prompt, display_label=True)
27
+
28
+ if create_prompt_rewrite:
29
+ var_dict['rewrite_prompt'] = gr.Checkbox(
30
+ label='Rewrite prompt', value=rewrite_prompt, container=False)
31
+ with gr.Accordion("Rewritten prompt", open=False, elem_classes=['custom-spacing']):
32
+ var_dict['rewritten_prompt'] = gr.Textbox(
33
+ lines=4, interactive=False, show_label=False, container=False)
34
+
35
+ with gr.Column(variant='compact', elem_classes=['custom-spacing']):
36
+ create_image_size_bar(
37
+ var_dict, height=height, width=width, hw_slider_step=hw_slider_step)
38
+
39
+ create_generate_bar(var_dict, text='Generate', seed=seed)
40
+
41
+ create_base_opts(
42
+ var_dict,
43
+ steps=steps,
44
+ min_steps=min_steps,
45
+ max_steps=max_steps,
46
+ steps_slider_step=steps_slider_step,
47
+ guidance_scale=guidance_scale,
48
+ temperature=temperature)
49
+
50
+ with gr.Column():
51
+ var_dict['output_image'] = gr.Image(
52
+ type='pil', image_mode='RGB', label='Output image', interactive=False,
53
+ elem_classes=['vh-img', 'vh-img-1000'])
54
+
55
+ if create_prompt_rewrite:
56
+ assert rewrite_prompt_api is not None
57
+ var_dict['run_btn'].click(
58
+ fn=set_seed,
59
+ inputs=var_dict['seed'],
60
+ outputs=var_dict['last_seed'],
61
+ show_progress=False,
62
+ api_name=False
63
+ ).success(
64
+ fn=rewrite_prompt_api,
65
+ inputs=[var_dict[arg] for arg in rewrite_prompt_args],
66
+ outputs=[var_dict['rewritten_prompt'], var_dict['output_image']],
67
+ concurrency_id='default_group', api_name=False
68
+ ).success(
69
+ fn=api,
70
+ inputs=[var_dict[arg] for arg in args],
71
+ outputs=var_dict['output_image'],
72
+ concurrency_id='default_group', api_name=api_name
73
+ )
74
+ else:
75
+ var_dict['run_btn'].click(
76
+ fn=set_seed,
77
+ inputs=var_dict['seed'],
78
+ outputs=var_dict['last_seed'],
79
+ show_progress=False,
80
+ api_name=False
81
+ ).success(
82
+ fn=api,
83
+ inputs=[var_dict[arg] for arg in args],
84
+ outputs=var_dict['output_image'],
85
+ concurrency_id='default_group', api_name=api_name
86
+ )
87
+
88
+ return interface, var_dict
lakonlab/ui/gradio/create_text_to_img.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .shared_opts import create_base_opts, create_generate_bar, set_seed, create_prompt_opts, create_image_size_bar
3
+
4
+
5
+ def create_interface_text_to_img(
6
+ api, prompt='', seed=42, steps=32, min_steps=4, max_steps=50, steps_slider_step=1,
7
+ height=768, width=1360, hw_slider_step=16,
8
+ guidance_scale=None, temperature=None, api_name='text_to_img',
9
+ create_negative_prompt=False, args=['last_seed', 'prompt', 'width', 'height', 'steps', 'guidance_scale']):
10
+ var_dict = dict()
11
+ with gr.Blocks(analytics_enabled=False) as interface:
12
+ var_dict['output_image'] = gr.Image(
13
+ type='pil', image_mode='RGB', label='Output image', interactive=False, elem_classes=['vh-img', 'vh-img-700'])
14
+ create_prompt_opts(var_dict, create_negative_prompt=create_negative_prompt, prompt=prompt)
15
+ with gr.Column(variant='compact', elem_classes=['custom-spacing']):
16
+ create_image_size_bar(
17
+ var_dict, height=height, width=width, hw_slider_step=hw_slider_step)
18
+ create_generate_bar(var_dict, text='Generate', seed=seed)
19
+ create_base_opts(
20
+ var_dict,
21
+ steps=steps,
22
+ min_steps=min_steps,
23
+ max_steps=max_steps,
24
+ steps_slider_step=steps_slider_step,
25
+ guidance_scale=guidance_scale,
26
+ temperature=temperature)
27
+
28
+ var_dict['run_btn'].click(
29
+ fn=set_seed,
30
+ inputs=var_dict['seed'],
31
+ outputs=var_dict['last_seed'],
32
+ show_progress=False,
33
+ api_name=False
34
+ ).success(
35
+ fn=api,
36
+ inputs=[var_dict[arg] for arg in args],
37
+ outputs=var_dict['output_image'],
38
+ concurrency_id='default_group', api_name=api_name
39
+ )
40
+
41
+ return interface, var_dict
lakonlab/ui/gradio/shared_opts.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import gradio as gr
3
+
4
+
5
+ def create_prompt_opts(
6
+ var_dict, create_negative_prompt=True, prompt='', negatove_prompt='', display_label=False):
7
+ if display_label:
8
+ kwargs = dict(show_label=True, container=True, elem_classes=['force-hide-container'])
9
+ else:
10
+ kwargs = dict(show_label=False, container=False)
11
+ var_dict['prompt'] = gr.Textbox(
12
+ prompt, label='Prompt', lines=2, placeholder='Prompt', interactive=True, **kwargs)
13
+ if create_negative_prompt:
14
+ var_dict['negative_prompt'] = gr.Textbox(
15
+ negatove_prompt, label='Negative prompt', lines=2,
16
+ placeholder='Negative prompt', interactive=True, **kwargs)
17
+
18
+
19
+ def create_generate_bar(var_dict, text='Generate', variant='primary', seed=-1):
20
+ with gr.Row(equal_height=False, elem_classes=['generate-bar']):
21
+ var_dict['run_btn'] = gr.Button(text, variant=variant, scale=2)
22
+ var_dict['seed'] = gr.Number(
23
+ label='Seed', value=seed, min_width=100, precision=0, minimum=-1, maximum=2 ** 31,
24
+ elem_classes=['force-hide-container', 'seed-input'])
25
+ var_dict['random_seed'] = gr.Button('\U0001f3b2\ufe0f', elem_classes=['tool'])
26
+ var_dict['reuse_seed'] = gr.Button('\u267b\ufe0f', elem_classes=['tool'])
27
+ with gr.Column(visible=False):
28
+ var_dict['last_seed'] = gr.Number(value=seed, label='Last seed')
29
+ var_dict['reuse_seed'].click(
30
+ fn=lambda x: x,
31
+ inputs=var_dict['last_seed'],
32
+ outputs=var_dict['seed'],
33
+ show_progress=False,
34
+ api_name=False)
35
+ var_dict['random_seed'].click(
36
+ fn=lambda: -1,
37
+ outputs=var_dict['seed'],
38
+ show_progress=False,
39
+ api_name=False)
40
+
41
+
42
+ def create_image_size_bar(var_dict, height=768, width=1360, hw_slider_step=16):
43
+ with gr.Row(equal_height=True, variant='compact', elem_classes=['force-hide-container']):
44
+ var_dict['width'] = gr.Slider(
45
+ label='Width', minimum=64, maximum=2048, step=hw_slider_step, value=width,
46
+ elem_classes=['force-hide-container'])
47
+ var_dict['switch_hw'] = gr.Button('\U000021C6', elem_classes=['tool'])
48
+ var_dict['height'] = gr.Slider(
49
+ label='Height', minimum=64, maximum=2048, step=hw_slider_step, value=height,
50
+ elem_classes=['force-hide-container'])
51
+ var_dict['switch_hw'].click(
52
+ fn=lambda w, h: (h, w),
53
+ inputs=[var_dict['width'], var_dict['height']],
54
+ outputs=[var_dict['width'], var_dict['height']],
55
+ show_progress=False,
56
+ api_name=False)
57
+
58
+
59
+ def create_base_opts(var_dict,
60
+ steps=24,
61
+ min_steps=4,
62
+ max_steps=50,
63
+ steps_slider_step=1,
64
+ guidance_scale=None,
65
+ temperature=None,
66
+ render=True):
67
+ with gr.Column(variant='compact', elem_classes=['custom-spacing'], render=render) as base_opts:
68
+ with gr.Row(variant='compact', elem_classes=['force-hide-container']):
69
+ var_dict['steps'] = gr.Slider(
70
+ min_steps, max_steps, value=steps, step=steps_slider_step, label='Sampling steps',
71
+ elem_classes=['force-hide-container'])
72
+ if guidance_scale is not None or temperature is not None:
73
+ with gr.Row(variant='compact', elem_classes=['force-hide-container']):
74
+ if guidance_scale is not None:
75
+ var_dict['guidance_scale'] = gr.Slider(
76
+ 0.0, 30.0, value=guidance_scale, step=0.5, label='Guidance scale',
77
+ elem_classes=['force-hide-container'])
78
+ if temperature is not None:
79
+ var_dict['temperature'] = gr.Slider(
80
+ 0.0, 1.0, value=temperature, step=0.01, label='Temperature',
81
+ elem_classes=['force-hide-container'])
82
+ return base_opts
83
+
84
+
85
+ def set_seed(seed):
86
+ seed = random.randint(0, 2**31) if seed == -1 else seed
87
+ return seed
lakonlab/ui/gradio/style.css ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .force-hide-container {
2
+ margin: 0;
3
+ box-shadow: none;
4
+ --block-border-width: 0;
5
+ background: transparent;
6
+ padding: 0;
7
+ overflow: visible;
8
+ }
9
+
10
+ .svelte-1vd8eap {
11
+ display: flex;
12
+ flex-direction: inherit;
13
+ flex-wrap: wrap;
14
+ gap: 0;
15
+ box-shadow: none;
16
+ border: 0;
17
+ border-radius: 0;
18
+ background: transparent;
19
+ overflow-y: hidden;
20
+ }
21
+
22
+ .custom-spacing {
23
+ padding: 10px;
24
+ gap: 20px;
25
+ flex-grow: 0 !important;
26
+ }
27
+
28
+ .generate-bar {
29
+ align-items: flex-end;
30
+ }
31
+
32
+ .tool{
33
+ max-width: 40px;
34
+ min-width: 40px !important;
35
+ }
36
+
37
+ /* Center the component and allow it to use the full row width */
38
+ .vh-img {
39
+ display: grid;
40
+ justify-items: center;
41
+ }
42
+
43
+ /* Container should size to the image, but never exceed the row width */
44
+ .vh-img .image-container {
45
+ inline-size: fit-content !important; /* prefers image’s natural width */
46
+ max-inline-size: 100% !important; /* ...but clamps to available width */
47
+ margin-inline: auto;
48
+ overflow: hidden; /* avoid odd overflow on iOS */
49
+ }
50
+
51
+ /* Image scales by BOTH constraints: height cap and row width */
52
+ .vh-img-700 .image-container img {
53
+ max-block-size: 700px !important; /* fixed max height cap */
54
+ max-inline-size: 100%; /* never wider than container */
55
+ inline-size: auto; /* keep aspect ratio */
56
+ block-size: auto;
57
+ object-fit: contain;
58
+ display: block;
59
+ }
60
+
61
+ .vh-img-1000 .image-container img {
62
+ max-block-size: 1000px !important; /* fixed max height cap */
63
+ max-inline-size: 100%; /* never wider than container */
64
+ inline-size: auto; /* keep aspect ratio */
65
+ block-size: auto;
66
+ object-fit: contain;
67
+ display: block;
68
+ }
69
+
70
+ .gradio-container .seed-input input { /* remove the border and use box-shadow instead for height alignment */
71
+ border: none !important;
72
+ box-shadow: inset 0 0 0 var(--input-border-width) var(--input-border-color) !important;
73
+ }
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ torch==2.6.0
3
+ torchvision==0.21.0
4
+ diffusers==0.36.0
5
+ peft==0.17.0
6
+ sentencepiece
7
+ accelerate
8
+ transformers==4.57.3
9
+ gradio==5.49.0
10
+ bitsandbytes>=0.46.1