Spaces:
Running
Running
| import gc | |
| import json | |
| import sqlite3 | |
| from pathlib import Path | |
| from typing import Optional, Tuple, Any, Dict, List, Set, Union | |
| from collections import Counter | |
| import numpy as np | |
| import faiss | |
| # [수정됨] 패키지 구조 변경 반영 및 EnsembleRetriever 제거 | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_core.documents import Document | |
| from langchain_community.vectorstores import FAISS | |
| from sentence_transformers import SentenceTransformer | |
| # 런타임에 Embeddings 클래스를 찾기 위한 로직 | |
| try: | |
| from langchain_core.embeddings import Embeddings | |
| except ImportError: | |
| try: | |
| from langchain.embeddings.base import Embeddings | |
| except ImportError: | |
| Embeddings = object | |
| # --- SQLite 헬퍼 함수 --- | |
| SQLITE_DB_NAME = "metadata_mapping.db" | |
| # === IDSelector 클래스 정의 === | |
| class MetadataIDSelector(faiss.IDSelectorBatch): | |
| def __init__(self, allowed_ids: Set[int]): | |
| super().__init__(list(allowed_ids)) | |
| def get_db_connection(persist_directory: str) -> sqlite3.Connection: | |
| """FAISS 저장 경로를 기반으로 SQLite 연결을 설정하고 반환합니다.""" | |
| db_path = Path(persist_directory) / SQLITE_DB_NAME | |
| conn = sqlite3.connect(db_path) | |
| return conn | |
| def _create_and_populate_sqlite_db(chunks: List[Document], persist_directory: str): | |
| """ | |
| 문서 청크를 기반으로 SQLite DB를 생성하고 채웁니다. | |
| [업데이트 반영] 메타데이터 구조: regulation, chapter, section, standard | |
| """ | |
| # 1. 입력 데이터 확인 (가장 중요한 체크 포인트) | |
| if not chunks: | |
| print("🚨 [오류] _create_and_populate_sqlite_db 함수에 전달된 chunks 리스트가 비어 있습니다!") | |
| print(" -> load_chunks_from_jsonl 함수가 정상적으로 파일을 읽었는지 확인해주세요.") | |
| return | |
| # 2. 저장 경로 확인 및 생성 | |
| save_dir = Path(persist_directory) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| conn = get_db_connection(persist_directory) | |
| try: | |
| cursor = conn.cursor() | |
| # 3. 테이블 생성 (기존 테이블 삭제 후 재생성 옵션 고려) | |
| # 스키마가 변경되었으므로 기존 테이블이 있다면 충돌날 수 있습니다. | |
| # 안전하게 지우고 다시 만드는 방법을 추천합니다. (개발 단계) | |
| cursor.execute("DROP TABLE IF EXISTS documents") | |
| cursor.execute(""" | |
| CREATE TABLE documents ( | |
| faiss_id INTEGER PRIMARY KEY, | |
| source TEXT, | |
| regulation TEXT, | |
| chapter TEXT, | |
| section TEXT, | |
| standard TEXT, | |
| json_metadata TEXT | |
| ) | |
| """) | |
| # 테이블 생성 직후 커밋 (파일에 스키마 기록) | |
| conn.commit() | |
| print(f"📂 DB 테이블 생성 완료 (경로: {save_dir}/{SQLITE_DB_NAME})") | |
| # 4. 데이터 채우기 | |
| inserted_count = 0 | |
| for i, doc in enumerate(chunks): | |
| faiss_id = i | |
| metadata_json = json.dumps(doc.metadata, ensure_ascii=False) | |
| source_val = doc.metadata.get('source', '') | |
| regulation_val = doc.metadata.get('regulation', '') | |
| chapter_val = doc.metadata.get('chapter', '') | |
| section_val = doc.metadata.get('section', '') | |
| standard_val = doc.metadata.get('standard', '') | |
| if isinstance(regulation_val, list): regulation_val = ', '.join(map(str, regulation_val)) | |
| if isinstance(chapter_val, list): chapter_val = ', '.join(map(str, chapter_val)) | |
| if isinstance(section_val, list): section_val = ', '.join(map(str, section_val)) | |
| if isinstance(standard_val, list): standard_val = ', '.join(map(str, standard_val)) | |
| doc.metadata['faiss_id'] = faiss_id | |
| cursor.execute( | |
| """ | |
| INSERT OR REPLACE INTO documents | |
| (faiss_id, source, regulation, chapter, section, standard, json_metadata) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, | |
| (faiss_id, source_val, regulation_val, chapter_val, section_val, standard_val, metadata_json) | |
| ) | |
| inserted_count += 1 | |
| # 5. 최종 커밋 | |
| conn.commit() | |
| print(f"✅ SQLite 데이터 저장 완료: 총 {inserted_count}행이 삽입되었습니다.") | |
| except Exception as e: | |
| print(f"🚨 [DB 저장 중 에러 발생] {e}") | |
| # 에러가 나도 traceback을 볼 수 있게 함 | |
| import traceback | |
| traceback.print_exc() | |
| finally: | |
| # 6. 연결 확실히 종료 | |
| conn.close() | |
| # --- LocalSentenceTransformerEmbeddings --- | |
| class LocalSentenceTransformerEmbeddings(Embeddings): | |
| def __init__(self, st_model, normalize_embeddings: bool = True, encode_batch_size: int = 32): | |
| self.model = st_model | |
| self.normalize = normalize_embeddings | |
| self.encode_batch_size = encode_batch_size | |
| def embed_documents(self, texts): | |
| vecs = self.model.encode( | |
| texts, | |
| batch_size=self.encode_batch_size, | |
| show_progress_bar=False, | |
| normalize_embeddings=self.normalize, | |
| convert_to_numpy=True, | |
| ) | |
| return vecs.tolist() | |
| def embed_query(self, text: str): | |
| vec = self.model.encode( | |
| [text], | |
| batch_size=self.encode_batch_size, | |
| show_progress_bar=False, | |
| normalize_embeddings=self.normalize, | |
| convert_to_numpy=True, | |
| )[0] | |
| return vec.tolist() | |
| def load_chunks_from_jsonl(file_paths: Union[str, List[str]]): | |
| """ | |
| JSONL 파일 로드 함수 | |
| """ | |
| if isinstance(file_paths, str): | |
| file_paths = [file_paths] | |
| restored_documents = [] | |
| print(f" 총 {len(file_paths)}개의 파일 병합 로드를 시작합니다...") | |
| for file_path in file_paths: | |
| try: | |
| file_doc_count = 0 | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line_number, line in enumerate(f): | |
| line = line.strip() | |
| if not line: continue | |
| data = json.loads(line) | |
| doc = Document( | |
| page_content=data.get('page_content', ""), | |
| metadata=data.get('metadata', {}) | |
| ) | |
| restored_documents.append(doc) | |
| file_doc_count += 1 | |
| print(f" - [성공] {file_path}: {file_doc_count}개 Chunk") | |
| except Exception as e: | |
| print(f" [실패] 오류 ({file_path}): {e}") | |
| continue | |
| print(f"✅ 전체 로드 완료: 총 {len(restored_documents)}개의 Chunk가 복원되었습니다.") | |
| return restored_documents | |
| # --- save_embedding_system (수정됨: Ensemble 제거 및 개별 반환) --- | |
| def save_embedding_system( | |
| chunks, | |
| persist_directory: str = r"D:/Project AI/RAG", | |
| batch_size: int = 32, | |
| device: str = 'cuda' | |
| ): | |
| Path(persist_directory).mkdir(parents=True, exist_ok=True) | |
| # 1) SQLite DB 저장 | |
| _create_and_populate_sqlite_db(chunks, persist_directory) | |
| # 2) 모델 로드 | |
| model = SentenceTransformer( | |
| 'nomic-ai/nomic-embed-text-v2-moe', | |
| trust_remote_code=True, | |
| device=device | |
| ) | |
| embeddings = LocalSentenceTransformerEmbeddings( | |
| st_model=model, | |
| normalize_embeddings=True, | |
| encode_batch_size=batch_size | |
| ) | |
| # 3) FAISS 생성 | |
| vectorstore = None | |
| for i in range(0, len(chunks), batch_size): | |
| batch = chunks[i:i + batch_size] | |
| if vectorstore is None: | |
| vectorstore = FAISS.from_documents(documents=batch, embedding=embeddings) | |
| else: | |
| vectorstore.add_documents(documents=batch) | |
| gc.collect() | |
| # 4) BM25 생성 (Ensemble 없이 독립 생성) | |
| bm25_retriever = BM25Retriever.from_documents(chunks) | |
| bm25_retriever.k = 5 | |
| # 5) 저장 | |
| vectorstore.save_local(persist_directory) | |
| # 6) 연결 반환 (개별 요소 반환) | |
| sqlite_conn = get_db_connection(persist_directory) | |
| gc.collect() | |
| return bm25_retriever, vectorstore, sqlite_conn | |
| # --- load_embedding_from_faiss (수정됨: Ensemble 제거 및 개별 반환) --- | |
| def load_embedding_from_faiss( | |
| persist_directory: str = r"D:/Project AI/RAG", | |
| top_k: int = 10, | |
| bm25_k: int = 10, | |
| embeddings: Optional[Any] = None, | |
| device: str = 'cpu' | |
| ) -> Tuple[Any, FAISS, sqlite3.Connection]: | |
| if embeddings is None: | |
| st_model = SentenceTransformer( | |
| 'nomic-ai/nomic-embed-text-v2-moe', | |
| trust_remote_code=True, | |
| device=device | |
| ) | |
| embeddings = LocalSentenceTransformerEmbeddings( | |
| st_model=st_model, | |
| normalize_embeddings=True, | |
| encode_batch_size=32 | |
| ) | |
| persist_dir = Path(persist_directory) | |
| if not persist_dir.exists(): | |
| raise FileNotFoundError(f"FAISS 경로가 없습니다: {persist_dir}") | |
| # FAISS 로드 | |
| vectorstore = FAISS.load_local( | |
| folder_path=str(persist_dir), | |
| embeddings=embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| # BM25 복원 (저장된 문서로부터 재생성) | |
| bm25_retriever = None | |
| docs = [] | |
| try: | |
| if hasattr(vectorstore, "docstore") and hasattr(vectorstore.docstore, "_dict"): | |
| docs = list(vectorstore.docstore._dict.values()) | |
| if docs: | |
| bm25_retriever = BM25Retriever.from_documents(docs) | |
| bm25_retriever.k = bm25_k | |
| else: | |
| print("[경고] 저장된 문서를 찾을 수 없어 BM25를 생성하지 못했습니다.") | |
| except Exception as e: | |
| print(f"[경고] 저장된 문서를 읽는 중 문제가 발생했습니다: {e}") | |
| sqlite_conn = get_db_connection(persist_directory) | |
| return bm25_retriever, vectorstore, sqlite_conn | |
| # --- search_vectorstore (단순 벡터 검색 헬퍼) --- | |
| def search_vectorstore(bm25_retriever, vectorstore, query, k=5): | |
| """ | |
| vectorstore와 bm25_retriever를 받아 앙상블(Hybrid) 검색을 수행하는 함수. | |
| EnsembleRetriever(weights=[0.6, 0.4])와 유사한 결과를 반환합니다. | |
| """ | |
| weights=[0.6, 0.4] | |
| # 1. 벡터 검색 수행 (Vector Search) | |
| # FAISS를 리트리버로 변환하여 검색 | |
| vec_retriever = vectorstore.as_retriever(search_kwargs={"k": k}) | |
| vec_docs = vec_retriever.invoke(query) | |
| # 2. 키워드 검색 수행 (BM25 Search) | |
| # 검색 개수를 k개로 맞춰서 실행 | |
| bm25_docs = bm25_retriever.invoke(query, config={"search_kwargs": {"k": k}}) | |
| # 3. 랭킹 퓨전 (Weighted Reciprocal Rank Fusion) | |
| # 두 리스트의 순위를 기반으로 가중치를 적용해 점수를 매깁니다. | |
| doc_scores = {} # 문서 내용(또는 ID) -> 점수 | |
| doc_map = {} # 문서 내용 -> 문서 객체 저장 (나중에 반환하기 위해) | |
| # 내부 함수: 순위에 따른 점수 계산 (Weight / (Rank + 1)) | |
| def apply_rank_score(docs, weight): | |
| for rank, doc in enumerate(docs): | |
| # 고유 키 생성 (page_content가 고유하다고 가정하거나, doc_id가 있다면 사용) | |
| doc_key = doc.page_content | |
| doc_map[doc_key] = doc | |
| if doc_key not in doc_scores: | |
| doc_scores[doc_key] = 0.0 | |
| # 순위가 높을수록(rank가 작을수록) 점수가 높음 | |
| score = weight / (rank + 1) | |
| doc_scores[doc_key] += score | |
| # 벡터 검색 결과 점수 반영 (가중치 0.6) | |
| apply_rank_score(vec_docs, weights[0]) | |
| # BM25 검색 결과 점수 반영 (가중치 0.4) | |
| apply_rank_score(bm25_docs, weights[1]) | |
| # 4. 점수순 정렬 (높은 점수가 상위) | |
| # 점수(item[1])를 기준으로 내림차순 정렬 | |
| sorted_docs = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True) | |
| # 5. Top-K 추출 및 문서 객체 반환 | |
| final_results = [doc_map[key] for key, score in sorted_docs[:k]] | |
| return final_results | |
| # --- search_with_metadata_filter (수정됨: 수동 병합 로직 구현) --- | |
| def search_with_metadata_filter( | |
| bm25_retriever: Any, # [변경] Ensemble 대신 BM25를 직접 받음 | |
| vectorstore: FAISS, | |
| query: str, | |
| k: int = 5, | |
| metadata_filter: Optional[Dict[str, Any]] = None, | |
| sqlite_conn: Optional[sqlite3.Connection] = None | |
| ) -> List[Document]: | |
| """ | |
| SQLite 사전 필터링 -> FAISS 벡터 검색 + BM25 검색 -> 결과 병합 | |
| """ | |
| # === 1. SQLite에서 필터링된 FAISS ID 추출 === | |
| filtered_ids = None | |
| if metadata_filter and sqlite_conn: | |
| cursor = sqlite_conn.cursor() | |
| where_clauses = [] | |
| params = [] | |
| for key, value in metadata_filter.items(): | |
| if isinstance(value, list): | |
| if not value: continue | |
| placeholders = ', '.join(['?'] * len(value)) | |
| where_clauses.append(f"{key} IN ({placeholders})") | |
| params.extend(value) | |
| else: | |
| where_clauses.append(f"{key} = ?") | |
| params.append(value) | |
| if where_clauses: | |
| where_sql = " OR ".join(where_clauses) | |
| sql_query = f"SELECT faiss_id FROM documents WHERE {where_sql}" | |
| try: | |
| cursor.execute(sql_query, params) | |
| filtered_ids = {row[0] for row in cursor.fetchall()} | |
| print(f"[사전 필터링] {len(filtered_ids)}개 ID 획득 → FAISS 검색 제한") | |
| except Exception as e: | |
| print(f"[경고] SQLite 필터링 실패: {e}") | |
| filtered_ids = None | |
| else: | |
| print("[안내] 필터 조건 없음 → 전체 검색") | |
| else: | |
| print("[안내] 필터 또는 DB 없음 → 전체 검색") | |
| # === 2. FAISS 벡터 검색 === | |
| vector_docs = [] | |
| if filtered_ids and len(filtered_ids) > 0: | |
| selector = MetadataIDSelector(filtered_ids) | |
| index: faiss.Index = vectorstore.index | |
| query_embedding = np.array(vectorstore.embeddings.embed_query(query)).astype('float32') | |
| query_embedding = query_embedding.reshape(1, -1) | |
| search_params = faiss.SearchParametersIVF(sel=selector, nprobe=20) | |
| _k = max(k * 10, 100) | |
| D, I = index.search(query_embedding, _k, params=search_params) | |
| valid_indices = [i for i in I[0] if i != -1] | |
| for idx in valid_indices[:k]: | |
| doc_id = vectorstore.index_to_docstore_id[idx] | |
| doc = vectorstore.docstore.search(doc_id) | |
| if isinstance(doc, Document): | |
| vector_docs.append(doc) | |
| print(f"[벡터 검색] {len(valid_indices)}개 후보 → {len(vector_docs)}개 유효") | |
| else: | |
| # 전체 검색 | |
| vector_retriever = vectorstore.as_retriever(search_kwargs={"k": k}) | |
| vector_docs = vector_retriever.invoke(query) | |
| print(f"[벡터 검색] 전체 검색 → {len(vector_docs)}개 후보") | |
| # === 3. BM25 검색 === | |
| bm25_docs = [] | |
| if bm25_retriever: | |
| search_k = k * 5 | |
| candidates = bm25_retriever.invoke(query, config={"search_kwargs": {"k": search_k}}) | |
| if filtered_ids: | |
| bm25_docs = [d for d in candidates if d.metadata.get('faiss_id') in filtered_ids] | |
| else: | |
| bm25_docs = candidates | |
| # Top K 자르기 | |
| bm25_docs = bm25_docs[:k] | |
| print(f"[BM25 검색] {len(candidates)}개 후보 → {len(bm25_docs)}개 필터링 후") | |
| # === 4. 병합 (Vector 우선 + 중복 제거) === | |
| combined = {id(d): d for d in (vector_docs + bm25_docs)}.values() | |
| final_results = list(combined)[:k] | |
| print(f"[최종 결과] {len(final_results)}개 문서 반환") | |
| return final_results | |
| # --- get_unique_metadata_values (빠진 함수 추가) --- | |
| def get_unique_metadata_values( | |
| sqlite_conn: sqlite3.Connection, | |
| key_name: str, | |
| partial_match: Optional[str] = None | |
| ) -> List[str]: | |
| """ | |
| 고유 값 검색 함수. | |
| key_name 인자로 'part', 'subpart', 'section', 'source' 등을 사용할 수 있습니다. | |
| """ | |
| if not sqlite_conn: | |
| return [] | |
| cursor = sqlite_conn.cursor() | |
| # 안전을 위해 key_name은 컬럼명으로 직접 사용 (SQL Injection 주의: 내부 사용 전제) | |
| # 실제 프로덕션에서는 key_name을 화이트리스트로 검증하는 것이 좋습니다. | |
| sql_query = f"SELECT DISTINCT `{key_name}` FROM documents" | |
| params = [] | |
| if partial_match: | |
| sql_query += f" WHERE `{key_name}` LIKE ?" | |
| params.append(f"%{partial_match}%") | |
| try: | |
| cursor.execute(sql_query, params) | |
| unique_values = [row[0] for row in cursor.fetchall() if row[0] is not None] | |
| return unique_values | |
| except Exception as e: | |
| print(f"[에러] 고유 값 검색 실패 ({key_name}): {e}") | |
| return [] | |
| def smart_search_vectorstore( | |
| bm25_retriever, | |
| vectorstore, | |
| query, | |
| k=5, | |
| sqlite_conn=None, | |
| enable_detailed_search=True | |
| ): | |
| """ | |
| 1단계: 하이브리드 검색으로 전체 맥락 파악 | |
| 2단계: 검색된 문서들에서 주된 'regulation'(규정) 파악 | |
| 3단계: 해당 규정으로 필터링하여 심층 검색 수행 | |
| """ | |
| # 1. 기본 검색 (하이브리드 검색 수행) | |
| # 단순 retriever.invoke가 아니라 앞서 정의한 search_vectorstore를 사용하여 | |
| # Vector + BM25 결과를 섞어서 가져옵니다. | |
| try: | |
| basic_results = search_vectorstore(bm25_retriever, vectorstore, query, k=k) | |
| except Exception as e: | |
| print(f"[스마트 검색 오류] 기본 검색 실패: {e}") | |
| return [] | |
| # 상세 검색이 비활성화되었거나 필수 컴포넌트가 없으면 기본 결과 반환 | |
| if not enable_detailed_search or not sqlite_conn: | |
| return basic_results | |
| # 2. 메타데이터 빈도 분석 (타겟 키: 'regulation') | |
| # DB 스키마가 변경되었으므로 regulation_part 대신 regulation을 사용합니다. | |
| target_metadata_key = 'regulation' | |
| extracted_values = [] | |
| for doc in basic_results: | |
| val = doc.metadata.get(target_metadata_key) | |
| if val: | |
| if isinstance(val, list): | |
| extracted_values.extend(val) | |
| elif isinstance(val, str): | |
| # DB 저장 시 ', '로 합쳐진 문자열일 수 있으므로 분리 시도 | |
| if ',' in val: | |
| extracted_values.extend([part.strip() for part in val.split(',')]) | |
| else: | |
| extracted_values.append(val) | |
| if not extracted_values: | |
| # 분석할 메타데이터가 없으면 기본 결과 반환 | |
| return basic_results | |
| counter = Counter(extracted_values) | |
| # 상위 2개의 카테고리(규정) 추출 | |
| most_common_categories = counter.most_common(2) | |
| # 3. 상세 검색 (필터링 검색) | |
| detailed_results = [] | |
| for rank, (category, count) in enumerate(most_common_categories, 1): | |
| # 상위 카테고리에 대해 필터링 조건 생성 | |
| metadata_filter = {target_metadata_key: category} | |
| try: | |
| # 변경된 search_with_metadata_filter 시그니처에 맞춰 호출 | |
| category_results = search_with_metadata_filter( | |
| bm25_retriever=bm25_retriever, # [변경] Ensemble 대신 BM25 전달 | |
| vectorstore=vectorstore, | |
| query=query, | |
| k=k, | |
| metadata_filter=metadata_filter, | |
| sqlite_conn=sqlite_conn | |
| ) | |
| detailed_results.extend(category_results) | |
| print(f"[스마트 검색] '{category}' 집중 검색 → {len(category_results)}개 추가") | |
| except Exception as e: | |
| print(f"[경고] 상세 검색 실패 ({category}): {e}") | |
| continue | |
| # 4. 결과 병합 및 중복 제거 | |
| # (상세 검색 결과 우선 + 기본 검색 결과) | |
| seen = set() | |
| final_results = [] | |
| # 우선순위: 상세 검색 결과 -> 기본 검색 결과 | |
| all_candidates = detailed_results + basic_results | |
| for doc in all_candidates: | |
| # 문서 내용과 메타데이터 문자열을 조합하여 고유 키 생성 | |
| doc_signature = (doc.page_content, str(sorted(doc.metadata.items()))) | |
| if doc_signature not in seen: | |
| seen.add(doc_signature) | |
| final_results.append(doc) | |
| # 최종적으로 k개만 반환 | |
| final_results = final_results[:k] | |
| return final_results | |
| # natural_sort_key 함수 추가 (app.py에서 사용됨) | |
| import re | |
| def natural_sort_key(s): | |
| """자연스러운 정렬을 위한 키 함수""" | |
| return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', str(s))] |