CG_AskPDF / app.py
CatoG's picture
Create app.py
a596a48 verified
raw
history blame
9.03 kB
import os
from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain.chains import RetrievalQA
import gradio as gr
import warnings
import uuid
MODEL_OPTIONS = [
"meta-llama/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"google/gemma-2-9b-it",
"google/gemma-2-27b-it",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"microsoft/Phi-3.5-mini-instruct",
"HuggingFaceH4/zephyr-7b-beta"
]
# Suppress warnings
def warn(*args, **kwargs):
pass
warnings.warn = warn
warnings.filterwarnings("ignore")
# ---------------------------
# Get credentials from environment variables
# ---------------------------
def get_huggingface_token():
"""
Get HuggingFace API token from environment.
Set this in your Space settings under Settings > Repository secrets:
- HF_TOKEN or HUGGINGFACE_TOKEN
"""
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if not token:
raise ValueError(
"HF_TOKEN not found. Please set it in your HuggingFace Space secrets."
)
return token
# ---------------------------
# LLM
# ---------------------------
def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
token = get_huggingface_token()
llm = HuggingFaceEndpoint(
repo_id=model_id,
max_new_tokens=max_tokens,
temperature=temperature,
huggingfacehub_api_token=token,
)
return llm
# ---------------------------
# Document loader
# ---------------------------
def document_loader(file):
# Handle file path string from Gradio
file_path = file if isinstance(file, str) else file.name
loader = PyPDFLoader(file_path)
loaded_document = loader.load()
return loaded_document
# ---------------------------
# Text splitter
# ---------------------------
def text_splitter(data, chunk_size: int = 500, chunk_overlap: int = 50):
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
)
chunks = splitter.split_documents(data)
return chunks
# ---------------------------
# Embedding model
# ---------------------------
def get_embedding_model(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
"""
Create HuggingFace embedding model.
Using sentence-transformers for efficient embeddings.
"""
embedding = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
return embedding
# ---------------------------
# Vector DB
# ---------------------------
def vector_database(chunks, embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
embedding_model = get_embedding_model(embedding_model_name)
# Create unique collection name to avoid reusing cached data
collection_name = f"rag_collection_{uuid.uuid4().hex[:8]}"
vectordb = Chroma.from_documents(
chunks,
embedding_model,
collection_name=collection_name
)
return vectordb
# ---------------------------
# Retriever
# ---------------------------
def retriever(file, chunk_size: int = 500, chunk_overlap: int = 50, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
splits = document_loader(file)
chunks = text_splitter(splits, chunk_size, chunk_overlap)
vectordb = vector_database(chunks, embedding_model)
retriever_obj = vectordb.as_retriever()
return retriever_obj
# ---------------------------
# QA Chain
# ---------------------------
def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_model, chunk_size, chunk_overlap):
if not file:
return "Please upload a PDF file first."
if not query.strip():
return "Please enter a query."
try:
selected_model = model_choice or MODEL_OPTIONS[0]
llm = get_llm(selected_model, int(max_tokens), float(temperature))
retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever_obj,
return_source_documents=True,
)
response = qa.invoke({"query": query})
return response['result']
except Exception as e:
return f"Error: {str(e)}"
# ---------------------------
# Gradio Interface
# ---------------------------
with gr.Blocks(title="QA Bot - PDF Question Answering") as demo:
gr.Markdown("# πŸ“„ QA Bot - PDF Question Answering")
gr.Markdown(
"Upload a PDF document and ask questions about its content. "
"Powered by HuggingFace models and LangChain."
)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload PDF File",
file_count="single",
file_types=[".pdf"],
type="filepath"
)
query_input = gr.Textbox(
label="Your Question",
lines=3,
placeholder="Ask a question about the uploaded document..."
)
model_dropdown = gr.Dropdown(
label="LLM Model",
choices=MODEL_OPTIONS,
value=MODEL_OPTIONS[0],
)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
max_tokens_slider = gr.Slider(
label="Max New Tokens",
minimum=50,
maximum=2048,
value=256,
step=1,
info="Maximum number of tokens in the generated output"
)
temperature_slider = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=2.0,
value=0.8,
step=0.1,
info="Controls randomness/creativity of responses"
)
truncate_slider = gr.Dropdown(
label="Embedding Model",
choices=[
"sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers/all-mpnet-base-v2",
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5"
],
value="sentence-transformers/all-MiniLM-L6-v2",
info="Model used for generating embeddings"
)
chunk_size_slider = gr.Slider(
label="Chunk Size",
minimum=100,
maximum=2000,
value=500,
step=50,
info="Size of text chunks for processing"
)
chunk_overlap_slider = gr.Slider(
label="Chunk Overlap",
minimum=0,
maximum=500,
value=50,
step=10,
info="Overlap between consecutive chunks"
)
submit_btn = gr.Button("Ask Question", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Answer",
lines=15,
show_copy_button=True
)
submit_btn.click(
fn=retriever_qa,
inputs=[
file_input,
query_input,
model_dropdown,
max_tokens_slider,
temperature_slider,
truncate_slider,
chunk_size_slider,
chunk_overlap_slider
],
outputs=output_text
)
gr.Markdown(
"""
### πŸ“ Instructions
1. Upload a PDF document
2. Enter your question in the text box
3. (Optional) Select a different LLM model
4. (Optional) Adjust advanced settings for fine-tuning
5. Click "Ask Question" to get an answer
### πŸ” Setup
This Space requires a HuggingFace API token. Set the following in your Space secrets:
- `HF_TOKEN`: Your HuggingFace API token (get it from https://huggingface.co/settings/tokens)
"""
)
# ---------------------------
# Launch the app
# ---------------------------
if __name__ == "__main__":
demo.launch()