Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from langchain.embeddings import HuggingFaceInstructEmbeddings | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.vectorstores import Chroma | |
| from langchain.document_loaders import TextLoader | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.llms import HuggingFaceHub | |
| from langchain.chains import ConversationalRetrievalChain | |
| embeddings = None | |
| qa_chain = None | |
| def load_embeddings(): | |
| global embeddings | |
| if not embeddings: | |
| print("loading embeddings...") | |
| model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME'] | |
| embeddings = HuggingFaceInstructEmbeddings(model_name=model_name) | |
| return embeddings | |
| def split_file(file, chunk_size, chunk_overlap): | |
| print('spliting file...', file.name, chunk_size, chunk_overlap) | |
| loader = TextLoader(file.name) | |
| documents = loader.load() | |
| text_splitter = CharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| return text_splitter.split_documents(documents) | |
| def get_persist_directory(file_name): | |
| return os.path.join(os.environ['CHROMADB_PERSIST_DIRECTORY'], file_name) | |
| def process_file(file, chunk_size, chunk_overlap): | |
| docs = split_file(file, chunk_size, chunk_overlap) | |
| embeddings = load_embeddings() | |
| file_name, _ = os.path.splitext(os.path.basename(file.name)) | |
| persist_directory = get_persist_directory(file_name) | |
| print("initializing vector store...", persist_directory) | |
| vectordb = Chroma.from_documents(documents=docs, embedding=embeddings, | |
| collection_name=file_name, persist_directory=persist_directory) | |
| print("persisting...", vectordb._client.list_collections()) | |
| vectordb.persist() | |
| return 'Done!', gr.Dropdown.update(choices=get_vector_dbs(), value=file_name) | |
| def is_dir(root, name): | |
| path = os.path.join(root, name) | |
| return os.path.isdir(path) | |
| def get_vector_dbs(): | |
| root = os.environ['CHROMADB_PERSIST_DIRECTORY'] | |
| if not os.path.exists(root): | |
| return [] | |
| print('get vector dbs...', root) | |
| files = os.listdir(root) | |
| dirs = list(filter(lambda x: is_dir(root, x), files)) | |
| print(dirs) | |
| return dirs | |
| def load_vectordb(file_name): | |
| embeddings = load_embeddings() | |
| persist_directory = get_persist_directory(file_name) | |
| print(persist_directory) | |
| vectordb = Chroma(collection_name=file_name, | |
| embedding_function=embeddings, persist_directory=persist_directory) | |
| print(vectordb._client.list_collections()) | |
| return vectordb | |
| def create_qa_chain(collection_name, temperature, max_length): | |
| print('creating qa chain...', collection_name, temperature, max_length) | |
| if not collection_name: | |
| return | |
| global qa_chain | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", return_messages=True) | |
| llm = HuggingFaceHub( | |
| repo_id=os.environ["HUGGINGFACEHUB_LLM_REPO_ID"], | |
| model_kwargs={"temperature": temperature, "max_length": max_length} | |
| ) | |
| vectordb = load_vectordb(collection_name) | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, retriever=vectordb.as_retriever(), memory=memory) | |
| def refresh_collection(): | |
| choices = get_vector_dbs() | |
| return gr.Dropdown.update(choices=choices, value=choices[0] if choices else None) | |
| def submit_message(bot_history, text): | |
| bot_history = bot_history + [(text, None)] | |
| return bot_history, "" | |
| def bot(bot_history): | |
| global qa_chain | |
| print(qa_chain, bot_history[-1][1]) | |
| result = qa_chain.run(bot_history[-1][0]) | |
| print(result) | |
| bot_history[-1][1] = result | |
| return bot_history | |
| def clear_bot(): | |
| return None | |
| title = "QnA Chatbot" | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# {title}") | |
| with gr.Tab("File"): | |
| upload = gr.File(file_types=["text"], label="Upload File") | |
| chunk_size = gr.Slider( | |
| 500, 5000, value=1000, step=100, label="Chunk Size") | |
| chunk_overlap = gr.Slider(0, 30, value=20, label="Chunk Overlap") | |
| process = gr.Button("Process") | |
| result = gr.Label() | |
| with gr.Tab("Bot"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| choices = get_vector_dbs() | |
| collection = gr.Dropdown( | |
| choices, value=choices[0] if choices else None, label="Document", allow_custom_value=True) | |
| with gr.Column(): | |
| refresh = gr.Button("Refresh") | |
| temperature = gr.Slider( | |
| 0.0, 1.0, value=0.5, step=0.05, label="Temperature") | |
| max_length = gr.Slider( | |
| 20, 1000, value=100, step=10, label="Max Length") | |
| with gr.Column(): | |
| chatbot = gr.Chatbot([], elem_id="chatbot").style(height=550) | |
| message = gr.Textbox( | |
| show_label=False, placeholder="Ask me anything!") | |
| clear = gr.Button("Clear") | |
| process.click( | |
| process_file, | |
| [upload, chunk_size, chunk_overlap], | |
| [result, collection] | |
| ) | |
| create_qa_chain(collection.value, temperature.value, max_length.value) | |
| collection.change(create_qa_chain, [collection, temperature, max_length]) | |
| temperature.change(create_qa_chain, [collection, temperature, max_length]) | |
| max_length.change(create_qa_chain, [collection, temperature, max_length]) | |
| refresh.click(refresh_collection, None, collection) | |
| message.submit(submit_message, [chatbot, message], [chatbot, message]).then( | |
| bot, chatbot, chatbot | |
| ) | |
| clear.click(clear_bot, None, chatbot) | |
| demo.title = title | |
| demo.launch() | |