frogleo commited on
Commit
31dae9f
·
verified ·
1 Parent(s): 9d0cfb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -80,7 +80,6 @@ pipe.to(device)
80
  # flash-attn估计库估计更新了,导致冲突了,不使用预编译的了
81
  # spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3")
82
 
83
-
84
  def image_to_data_uri(img):
85
  buffered = io.BytesIO()
86
  img.save(buffered, format="PNG")
@@ -162,27 +161,32 @@ def get_duration(prompt_embeds, image_list, width, height, num_inference_steps,
162
  return max(65, num_inference_steps * step_duration + 10)
163
 
164
  @spaces.GPU(duration=get_duration)
165
- def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
166
  # Move embeddings to GPU only when inside the GPU decorated function
167
  prompt_embeds = prompt_embeds.to(device)
168
 
169
  generator = torch.Generator(device=device).manual_seed(seed)
170
-
171
- pipe_kwargs = {
172
- "prompt_embeds": prompt_embeds,
173
- "image": image_list,
174
- "num_inference_steps": num_inference_steps,
175
- "guidance_scale": guidance_scale,
176
- "generator": generator,
177
- "width": width,
178
- "height": height,
179
- }
180
-
181
  # Progress bar for the actual generation steps
182
  if progress:
183
  progress(0, desc="Starting generation...")
 
 
 
 
 
 
184
 
185
- image = pipe(**pipe_kwargs).images[0]
 
 
 
 
 
 
 
 
 
186
  return image
187
 
188
  def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
 
80
  # flash-attn估计库估计更新了,导致冲突了,不使用预编译的了
81
  # spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3")
82
 
 
83
  def image_to_data_uri(img):
84
  buffered = io.BytesIO()
85
  img.save(buffered, format="PNG")
 
161
  return max(65, num_inference_steps * step_duration + 10)
162
 
163
  @spaces.GPU(duration=get_duration)
164
+ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress()):
165
  # Move embeddings to GPU only when inside the GPU decorated function
166
  prompt_embeds = prompt_embeds.to(device)
167
 
168
  generator = torch.Generator(device=device).manual_seed(seed)
169
+
 
 
 
 
 
 
 
 
 
 
170
  # Progress bar for the actual generation steps
171
  if progress:
172
  progress(0, desc="Starting generation...")
173
+
174
+ def callback_fn(pipe, step, timestep, callback_kwargs):
175
+ print(f"[Step {step}] Timestep: {timestep}")
176
+ progress_value = (step+1.0)/num_inference_steps
177
+ progress(progress_value, desc=f"Image generating, {step + 1}/{num_inference_steps} steps")
178
+ return callback_kwargs
179
 
180
+ image = pipe(
181
+ prompt_embeds=prompt_embeds,
182
+ image=image_list,
183
+ num_inference_steps=num_inference_steps,
184
+ guidance_scale=guidance_scale,
185
+ generator=generator,
186
+ width=width,
187
+ height=height,
188
+ callback_on_step_end=callback_fn,
189
+ ).images[0]
190
  return image
191
 
192
  def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):