Spaces:
Running
Running
| import io | |
| from pprint import pformat | |
| import gradio as gr | |
| from hbutils.string import titleize | |
| from hfutils.repository import hf_hub_repo_url | |
| from imgutils.tagging.pixai import _open_default_category_thresholds, get_pixai_tags | |
| REPO_ID = 'deepghs/pixai-tagger-v0.9-onnx' | |
| if __name__ == '__main__': | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| repo_url = hf_hub_repo_url(repo_id=REPO_ID, repo_type='model') | |
| gr.HTML(f'<h2 style="text-align: center;">Tagger Demo For {REPO_ID}</h2>') | |
| gr.Markdown(f'This is the quick demo for tagger model [{REPO_ID}]({repo_url}). ' | |
| f'Powered by `dghs-imgutils`\'s quick demo module.') | |
| with gr.Row(): | |
| thresholds, names = _open_default_category_thresholds(model_name=REPO_ID) | |
| categories = sorted(set(names.keys())) | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr_input_image = gr.Image(type='pil', label='Original Image') | |
| with gr.Row(): | |
| gr_thresholds = [] | |
| for category in categories: | |
| gr_cate_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=thresholds[category], | |
| step=0.001, | |
| label=f'Threshold for {titleize(names[category])}', | |
| ) | |
| gr_thresholds.append(gr_cate_threshold) | |
| with gr.Row(): | |
| gr_submit = gr.Button(value='Submit', variant='primary') | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| gr_preds = [] | |
| for category in categories: | |
| with gr.Tab(f'{titleize(names[category])}'): | |
| gr_cate_label = gr.Label(f'{titleize(names[category])} Prediction') | |
| gr_preds.append(gr_cate_label) | |
| with gr.Tab('IPs Mapping'): | |
| gr_ips_mapping = gr.TextArea(label="IPs (string)", lines=15) | |
| with gr.Tab('Text Output'): | |
| gr_text_output = gr.TextArea(label="Output (string)", lines=15) | |
| def _fn_submit(image, *thresholds): | |
| _ths = { | |
| category: cate_ths | |
| for category, cate_ths in zip(categories, thresholds) | |
| } | |
| fmt = { | |
| **names, | |
| 'ips_mapping': 'ips_mapping', | |
| 'ips': 'ips', | |
| } | |
| res = get_pixai_tags(image=image, model_name=REPO_ID, thresholds=_ths, fmt=fmt) | |
| with io.StringIO() as sf: | |
| for category in categories: | |
| print(f'# {names[category]} (#{category})', file=sf) | |
| print(f', '.join(res[category].keys()), file=sf) | |
| print(f'', file=sf) | |
| print(f'# IPs', file=sf) | |
| print(f', '.join(res['ips']), file=sf) | |
| print(f'', file=sf) | |
| return sf.getvalue(), pformat(res['ips_mapping']), \ | |
| *[res[category] for category in categories] | |
| gr_submit.click( | |
| fn=_fn_submit, | |
| inputs=[gr_input_image, *gr_thresholds], | |
| outputs=[gr_text_output, gr_ips_mapping, *gr_preds] | |
| ) | |
| demo.launch() | |