Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -82,28 +82,40 @@ C = len(CLASS_NAMES)
|
|
| 82 |
# ---------------------- Load head ----------------------
|
| 83 |
# logger.info(f"Loading head from {HEAD_PATH}")
|
| 84 |
# head = tf.keras.models.load_model(HEAD_PATH, compile=False)
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
try:
|
| 88 |
-
|
| 89 |
-
return
|
| 90 |
-
except Exception as
|
| 91 |
-
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
import keras
|
| 96 |
-
logger.info(f"Falling back to keras.saving.load_model for {path}")
|
| 97 |
-
return keras.saving.load_model(path, compile=False)
|
| 98 |
-
except Exception as e_k:
|
| 99 |
-
raise RuntimeError(
|
| 100 |
-
f"Cannot load head model: {path}\n"
|
| 101 |
-
f"tf.keras error: {e_tf}\nkeras3 error: {e_k}\n"
|
| 102 |
-
"If your file is .keras (Keras 3), either add 'keras>=3.3' to requirements "
|
| 103 |
-
"or re-export to .h5 / SavedModel."
|
| 104 |
-
)
|
| 105 |
logger.info(f"Loading head from {HEAD_PATH}")
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
# ---------------------- Load mu/sd ----------------------
|
| 109 |
def _load_mu_sd():
|
|
@@ -232,17 +244,16 @@ def predict_probs(img_bytes: bytes) -> np.ndarray:
|
|
| 232 |
ex = tf.train.Example(features=tf.train.Features(
|
| 233 |
feature={'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[by]))}
|
| 234 |
)).SerializeToString()
|
| 235 |
-
|
| 236 |
out = infer(inputs=tf.constant([ex]))
|
| 237 |
if "embedding" not in out:
|
| 238 |
raise RuntimeError(f"Unexpected derm-foundation outputs: {list(out.keys())}")
|
| 239 |
-
|
| 240 |
-
emb
|
| 241 |
-
|
| 242 |
-
probs = head.predict(z, verbose=0)[0] # เรียกหัวโมเดลโดยตรง
|
| 243 |
return probs
|
| 244 |
|
| 245 |
|
|
|
|
| 246 |
# ---------------------- Endpoints ----------------------
|
| 247 |
@app.get("/health")
|
| 248 |
def health():
|
|
|
|
| 82 |
# ---------------------- Load head ----------------------
|
| 83 |
# logger.info(f"Loading head from {HEAD_PATH}")
|
| 84 |
# head = tf.keras.models.load_model(HEAD_PATH, compile=False)
|
| 85 |
+
import inspect
|
| 86 |
+
|
| 87 |
+
def load_head_any(path: str):
|
| 88 |
+
# กรณีเป็นโฟลเดอร์ SavedModel
|
| 89 |
+
if os.path.isdir(path) and os.path.exists(os.path.join(path, "saved_model.pb")):
|
| 90 |
+
m = tf.saved_model.load(path)
|
| 91 |
+
sig = m.signatures.get("serving_default") or next(iter(m.signatures.values()))
|
| 92 |
+
# หาชื่อ input/output อัตโนมัติ
|
| 93 |
+
in_names = list(sig.structured_input_signature[1].keys())
|
| 94 |
+
out_names = list(sig.structured_outputs.keys())
|
| 95 |
+
|
| 96 |
+
def _predict(z_np: np.ndarray) -> np.ndarray:
|
| 97 |
+
z_tf = tf.convert_to_tensor(z_np, dtype=tf.float32)
|
| 98 |
+
# บางรุ่นรับเป็น args บางรุ่นรับเป็น kwargs
|
| 99 |
+
if len(in_names) == 1:
|
| 100 |
+
out_dict = sig(z_tf)
|
| 101 |
+
else:
|
| 102 |
+
out_dict = sig(**{in_names[0]: z_tf})
|
| 103 |
+
y = out_dict[out_names[0]]
|
| 104 |
+
return y.numpy()
|
| 105 |
+
return _predict
|
| 106 |
+
|
| 107 |
+
# กรณีเป็นไฟล์ .h5 / .keras ที่ tf.keras อ่านได้
|
| 108 |
try:
|
| 109 |
+
model = tf.keras.models.load_model(path, compile=False)
|
| 110 |
+
return lambda z_np: model.predict(z_np, verbose=0)
|
| 111 |
+
except Exception as e:
|
| 112 |
+
raise RuntimeError(f"Cannot load head from {path}: {e}")
|
| 113 |
|
| 114 |
+
# ตั้ง path ไปยัง SavedModel ที่อัปขึ้น
|
| 115 |
+
HEAD_PATH = os.getenv("HEAD_PATH", "Models/mlp_head_savedmodel")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
logger.info(f"Loading head from {HEAD_PATH}")
|
| 117 |
+
head_predict = load_head_any(HEAD_PATH)
|
| 118 |
+
|
| 119 |
|
| 120 |
# ---------------------- Load mu/sd ----------------------
|
| 121 |
def _load_mu_sd():
|
|
|
|
| 244 |
ex = tf.train.Example(features=tf.train.Features(
|
| 245 |
feature={'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[by]))}
|
| 246 |
)).SerializeToString()
|
|
|
|
| 247 |
out = infer(inputs=tf.constant([ex]))
|
| 248 |
if "embedding" not in out:
|
| 249 |
raise RuntimeError(f"Unexpected derm-foundation outputs: {list(out.keys())}")
|
| 250 |
+
emb = out["embedding"].numpy().astype("float32") # (1,6144)
|
| 251 |
+
z = (emb - mu) / (sd + 1e-6)
|
| 252 |
+
probs = head_predict(z)[0] # <-- เรียกหัวที่โหลดมา
|
|
|
|
| 253 |
return probs
|
| 254 |
|
| 255 |
|
| 256 |
+
|
| 257 |
# ---------------------- Endpoints ----------------------
|
| 258 |
@app.get("/health")
|
| 259 |
def health():
|