Spaces:
Runtime error
Runtime error
| import os | |
| import streamlit as st | |
| from langchain.llms import HuggingFaceHub | |
| from models import return_sum_models | |
| class LLM_Langchain(): | |
| def __init__(self): | |
| st.header('π¦ Code summarization') | |
| st.warning("Warning: input function needs cleaning and may take long to be processed at first time") | |
| st.warning("Note: you should not copy the whole function from IDE, the \"\\n\" character needs typing by hand") | |
| st.info("Reference: [CodeT5](https://arxiv.org/abs/2109.00859), [The Vault](https://arxiv.org/abs/2305.06156), [CodeXGLUE](https://arxiv.org/abs/2102.04664)") | |
| st.info("About me: namnh113") | |
| self.API_KEY = st.sidebar.text_input( | |
| 'API key', | |
| type='password', | |
| help="Type in your HuggingFace API key to use this app") | |
| self.model_parent = st.sidebar.selectbox( | |
| label = "Choose language", | |
| options = ["python", "java", "javascript", "php", "ruby", "go", "cpp"], | |
| help="Choose languages", | |
| ) | |
| if self.model_parent is None: | |
| model_name_visibility = True | |
| else: | |
| model_name_visibility = False | |
| model_name = return_sum_models(self.model_parent) | |
| list_model = [model_name] | |
| if self.model_parent in ["python", "java"]: | |
| list_model += [model_name+"_v2"] | |
| if self.model_parent != "cpp": | |
| list_model += ["Salesforce/codet5-base-multi-sum", f"Salesforce/codet5-base-codexglue-sum-{self.model_parent}"] | |
| self.checkpoint = st.sidebar.selectbox( | |
| label = "Choose model (namnh113/... is my model)", | |
| options = list_model, | |
| help="Model used to predict", | |
| disabled=model_name_visibility | |
| ) | |
| self.max_new_tokens = st.sidebar.slider( | |
| label="Token Length", | |
| min_value=32, | |
| max_value=248, | |
| step=4, | |
| value=128, | |
| help="Set the max tokens to get accurate results" | |
| ) | |
| self.num_beams = st.sidebar.slider( | |
| label="num beams", | |
| min_value=1, | |
| max_value=10, | |
| step=1, | |
| value=2, | |
| help="Set num beam" | |
| ) | |
| self.top_k = st.sidebar.slider( | |
| label="top k", | |
| min_value=1, | |
| max_value=50, | |
| step=1, | |
| value=30, | |
| help="Set the top_k" | |
| ) | |
| self.top_p = st.sidebar.slider( | |
| label="top p", | |
| min_value=0.1, | |
| max_value=1.0, | |
| step=0.05, | |
| value=0.95, | |
| help="Set the top_p" | |
| ) | |
| self.model_kwargs = { | |
| "max_new_tokens": self.max_new_tokens, | |
| "top_k": self.top_k, | |
| "top_p": self.top_p, | |
| "num_beams": self.num_beams | |
| } | |
| os.environ['HUGGINGFACEHUB_API_TOKEN'] = self.API_KEY | |
| def generate_response(self, input_text): | |
| input_text = "Summarize " + self.model_parent.capitalize() + ": " + input_text | |
| llm = HuggingFaceHub( | |
| repo_id = self.checkpoint, | |
| model_kwargs = self.model_kwargs | |
| ) | |
| return llm(input_text) | |
| def form_data(self): | |
| # with st.form('my_form'): | |
| try: | |
| if not self.API_KEY.startswith('hf_'): | |
| st.warning('Please enter your API key!', icon='β ') | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| st.write(f"You are using {self.checkpoint} model") | |
| for message in st.session_state.messages: | |
| with st.chat_message(message.get('role')): | |
| st.write(message.get("content")) | |
| text = st.chat_input(disabled=False) | |
| if text: | |
| st.session_state.messages.append( | |
| { | |
| "role":"user", | |
| "content": text | |
| } | |
| ) | |
| with st.chat_message("user"): | |
| st.write(text) | |
| if text.lower() == "clear": | |
| del st.session_state.messages | |
| return | |
| result = self.generate_response(text) | |
| result = result.replace(' * ', '\n* ') | |
| st.session_state.messages.append( | |
| { | |
| "role": "assistant", | |
| "content": result | |
| } | |
| ) | |
| with st.chat_message('assistant'): | |
| st.markdown(result) | |
| except Exception as e: | |
| st.error(e, icon="π¨") | |
| model = LLM_Langchain() | |
| model.form_data() |