Arlet41 commited on
Commit
3681b53
·
verified ·
1 Parent(s): e5ee8be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -60
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
-
3
  import cv2
4
  import gradio as gr
5
  import mediapipe as mp
@@ -7,7 +6,10 @@ import numpy as np
7
  from PIL import Image
8
  from gradio_client import Client, handle_file
9
 
10
- example_path = os.path.join(os.path.dirname(__file__), 'example')
 
 
 
11
 
12
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
13
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
@@ -15,35 +17,30 @@ garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_lis
15
  human_list = os.listdir(os.path.join(example_path, "human"))
16
  human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
17
 
18
- # Initialize MediaPipe Pose
 
 
19
  mp_pose = mp.solutions.pose
20
  pose = mp_pose.Pose(static_image_mode=True)
21
  mp_drawing = mp.solutions.drawing_utils
22
  mp_pose_landmark = mp_pose.PoseLandmark
23
 
24
-
25
  def detect_pose(image):
26
- # Convert to RGB
27
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
28
-
29
- # Run pose detection
30
  result = pose.process(image_rgb)
31
 
32
  keypoints = {}
33
 
34
  if result.pose_landmarks:
35
- # Draw landmarks on image
36
  mp_drawing.draw_landmarks(image, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
37
 
38
- # Get image dimensions
39
  height, width, _ = image.shape
40
 
41
- # Extract specific landmarks
42
  landmark_indices = {
43
- 'left_shoulder': mp_pose_landmark.LEFT_SHOULDER,
44
- 'right_shoulder': mp_pose_landmark.RIGHT_SHOULDER,
45
- 'left_hip': mp_pose_landmark.LEFT_HIP,
46
- 'right_hip': mp_pose_landmark.RIGHT_HIP
47
  }
48
 
49
  for name, index in landmark_indices.items():
@@ -51,9 +48,16 @@ def detect_pose(image):
51
  x, y = int(lm.x * width), int(lm.y * height)
52
  keypoints[name] = (x, y)
53
 
54
- # Draw a circle + label for debug
55
  cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
56
- cv2.putText(image, name, (x + 5, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
 
 
 
 
 
 
 
 
57
 
58
  return image
59
 
@@ -61,12 +65,12 @@ def detect_pose(image):
61
  def align_clothing(body_img, clothing_img):
62
  image_rgb = cv2.cvtColor(body_img, cv2.COLOR_BGR2RGB)
63
  result = pose.process(image_rgb)
 
64
  output = body_img.copy()
65
 
66
  if result.pose_landmarks:
67
  h, w, _ = output.shape
68
 
69
- # Extract key points
70
  def get_point(landmark_id):
71
  lm = result.pose_landmarks.landmark[landmark_id]
72
  return int(lm.x * w), int(lm.y * h)
@@ -76,28 +80,20 @@ def align_clothing(body_img, clothing_img):
76
  left_hip = get_point(mp_pose_landmark.LEFT_HIP)
77
  right_hip = get_point(mp_pose_landmark.RIGHT_HIP)
78
 
79
- # Destination box (torso region)
80
- dst_pts = np.array([
81
- left_shoulder,
82
- right_shoulder,
83
- right_hip,
84
- left_hip
85
- ], dtype=np.float32)
86
 
87
- # Source box (clothing image corners)
88
  src_h, src_w = clothing_img.shape[:2]
89
- src_pts = np.array([
90
- [0, 0],
91
- [src_w, 0],
92
- [src_w, src_h],
93
- [0, src_h]
94
- ], dtype=np.float32)
95
-
96
- # Compute perspective transform and warp
97
  matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
98
- warped_clothing = cv2.warpPerspective(clothing_img, matrix, (w, h), borderMode=cv2.BORDER_TRANSPARENT)
 
 
99
 
100
- # Handle transparency
101
  if clothing_img.shape[2] == 4:
102
  alpha = warped_clothing[:, :, 3] / 255.0
103
  for c in range(3):
@@ -109,9 +105,6 @@ def align_clothing(body_img, clothing_img):
109
 
110
 
111
  def process_image(human_img_path, garm_img_path):
112
- from io import BytesIO
113
- import base64
114
-
115
  client = Client("franciszzj/Leffa")
116
 
117
  result = client.predict(
@@ -124,47 +117,46 @@ def process_image(human_img_path, garm_img_path):
124
  vt_model_type="viton_hd",
125
  vt_garment_type="upper_body",
126
  vt_repaint=False,
127
- api_name="/leffa_predict_vt"
128
  )
129
 
 
130
  generated_image_path = result[0]
131
- generated_image = Image.open(generated_image_path)
132
 
133
- # Convertir imagen → base64
134
- buffer = BytesIO()
135
- generated_image.save(buffer, format="PNG")
136
- base64_img = base64.b64encode(buffer.getvalue()).decode("utf-8")
137
 
138
- # Retornar base64 (no path)
139
- return {"base64": base64_img}
140
 
 
 
 
141
 
142
  image_blocks = gr.Blocks().queue()
 
143
  with image_blocks as demo:
144
  gr.HTML("<center><h1>Virtual Try-On</h1></center>")
145
  gr.HTML("<center><p>Upload an image of a person and an image of a garment ✨</p></center>")
 
146
  with gr.Row():
147
  with gr.Column():
148
- human_img = gr.Image(type="filepath", label='Human', interactive=True)
149
- example = gr.Examples(
150
- inputs=human_img,
151
- examples_per_page=10,
152
- examples=human_list_path
153
- )
154
 
155
  with gr.Column():
156
- garm_img = gr.Image(label="Garment", type="filepath", interactive=True)
157
- example = gr.Examples(
158
- inputs=garm_img,
159
- examples_per_page=8,
160
- examples=garm_list_path)
161
  with gr.Column():
162
  image_out = gr.Image(label="Processed image", type="pil")
163
 
164
  with gr.Row():
165
- try_button = gr.Button(value="Try-on", variant='primary')
166
 
167
- # Linking the button to the processing function
168
- try_button.click(fn=process_image, inputs=[human_img, garm_img], outputs=image_out)
 
 
 
169
 
170
  image_blocks.launch(show_error=True)
 
1
  import os
 
2
  import cv2
3
  import gradio as gr
4
  import mediapipe as mp
 
6
  from PIL import Image
7
  from gradio_client import Client, handle_file
8
 
9
+ # ------------------------------
10
+ # PATHS
11
+ # ------------------------------
12
+ example_path = os.path.join(os.path.dirname(__file__), "example")
13
 
14
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
15
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
 
17
  human_list = os.listdir(os.path.join(example_path, "human"))
18
  human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
19
 
20
+ # ------------------------------
21
+ # MEDIAPIPE POSE
22
+ # ------------------------------
23
  mp_pose = mp.solutions.pose
24
  pose = mp_pose.Pose(static_image_mode=True)
25
  mp_drawing = mp.solutions.drawing_utils
26
  mp_pose_landmark = mp_pose.PoseLandmark
27
 
 
28
  def detect_pose(image):
 
29
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
30
  result = pose.process(image_rgb)
31
 
32
  keypoints = {}
33
 
34
  if result.pose_landmarks:
 
35
  mp_drawing.draw_landmarks(image, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
36
 
 
37
  height, width, _ = image.shape
38
 
 
39
  landmark_indices = {
40
+ "left_shoulder": mp_pose_landmark.LEFT_SHOULDER,
41
+ "right_shoulder": mp_pose_landmark.RIGHT_SHOULDER,
42
+ "left_hip": mp_pose_landmark.LEFT_HIP,
43
+ "right_hip": mp_pose_landmark.RIGHT_HIP,
44
  }
45
 
46
  for name, index in landmark_indices.items():
 
48
  x, y = int(lm.x * width), int(lm.y * height)
49
  keypoints[name] = (x, y)
50
 
 
51
  cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
52
+ cv2.putText(
53
+ image,
54
+ name,
55
+ (x + 5, y - 5),
56
+ cv2.FONT_HERSHEY_SIMPLEX,
57
+ 0.5,
58
+ (255, 255, 255),
59
+ 1,
60
+ )
61
 
62
  return image
63
 
 
65
  def align_clothing(body_img, clothing_img):
66
  image_rgb = cv2.cvtColor(body_img, cv2.COLOR_BGR2RGB)
67
  result = pose.process(image_rgb)
68
+
69
  output = body_img.copy()
70
 
71
  if result.pose_landmarks:
72
  h, w, _ = output.shape
73
 
 
74
  def get_point(landmark_id):
75
  lm = result.pose_landmarks.landmark[landmark_id]
76
  return int(lm.x * w), int(lm.y * h)
 
80
  left_hip = get_point(mp_pose_landmark.LEFT_HIP)
81
  right_hip = get_point(mp_pose_landmark.RIGHT_HIP)
82
 
83
+ dst_pts = np.array(
84
+ [left_shoulder, right_shoulder, right_hip, left_hip], dtype=np.float32
85
+ )
 
 
 
 
86
 
 
87
  src_h, src_w = clothing_img.shape[:2]
88
+ src_pts = np.array(
89
+ [[0, 0], [src_w, 0], [src_w, src_h], [0, src_h]], dtype=np.float32
90
+ )
91
+
 
 
 
 
92
  matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
93
+ warped_clothing = cv2.warpPerspective(
94
+ clothing_img, matrix, (w, h), borderMode=cv2.BORDER_TRANSPARENT
95
+ )
96
 
 
97
  if clothing_img.shape[2] == 4:
98
  alpha = warped_clothing[:, :, 3] / 255.0
99
  for c in range(3):
 
105
 
106
 
107
  def process_image(human_img_path, garm_img_path):
 
 
 
108
  client = Client("franciszzj/Leffa")
109
 
110
  result = client.predict(
 
117
  vt_model_type="viton_hd",
118
  vt_garment_type="upper_body",
119
  vt_repaint=False,
120
+ api_name="/leffa_predict_vt",
121
  )
122
 
123
+ print(result)
124
  generated_image_path = result[0]
125
+ print("generated_image_path " + generated_image_path)
126
 
127
+ generated_image = Image.open(generated_image_path)
128
+ return generated_image
 
 
129
 
 
 
130
 
131
+ # ------------------------------
132
+ # GRADIO UI
133
+ # ------------------------------
134
 
135
  image_blocks = gr.Blocks().queue()
136
+
137
  with image_blocks as demo:
138
  gr.HTML("<center><h1>Virtual Try-On</h1></center>")
139
  gr.HTML("<center><p>Upload an image of a person and an image of a garment ✨</p></center>")
140
+
141
  with gr.Row():
142
  with gr.Column():
143
+ human_img = gr.Image(type="filepath", label="Human", interactive=True)
144
+ gr.Examples(inputs=human_img, examples_per_page=10, examples=human_list_path)
 
 
 
 
145
 
146
  with gr.Column():
147
+ garm_img = gr.Image(type="filepath", label="Garment", interactive=True)
148
+ gr.Examples(inputs=garm_img, examples_per_page=8, examples=garm_list_path)
149
+
 
 
150
  with gr.Column():
151
  image_out = gr.Image(label="Processed image", type="pil")
152
 
153
  with gr.Row():
154
+ try_button = gr.Button(value="Try-on", variant="primary")
155
 
156
+ try_button.click(
157
+ fn=process_image,
158
+ inputs=[human_img, garm_img],
159
+ outputs=image_out,
160
+ )
161
 
162
  image_blocks.launch(show_error=True)