Spaces:
Runtime error
Runtime error
| import torch | |
| import src.constants.config as configurations | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers import CrossEncoder | |
| from src.constants.credentials import cohere_trial_key, mixedbread_key | |
| import streamlit as st | |
| from src.reader import Reader | |
| from src.utils_search import UtilsSearch | |
| from copy import deepcopy | |
| import numpy as np | |
| import cohere | |
| from mixedbread_ai.client import MixedbreadAI | |
| from src.pytorch_modules.datasets.schema_string_dataset import SchemaStringDataset | |
| configurations = configurations.service_mxbai_msc_direct_config | |
| api_key = cohere_trial_key | |
| co = cohere.Client(api_key) | |
| semantic_column_names = configurations["semantic_column_names"] | |
| model = MixedbreadAI(api_key=mixedbread_key) | |
| cross_encoder_name = configurations["cross_encoder_name"] | |
| def init(): | |
| config = configurations | |
| search_utils = UtilsSearch(config) | |
| reader = Reader(config=config["reader_config"]) | |
| df = reader.read() | |
| index = search_utils.dataframe_to_index(df) | |
| return df, index, search_utils | |
| def get_possible_values_for_column(column_name, search_utils, df): | |
| if column_name not in st.session_state: | |
| setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name)) | |
| return getattr(st.session_state, column_name) | |
| # Initialize or retrieve from session state | |
| if 'init_results' not in st.session_state: | |
| st.session_state.init_results = init() | |
| # Now you can access your initialized objects directly from the session state | |
| df, index, search_utils = st.session_state.init_results | |
| # Streamlit app layout | |
| st.title('Search Demo') | |
| # Input fields | |
| query = st.text_input('Enter your search query here') | |
| use_cohere = st.checkbox('Use Cohere', value=True) # Default to checked | |
| programmatic_search_config = deepcopy(configurations['programmatic_search_config']) | |
| dynamic_programmatic_search_config = { | |
| "scalar_columns": [], | |
| "discrete_columns": [] | |
| } | |
| for column in programmatic_search_config['scalar_columns']: | |
| # Create number input for scalar values | |
| col_name = column["column_name"] | |
| min_val = float(column["min_value"]) | |
| max_val = float(column["max_value"]) | |
| user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val) | |
| user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val) | |
| dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max}) | |
| for column in programmatic_search_config['discrete_columns']: | |
| # Create multiselect for discrete values | |
| col_name = column["column_name"] | |
| default_values = column["default_values"] | |
| # Assuming you have a function to fetch possible values for the discrete columns based on the column name | |
| possible_values = get_possible_values_for_column(col_name, search_utils, df) # Implement this function based on your application | |
| selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values) | |
| dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values}) | |
| programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns'] | |
| programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns'] | |
| # Search button | |
| if st.button('Search'): | |
| if query: # Checking if a query was entered | |
| df_retrieved = search_utils.retrieve(query, df, model, index, top_k=1000, api=True) | |
| df_filtered = search_utils.filter_dataframe(df_retrieved, programmatic_search_config) | |
| df_filtered = df_filtered.sort_values(by='similarities', ascending=True) | |
| df_filtered = df_filtered[:100].reset_index(drop=True) | |
| if len(df_filtered) == 0: | |
| st.write('No results found') | |
| else: | |
| if use_cohere == False: | |
| records = df_filtered.to_dict(orient='records') | |
| dataset_str = SchemaStringDataset(records, configurations) | |
| documents = [batch["inputs"][:256] for batch in dataset_str] | |
| res = model.reranking( | |
| model=cross_encoder_name, | |
| query=query, | |
| input=documents, | |
| top_k=10, | |
| return_input=False | |
| ) | |
| ids = [item.index for item in res.data] | |
| results_df = df_filtered.loc[ids] | |
| else: | |
| df_filtered.fillna(value="", inplace=True) | |
| docs = df_filtered.to_dict('records') | |
| column_names = semantic_column_names | |
| docs = [{name: str(doc[name]) for name in column_names} for doc in docs] | |
| rank_fields = list(docs[0].keys()) | |
| results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0', | |
| rank_fields=rank_fields) | |
| top_ids = [hit.index for hit in results.results] | |
| # Create the DataFrame with the rerank results | |
| results_df = df_filtered.iloc[top_ids].copy() | |
| results_df['rank'] = (np.arange(len(results_df)) + 1) | |
| results_df = search_utils.drop_columns(results_df, programmatic_search_config) | |
| st.write(results_df) | |
| else: | |
| st.write("Please enter a query to search.") | |