cnxoo commited on
Commit
b86d351
·
verified ·
1 Parent(s): ac313ad

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -24
main.py CHANGED
@@ -82,28 +82,40 @@ C = len(CLASS_NAMES)
82
  # ---------------------- Load head ----------------------
83
  # logger.info(f"Loading head from {HEAD_PATH}")
84
  # head = tf.keras.models.load_model(HEAD_PATH, compile=False)
85
- def load_head_model(path: str):
86
- # 1) ลอง tf.keras (.h5 / SavedModel)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  try:
88
- logger.info(f"Loading head via tf.keras from {path}")
89
- return tf.keras.models.load_model(path, compile=False)
90
- except Exception as e_tf:
91
- logger.warning(f"tf.keras load failed: {e_tf}")
92
 
93
- # 2) ถ้าเป็น .keras (Keras 3)
94
- try:
95
- import keras
96
- logger.info(f"Falling back to keras.saving.load_model for {path}")
97
- return keras.saving.load_model(path, compile=False)
98
- except Exception as e_k:
99
- raise RuntimeError(
100
- f"Cannot load head model: {path}\n"
101
- f"tf.keras error: {e_tf}\nkeras3 error: {e_k}\n"
102
- "If your file is .keras (Keras 3), either add 'keras>=3.3' to requirements "
103
- "or re-export to .h5 / SavedModel."
104
- )
105
  logger.info(f"Loading head from {HEAD_PATH}")
106
- head = tf.keras.models.load_model(HEAD_PATH, compile=False)
 
107
 
108
  # ---------------------- Load mu/sd ----------------------
109
  def _load_mu_sd():
@@ -232,17 +244,16 @@ def predict_probs(img_bytes: bytes) -> np.ndarray:
232
  ex = tf.train.Example(features=tf.train.Features(
233
  feature={'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[by]))}
234
  )).SerializeToString()
235
-
236
  out = infer(inputs=tf.constant([ex]))
237
  if "embedding" not in out:
238
  raise RuntimeError(f"Unexpected derm-foundation outputs: {list(out.keys())}")
239
-
240
- emb = out["embedding"].numpy().astype("float32") # (1, 6144)
241
- z = (emb - mu) / (sd + 1e-6) # ทำ normalization ด้วย NumPy
242
- probs = head.predict(z, verbose=0)[0] # เรียกหัวโมเดลโดยตรง
243
  return probs
244
 
245
 
 
246
  # ---------------------- Endpoints ----------------------
247
  @app.get("/health")
248
  def health():
 
82
  # ---------------------- Load head ----------------------
83
  # logger.info(f"Loading head from {HEAD_PATH}")
84
  # head = tf.keras.models.load_model(HEAD_PATH, compile=False)
85
+ import inspect
86
+
87
+ def load_head_any(path: str):
88
+ # กรณีเป็นโฟลเดอร์ SavedModel
89
+ if os.path.isdir(path) and os.path.exists(os.path.join(path, "saved_model.pb")):
90
+ m = tf.saved_model.load(path)
91
+ sig = m.signatures.get("serving_default") or next(iter(m.signatures.values()))
92
+ # หาชื่อ input/output อัตโนมัติ
93
+ in_names = list(sig.structured_input_signature[1].keys())
94
+ out_names = list(sig.structured_outputs.keys())
95
+
96
+ def _predict(z_np: np.ndarray) -> np.ndarray:
97
+ z_tf = tf.convert_to_tensor(z_np, dtype=tf.float32)
98
+ # บางรุ่นรับเป็น args บางรุ่นรับเป็น kwargs
99
+ if len(in_names) == 1:
100
+ out_dict = sig(z_tf)
101
+ else:
102
+ out_dict = sig(**{in_names[0]: z_tf})
103
+ y = out_dict[out_names[0]]
104
+ return y.numpy()
105
+ return _predict
106
+
107
+ # กรณีเป็นไฟล์ .h5 / .keras ที่ tf.keras อ่านได้
108
  try:
109
+ model = tf.keras.models.load_model(path, compile=False)
110
+ return lambda z_np: model.predict(z_np, verbose=0)
111
+ except Exception as e:
112
+ raise RuntimeError(f"Cannot load head from {path}: {e}")
113
 
114
+ # ตั้ง path ไปยัง SavedModel ที่อัปขึ้น
115
+ HEAD_PATH = os.getenv("HEAD_PATH", "Models/mlp_head_savedmodel")
 
 
 
 
 
 
 
 
 
 
116
  logger.info(f"Loading head from {HEAD_PATH}")
117
+ head_predict = load_head_any(HEAD_PATH)
118
+
119
 
120
  # ---------------------- Load mu/sd ----------------------
121
  def _load_mu_sd():
 
244
  ex = tf.train.Example(features=tf.train.Features(
245
  feature={'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[by]))}
246
  )).SerializeToString()
 
247
  out = infer(inputs=tf.constant([ex]))
248
  if "embedding" not in out:
249
  raise RuntimeError(f"Unexpected derm-foundation outputs: {list(out.keys())}")
250
+ emb = out["embedding"].numpy().astype("float32") # (1,6144)
251
+ z = (emb - mu) / (sd + 1e-6)
252
+ probs = head_predict(z)[0] # <-- เรียกหัวที่โหลดมา
 
253
  return probs
254
 
255
 
256
+
257
  # ---------------------- Endpoints ----------------------
258
  @app.get("/health")
259
  def health():