Arlet41 commited on
Commit
8e233ce
·
verified ·
1 Parent(s): 3681b53
Files changed (1) hide show
  1. app.py +52 -51
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import cv2
3
  import gradio as gr
4
  import mediapipe as mp
@@ -6,10 +7,7 @@ import numpy as np
6
  from PIL import Image
7
  from gradio_client import Client, handle_file
8
 
9
- # ------------------------------
10
- # PATHS
11
- # ------------------------------
12
- example_path = os.path.join(os.path.dirname(__file__), "example")
13
 
14
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
15
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
@@ -17,30 +15,35 @@ garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_lis
17
  human_list = os.listdir(os.path.join(example_path, "human"))
18
  human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
19
 
20
- # ------------------------------
21
- # MEDIAPIPE POSE
22
- # ------------------------------
23
  mp_pose = mp.solutions.pose
24
  pose = mp_pose.Pose(static_image_mode=True)
25
  mp_drawing = mp.solutions.drawing_utils
26
  mp_pose_landmark = mp_pose.PoseLandmark
27
 
 
28
  def detect_pose(image):
 
29
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
30
  result = pose.process(image_rgb)
31
 
32
  keypoints = {}
33
 
34
  if result.pose_landmarks:
 
35
  mp_drawing.draw_landmarks(image, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
36
 
 
37
  height, width, _ = image.shape
38
 
 
39
  landmark_indices = {
40
- "left_shoulder": mp_pose_landmark.LEFT_SHOULDER,
41
- "right_shoulder": mp_pose_landmark.RIGHT_SHOULDER,
42
- "left_hip": mp_pose_landmark.LEFT_HIP,
43
- "right_hip": mp_pose_landmark.RIGHT_HIP,
44
  }
45
 
46
  for name, index in landmark_indices.items():
@@ -48,16 +51,9 @@ def detect_pose(image):
48
  x, y = int(lm.x * width), int(lm.y * height)
49
  keypoints[name] = (x, y)
50
 
 
51
  cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
52
- cv2.putText(
53
- image,
54
- name,
55
- (x + 5, y - 5),
56
- cv2.FONT_HERSHEY_SIMPLEX,
57
- 0.5,
58
- (255, 255, 255),
59
- 1,
60
- )
61
 
62
  return image
63
 
@@ -65,12 +61,12 @@ def detect_pose(image):
65
  def align_clothing(body_img, clothing_img):
66
  image_rgb = cv2.cvtColor(body_img, cv2.COLOR_BGR2RGB)
67
  result = pose.process(image_rgb)
68
-
69
  output = body_img.copy()
70
 
71
  if result.pose_landmarks:
72
  h, w, _ = output.shape
73
 
 
74
  def get_point(landmark_id):
75
  lm = result.pose_landmarks.landmark[landmark_id]
76
  return int(lm.x * w), int(lm.y * h)
@@ -80,20 +76,28 @@ def align_clothing(body_img, clothing_img):
80
  left_hip = get_point(mp_pose_landmark.LEFT_HIP)
81
  right_hip = get_point(mp_pose_landmark.RIGHT_HIP)
82
 
83
- dst_pts = np.array(
84
- [left_shoulder, right_shoulder, right_hip, left_hip], dtype=np.float32
85
- )
 
 
 
 
86
 
 
87
  src_h, src_w = clothing_img.shape[:2]
88
- src_pts = np.array(
89
- [[0, 0], [src_w, 0], [src_w, src_h], [0, src_h]], dtype=np.float32
90
- )
91
-
 
 
 
 
92
  matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
93
- warped_clothing = cv2.warpPerspective(
94
- clothing_img, matrix, (w, h), borderMode=cv2.BORDER_TRANSPARENT
95
- )
96
 
 
97
  if clothing_img.shape[2] == 4:
98
  alpha = warped_clothing[:, :, 3] / 255.0
99
  for c in range(3):
@@ -117,46 +121,43 @@ def process_image(human_img_path, garm_img_path):
117
  vt_model_type="viton_hd",
118
  vt_garment_type="upper_body",
119
  vt_repaint=False,
120
- api_name="/leffa_predict_vt",
121
  )
122
 
123
  print(result)
124
  generated_image_path = result[0]
125
- print("generated_image_path " + generated_image_path)
126
-
127
  generated_image = Image.open(generated_image_path)
128
- return generated_image
129
 
 
130
 
131
- # ------------------------------
132
- # GRADIO UI
133
- # ------------------------------
134
 
135
  image_blocks = gr.Blocks().queue()
136
-
137
  with image_blocks as demo:
138
  gr.HTML("<center><h1>Virtual Try-On</h1></center>")
139
  gr.HTML("<center><p>Upload an image of a person and an image of a garment ✨</p></center>")
140
-
141
  with gr.Row():
142
  with gr.Column():
143
- human_img = gr.Image(type="filepath", label="Human", interactive=True)
144
- gr.Examples(inputs=human_img, examples_per_page=10, examples=human_list_path)
 
 
 
 
145
 
146
  with gr.Column():
147
- garm_img = gr.Image(type="filepath", label="Garment", interactive=True)
148
- gr.Examples(inputs=garm_img, examples_per_page=8, examples=garm_list_path)
149
-
 
 
150
  with gr.Column():
151
  image_out = gr.Image(label="Processed image", type="pil")
152
 
153
  with gr.Row():
154
- try_button = gr.Button(value="Try-on", variant="primary")
155
 
