PJ2005 commited on
Commit
4cd4e6d
·
verified ·
1 Parent(s): 66f73af

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re, os, tempfile, math
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import plotly.graph_objects as go
8
+ from functools import lru_cache
9
+ from transformers import pipeline
10
+ import emoji
11
+ from langdetect import detect, DetectorFactory
12
+ DetectorFactory.seed = 0 # deterministic language detection
13
+
14
+ # ----------------------
15
+ # Models and pipelines
16
+ # ----------------------
17
+ SENTIMENT_CHOICES = {
18
+ "SST-2 (binary, reviews)": "distilbert-base-uncased-finetuned-sst-2-english", # POSITIVE/NEGATIVE
19
+ "Twitter (EN, 3-class)": "cardiffnlp/twitter-roberta-base-sentiment-latest", # NEGATIVE/NEUTRAL/POSITIVE
20
+ "Twitter (Multilingual, 3-class)": "cardiffnlp/twitter-xlm-roberta-base-sentiment"
21
+ }
22
+ EMOTION_MODEL = "bhadresh-savani/distilbert-base-uncased-emotion" # sadness/joy/love/anger/fear/surprise
23
+
24
+ @lru_cache(maxsize=8)
25
+ def get_pipe(task, model_id):
26
+ # device_map="auto" if GPU is present in Colab; else CPU fallback
27
+ return pipeline(task, model=model_id, tokenizer=model_id if "cardiffnlp" in model_id else None)
28
+
29
+ # ----------------------
30
+ # Text utilities
31
+ # ----------------------
32
+ URL_RE = re.compile(r"https?://\S+|www\.\S+")
33
+ MENTION_RE = re.compile(r"@\w+")
34
+ HASHTAG_RE = re.compile(r"#(\w+)")
35
+
36
+ def clean_line(s, demojize=True, strip_social=True, lower=False):
37
+ if demojize:
38
+ s = emoji.demojize(s, language='en')
39
+ if strip_social:
40
+ s = URL_RE.sub("", s)
41
+ s = MENTION_RE.sub("", s)
42
+ # Keep hashtag token but turn into plain word
43
+ s = HASHTAG_RE.sub(r"\1", s)
44
+ s = s.strip()
45
+ if lower:
46
+ s = s.lower()
47
+ return s
48
+
49
+ def detect_langs(lines, probe=30):
50
+ # quick language probe on a sample
51
+ sample = [l for l in lines if l.strip()][:probe]
52
+ counts = {}
53
+ for s in sample:
54
+ try:
55
+ code = detect(s)
56
+ counts[code] = counts.get(code, 0) + 1
57
+ except:
58
+ counts["unk"] = counts.get("unk", 0) + 1
59
+ total = sum(counts.values()) or 1
60
+ share_en = counts.get("en", 0) / total
61
+ return counts, share_en
62
+
63
+ # ----------------------
64
+ # Core analyzer
65
+ # ----------------------
66
+ def run_analysis(
67
+ text_block, file_obj, text_col, mode, sentiment_model_choice, auto_model, demojize_opt,
68
+ strip_social_opt, lower_opt, batch_size
69
+ ):
70
+ # Collect lines from textbox and/or CSV
71
+ lines = []
72
+ if text_block:
73
+ lines.extend([l.rstrip() for l in text_block.splitlines() if l.strip()])
74
+
75
+ df_in = None
76
+ if file_obj is not None:
77
+ try:
78
+ df_in = pd.read_csv(file_obj.name)
79
+ use_col = text_col if (text_col and text_col in df_in.columns) else None
80
+ if not use_col:
81
+ # naive auto-pick
82
+ for c in ["text", "message", "msg", "content", "body"]:
83
+ if c in df_in.columns:
84
+ use_col = c
85
+ break
86
+ if not use_col:
87
+ return (pd.DataFrame([{"error": "CSV loaded, but no text column selected/found."}]),
88
+ plt.figure(), gr.update(value=None), "No language info")
89
+ lines.extend([str(x) for x in df_in[use_col].astype(str).tolist()])
90
+ except Exception as e:
91
+ return (pd.DataFrame([{"error": f"Failed to read CSV: {e}"}]),
92
+ plt.figure(), gr.update(value=None), "No language info")
93
+
94
+ if not lines:
95
+ return (pd.DataFrame([{"error": "Enter text or upload CSV with a text column."}]),
96
+ plt.figure(), gr.update(value=None), "No language info")
97
+
98
+ # Preprocess
99
+ proc = [clean_line(l, demojize=demojize_opt, strip_social=strip_social_opt, lower=lower_opt) for l in lines]
100
+
101
+ # Language probe to optionally switch sentiment model
102
+ lang_counts, share_en = detect_langs(proc, probe=min(30, len(proc)))
103
+ lang_info = f"Lang probe (top): {dict(sorted(lang_counts.items(), key=lambda x: -x[1])[:3])}, EN share≈{round(share_en,2)}"
104
+
105
+ # Choose model
106
+ if mode == "Sentiment":
107
+ if auto_model:
108
+ model_id = SENTIMENT_CHOICES["Twitter (EN, 3-class)"] if share_en >= 0.6 else SENTIMENT_CHOICES["Twitter (Multilingual, 3-class)"]
109
+ else:
110
+ model_id = SENTIMENT_CHOICES[sentiment_model_choice]
111
+ pipe = get_pipe("sentiment-analysis", model_id)
112
+ else:
113
+ pipe = get_pipe("text-classification", EMOTION_MODEL)
114
+
115
+ # Batched inference for speed
116
+ outputs = []
117
+ for i in range(0, len(proc), batch_size):
118
+ batch = proc[i:i+batch_size]
119
+ outs = pipe(batch, batch_size=batch_size, truncation=True)
120
+ # normalize to list[dict]
121
+ for out in outs:
122
+ out0 = out[0] if isinstance(out, list) else out
123
+ outputs.append({"label": out0["label"], "score": float(out0["score"])})
124
+
125
+ # Build results DataFrame
126
+ rows = []
127
+ for idx, (raw, out) in enumerate(zip(lines, outputs), 1):
128
+ rows.append({
129
+ "idx": idx,
130
+ "text": raw,
131
+ "label": out["label"],
132
+ "score": round(out["score"], 4)
133
+ })
134
+ df = pd.DataFrame(rows)
135
+
136
+ # Distribution plot (matplotlib)
137
+ counts = df["label"].value_counts().sort_values(ascending=False)
138
+ fig, ax = plt.subplots(figsize=(6.5, 3.2))
139
+ counts.plot(kind="bar", ax=ax, color="#4C78A8")
140
+ ax.set_title("Label Distribution")
141
+ ax.set_xlabel("Label")
142
+ ax.set_ylabel("Count")
143
+ plt.tight_layout()
144
+
145
+ # Export to CSV
146
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
147
+ df.to_csv(tmp.name, index=False)
148
+
149
+ return df, fig, tmp.name, lang_info
150
+
151
+ # ----------------------
152
+ # UI
153
+ # ----------------------
154
+ with gr.Blocks(title="Chat Mood Analyzer — Ultimate") as demo:
155
+ gr.Markdown("## Chat Mood Analyzer — Ultimate Edition")
156
+
157
+ with gr.Row():
158
+ with gr.Column():
159
+ txt = gr.Textbox(lines=10, label="Paste chat (one message per line)")
160
+ file_in = gr.File(label="Or upload CSV", file_types=[".csv"], file_count="single")
161
+ text_col = gr.Textbox(value="", label="CSV text column (auto-detect if blank)")
162
+
163
+ mode = gr.Radio(["Sentiment", "Emotion"], value="Sentiment", label="Analysis mode")
164
+
165
+ with gr.Accordion("Sentiment model settings", open=False):
166
+ auto_model = gr.Checkbox(value=True, label="Auto-pick tweet-aware EN vs Multilingual")
167
+ sentiment_model = gr.Dropdown(
168
+ choices=list(SENTIMENT_CHOICES.keys()),
169
+ value="Twitter (EN, 3-class)",
170
+ label="Manual model (used if Auto is OFF)"
171
+ )
172
+ gr.Markdown("Tip: Twitter models understand slang/emojis better than SST‑2 review models.")
173
+
174
+ with gr.Accordion("Preprocessing", open=False):
175
+ demojize_opt = gr.Checkbox(value=True, label="Convert emojis to text (:face_with_tears_of_joy:)")
176
+ strip_social_opt = gr.Checkbox(value=True, label="Strip URLs/@mentions/#hashtags")
177
+ lower_opt = gr.Checkbox(value=False, label="Lowercase text")
178
+
179
+ batch_size = gr.Slider(1, 64, value=16, step=1, label="Batch size")
180
+
181
+ run = gr.Button("Analyze", variant="primary")
182
+ clear = gr.ClearButton([txt, file_in])
183
+
184
+ with gr.Column():
185
+ out_table = gr.Dataframe(label="Per-message results", wrap=True)
186
+ out_plot = gr.Plot(label="Label distribution")
187
+ download = gr.File(label="Download results (.csv)")
188
+ lang_probe = gr.Markdown()
189
+
190
+ evt = run.click(
191
+ fn=run_analysis,
192
+ inputs=[txt, file_in, text_col, mode, sentiment_model, auto_model,
193
+ demojize_opt, strip_social_opt, lower_opt, batch_size],
194
+ outputs=[out_table, out_plot, download, lang_probe],
195
+ concurrency_limit=4,
196
+ show_progress=True
197
+ )
198
+
199
+ demo.queue(max_size=64, default_concurrency_limit=2)
200
+ demo.launch(share=True)