Qwen-Image-Edit-2509 Lineart Interpolation

This is a LoRA weight for lineart interpolation, trained on randomly selected 10% of the train subset of Mixamo 240 dataset. The number of steps was 3000.

⚠️ Notes

  • Still an attempt phase.
  • Verification not enough, but works in some examples.

Quick start

import os
import math
import torch
import numpy as np
from PIL import Image

from diffusers import (
    QwenImageEditPipeline,
    FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.torch_utils import randn_tensor

# --------- ユーザー環境に合わせてここを設定 ---------
BASE_MODEL_ID = "Qwen/Qwen-Image-Edit-2509"  # 学習時と同じ
LORA_DIR = "<PATH_TO_WEIGHT>"
DEVICE = "cuda"
DTYPE = torch.bfloat16

CONTROL1_IMAGE_PATH = <PATH_TO_START_IMAGE>
CONTROL2_IMAGE_PATH = <PATH_TO_END_IMAGE>
PROMPT = "<inbetween> middle frame" #Don't change
NEGATIVE_PROMPT = " "
NUM_STEPS = 30
SEED = 0
# ===================


def calculate_dimensions(target_area, ratio):
    width = math.sqrt(target_area * ratio)
    height = width / ratio

    width = round(width / 32) * 32
    height = round(height / 32) * 32

    return int(width), int(height), None


def calculate_shift(
    image_seq_len,
    base_seq_len: int = 256,
    max_seq_len: int = 4096,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
):
    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
    b = base_shift - m * base_seq_len
    mu = image_seq_len * m + b
    return mu


def retrieve_timesteps(
    scheduler,
    num_inference_steps: int = None,
    device: torch.device | str | None = None,
    timesteps=None,
    sigmas=None,
    **kwargs,
):
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")

    if timesteps is not None:
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    elif sigmas is not None:
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps

    return timesteps, num_inference_steps


def main():
    torch.manual_seed(SEED)

    # ---------- 1. ベースパイプライン + LoRA 読み込み ----------
    pipe: QwenImageEditPipeline = QwenImageEditPipeline.from_pretrained(
        BASE_MODEL_ID,
        torch_dtype=DTYPE,
    ).to(DEVICE)

    pipe.load_lora_weights(LORA_DIR, adapter_name="my_lora")
    pipe.set_adapters(["my_lora"], adapter_weights=[1.0]) 

    transformer = pipe.transformer.to(DEVICE, dtype=DTYPE)
    vae = pipe.vae.to(DEVICE, dtype=DTYPE)
    scheduler: FlowMatchEulerDiscreteScheduler = pipe.scheduler
    vae_scale_factor = pipe.vae_scale_factor
    # ------------------------------------------------------------

    # ---------- 2. コントロール画像読み込み & 前処理 ----------
    control1_img = Image.open(CONTROL1_IMAGE_PATH).convert("RGB")
    control2_img = Image.open(CONTROL2_IMAGE_PATH).convert("RGB")

    # Qwen公式と同じ resize ロジック
    image_size = control1_img.size  # (W, H)
    calculated_width, calculated_height, _ = calculate_dimensions(
        1024 * 1024, image_size[0] / image_size[1]
    )

    # VAE + patch pack 用に multiple_of で揃える
    multiple_of = vae_scale_factor * 2
    width = calculated_width // multiple_of * multiple_of
    height = calculated_height // multiple_of * multiple_of

    # 公式と同じ image_processor の使い方
    control1_resized = pipe.image_processor.resize(control1_img, calculated_height, calculated_width)
    control2_resized = pipe.image_processor.resize(control2_img, calculated_height, calculated_width)

    control1_px = pipe.image_processor.preprocess(
        control1_resized, calculated_height, calculated_width
    ).unsqueeze(2)  # [B,3,1,H,W]
    control2_px = pipe.image_processor.preprocess(
        control2_resized, calculated_height, calculated_width
    ).unsqueeze(2)

    control1_px = control1_px.to(DEVICE, DTYPE)
    control2_px = control2_px.to(DEVICE, DTYPE)

    bsz = control1_px.shape[0]
    num_channels_latents = transformer.config.in_channels // 4  # = vae.config.z_dim

    # ---------- 3. コントロール画像を latent token に ----------
    with torch.no_grad():
        # 公式の _encode_vae_image と同等の処理
        ctrl1_latents_5d = pipe._encode_vae_image(control1_px, generator=None)  # [B, C,1,H',W']
        ctrl2_latents_5d = pipe._encode_vae_image(control2_px, generator=None)

    H_lat, W_lat = ctrl1_latents_5d.shape[3], ctrl1_latents_5d.shape[4]

    # pack して transformer 入力用 token 形式に
    ctrl1_tokens = QwenImageEditPipeline._pack_latents(
        ctrl1_latents_5d, bsz, num_channels_latents, H_lat, W_lat
    )  # [B, N, C*4]
    ctrl2_tokens = QwenImageEditPipeline._pack_latents(
        ctrl2_latents_5d, bsz, num_channels_latents, H_lat, W_lat
    )

    # ---------- 4. target latent をランダム初期化 ----------
    # 公式の prepare_latents と同じ形状
    height_lat = 2 * (height // (vae_scale_factor * 2))
    width_lat = 2 * (width // (vae_scale_factor * 2))

    shape = (bsz, 1, num_channels_latents, height_lat, width_lat)
    latents_5d = randn_tensor(shape, device=DEVICE, dtype=DTYPE)
    latents = QwenImageEditPipeline._pack_latents(
        latents_5d, bsz, num_channels_latents, height_lat, width_lat
    )  # ここからはずっと [B,N,C*4] で回す

    # ---------- 5. テキスト埋め込み ----------
    with torch.no_grad():
        prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
            image=control1_resized,       # 公式と同様: resize 済み画像
            prompt=[PROMPT],
            device=DEVICE,
            num_images_per_prompt=1,
            max_sequence_length=1024,
        )
        txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()

    # 3ストリーム分の img_shapes(rotary 用)
    img_shapes = [[
        (1, height_lat // 2, width_lat // 2),  # target
        (1, height_lat // 2, width_lat // 2),  # control1
        (1, height_lat // 2, width_lat // 2),  # control2
    ]] * bsz

    # 必要であれば rotary 埋め込みを事前計算する実装の場合:
    # image_rotary_emb = transformer.pos_embed(img_shapes, txt_seq_lens, device=DEVICE)

    # ---------- 6. scheduler timesteps 準備 (公式準拠) ----------
    sigmas = np.linspace(1.0, 1.0 / NUM_STEPS, NUM_STEPS)
    image_seq_len = latents.shape[1]  # token 数

    mu = calculate_shift(
        image_seq_len,
        scheduler.config.get("base_image_seq_len", 256),
        scheduler.config.get("max_image_seq_len", 4096),
        scheduler.config.get("base_shift", 0.5),
        scheduler.config.get("max_shift", 1.15),
    )

    timesteps, _ = retrieve_timesteps(
        scheduler,
        NUM_STEPS,
        device=DEVICE,
        sigmas=sigmas,
        mu=mu,
    )

    scheduler.set_begin_index(0)

    # guidance は使わない(multi-control だけ)
    guidance = None

    # ---------- 7. 反復推論ループ ----------
    with torch.no_grad():
        for i, t in enumerate(timesteps):
            # transformer 入力: target + control1 + control2 を token 軸で concat
            latent_model_input = torch.cat([latents, ctrl1_tokens, ctrl2_tokens], dim=1)

            # timestep をバッチ分にブロードキャスト
            timestep_batch = t.expand(latents.shape[0]).to(latents.dtype)

            model_pred_all = transformer(
                hidden_states=latent_model_input,
                timestep=timestep_batch / 1000.0,     # 学習時と同じスケール
                guidance=guidance,
                encoder_hidden_states=prompt_embeds,
                encoder_hidden_states_mask=prompt_embeds_mask,
                img_shapes=img_shapes,                # あなたの diffusers 版ではこれでOK
                txt_seq_lens=txt_seq_lens,
                # もし新しい API なら:
                # image_rotary_emb=image_rotary_emb,
                # attention_kwargs=None,
                return_dict=False,
            )[0]  # [B, N_total, C*4]

            # 先頭の target 分だけ取り出す
            model_pred = model_pred_all[:, : latents.size(1)]

            latents_dtype = latents.dtype
            latents = scheduler.step(model_pred, t, latents, return_dict=False)[0]
            if latents.dtype != latents_dtype:
                latents = latents.to(latents_dtype)

    # ---------- 8. decode ----------
    # packed token -> 5D latent へ
    latents = QwenImageEditPipeline._unpack_latents(
        latents,
        height=height,
        width=width,
        vae_scale_factor=vae_scale_factor,
    )  # [B, C,1,H_lat,W_lat]

    latents = latents.to(vae.dtype)

    # latents_mean / std を戻す(公式と同じ)
    latents_mean = (
        torch.tensor(vae.config.latents_mean)
        .view(1, vae.config.z_dim, 1, 1, 1)
        .to(latents.device, latents.dtype)
    )
    latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(
        1, vae.config.z_dim, 1, 1, 1
    ).to(latents.device, latents.dtype)

    latents = latents / latents_std + latents_mean  # [B,C,1,H_lat,W_lat]

    # decode (T=1 の 0 フレームだけ使う)
    with torch.no_grad():
        image_latents = latents
        decoded = vae.decode(image_latents, return_dict=False)[0]  # [B,3,1,H,W]
        images = decoded[:, :, 0, :, :]                             # [B,3,H,W]

    # Qwen の image_processor で [-1,1] -> PIL
    images = pipe.image_processor.postprocess(images, output_type="pil")
    out: Image.Image = images[0]
    out.save("multi_control_from_controls.png")
    print("saved to multi_control_from_controls.png")


if __name__ == "__main__":
    main()

Acknowledgement

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for EQUES/qwen-image-edit-2509-lineart-interpolation

Adapter
(50)
this model