156
- try_button.click(
157
- fn=process_image,
158
- inputs=[human_img, garm_img],
159
- outputs=image_out,
160
- )
161
 
162
  image_blocks.launch(show_error=True)
 
1
  import os
2
+
3
  import cv2
4
  import gradio as gr
5
  import mediapipe as mp
 
7
  from PIL import Image
8
  from gradio_client import Client, handle_file
9
 
10
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
 
 
 
11
 
12
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
13
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
 
15
  human_list = os.listdir(os.path.join(example_path, "human"))
16
  human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
17
 
18
+ # Initialize MediaPipe Pose
 
 
19
  mp_pose = mp.solutions.pose
20
  pose = mp_pose.Pose(static_image_mode=True)
21
  mp_drawing = mp.solutions.drawing_utils
22
  mp_pose_landmark = mp_pose.PoseLandmark
23
 
24
+
25
  def detect_pose(image):
26
+ # Convert to RGB
27
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
28
+
29
+ # Run pose detection
30
  result = pose.process(image_rgb)
31
 
32
  keypoints = {}
33
 
34
  if result.pose_landmarks:
35
+ # Draw landmarks on image
36
  mp_drawing.draw_landmarks(image, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
37
 
38
+ # Get image dimensions
39
  height, width, _ = image.shape
40
 
41
+ # Extract specific landmarks
42
  landmark_indices = {
43
+ 'left_shoulder': mp_pose_landmark.LEFT_SHOULDER,
44
+ 'right_shoulder': mp_pose_landmark.RIGHT_SHOULDER,
45
+ 'left_hip': mp_pose_landmark.LEFT_HIP,
46
+ 'right_hip': mp_pose_landmark.RIGHT_HIP
47
  }
48
 
49
  for name, index in landmark_indices.items():
 
51
  x, y = int(lm.x * width), int(lm.y * height)
52
  keypoints[name] = (x, y)
53
 
54
+ # Draw a circle + label for debug
55
  cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
56
+ cv2.putText(image, name, (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
 
 
 
 
 
 
 
 
57
 
58
  return image
59
 
 
61
  def align_clothing(body_img, clothing_img):
62
  image_rgb = cv2.cvtColor(body_img, cv2.COLOR_BGR2RGB)
63
  result = pose.process(image_rgb)
 
64
  output = body_img.copy()
65
 
66
  if result.pose_landmarks:
67
  h, w, _ = output.shape
68
 
69
+ # Extract key points
70
  def get_point(landmark_id):
71
  lm = result.pose_landmarks.landmark[landmark_id]
72
  return int(lm.x * w), int(lm.y * h)
 
76
  left_hip = get_point(mp_pose_landmark.LEFT_HIP)
77
  right_hip = get_point(mp_pose_landmark.RIGHT_HIP)
78
 
79
+ # Destination box (torso region)
80
+ dst_pts = np.array([
81
+ left_shoulder,
82
+ right_shoulder,
83
+ right_hip,
84
+ left_hip
85
+ ], dtype=np.float32)
86
 
87
+ # Source box (clothing image corners)
88
  src_h, src_w = clothing_img.shape[:2]
89
+ src_pts = np.array([
90
+ [0, 0],
91
+ [src_w, 0],
92
+ [src_w, src_h],
93
+ [0, src_h]
94
+ ], dtype=np.float32)
95
+
96
+ # Compute perspective transform and warp
97
  matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
98
+ warped_clothing = cv2.warpPerspective(clothing_img, matrix, (w, h), borderMode=cv2.BORDER_TRANSPARENT)
 
 
99
 
100
+ # Handle transparency
101
  if clothing_img.shape[2] == 4:
102
  alpha = warped_clothing[:, :, 3] / 255.0
103
  for c in range(3):
 
121
  vt_model_type="viton_hd",
122
  vt_garment_type="upper_body",
123
  vt_repaint=False,
124
+ api_name="/leffa_predict_vt"
125
  )
126
 
127
  print(result)
128
  generated_image_path = result[0]
129
+ print("generated_image_path" + generated_image_path)
 
130
  generated_image = Image.open(generated_image_path)
 
131
 
132
+ return generated_image
133
 
 
 
 
134
 
135
  image_blocks = gr.Blocks().queue()
 
136
  with image_blocks as demo:
137
  gr.HTML("<center><h1>Virtual Try-On</h1></center>")
138
  gr.HTML("<center><p>Upload an image of a person and an image of a garment ✨</p></center>")
 
139
  with gr.Row():
140
  with gr.Column():
141
+ human_img = gr.Image(type="filepath", label='Human', interactive=True)
142
+ example = gr.Examples(
143
+ inputs=human_img,
144
+ examples_per_page=10,
145
+ examples=human_list_path
146
+ )
147
 
148
  with gr.Column():
149
+ garm_img = gr.Image(label="Garment", type="filepath", interactive=True)
150
+ example = gr.Examples(
151
+ inputs=garm_img,
152
+ examples_per_page=8,
153
+ examples=garm_list_path)
154
  with gr.Column():
155
  image_out = gr.Image(label="Processed image", type="pil")
156
 
157
  with gr.Row():
158
+ try_button = gr.Button(value="Try-on", variant='primary')
159
 
160
+ # Linking the button to the processing function
161
+ try_button.click(fn=process_image, inputs=[human_img, garm_img], outputs=image_out)
 
 
 
162
 
163
  image_blocks.launch(show_error=True)