| | |
| | import gradio as gr |
| | import torch |
| | import torch.nn.functional as F |
| | import pytorch_lightning as pl |
| | import os |
| | import json |
| | import logging |
| | from tokenizers import Tokenizer |
| | from huggingface_hub import hf_hub_download |
| | import gc |
| | from rdkit.Chem import CanonSmiles, MolFromSmiles |
| | import spaces |
| | import heapq |
| | import math |
| |
|
| | |
| | MODEL_REPO_ID = ( |
| | "AdrianM0/smiles-to-iupac-translator" |
| | ) |
| | CHECKPOINT_FILENAME = "last.ckpt" |
| | SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json" |
| | IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json" |
| | CONFIG_FILENAME = "config.json" |
| | |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
| | ) |
| |
|
| | |
| | try: |
| | |
| | from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask |
| |
|
| | logging.info("Successfully imported from enhanced_trainer.py.") |
| | except ImportError as e: |
| | logging.error( |
| | f"Failed to import helper code from enhanced_trainer.py: {e}. " |
| | f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'." |
| | ) |
| | raise gr.Error( |
| | f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}" |
| | ) |
| | except Exception as e: |
| | logging.error( |
| | f"An unexpected error occurred during helper code import: {e}", exc_info=True |
| | ) |
| | raise gr.Error( |
| | f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}" |
| | ) |
| |
|
| | |
| | model: pl.LightningModule | None = None |
| | smiles_tokenizer: Tokenizer | None = None |
| | iupac_tokenizer: Tokenizer | None = None |
| | device: torch.device | None = None |
| | config: dict | None = None |
| |
|
| |
|
| | |
| | def greedy_decode( |
| | model: pl.LightningModule, |
| | src: torch.Tensor, |
| | src_padding_mask: torch.Tensor, |
| | max_len: int, |
| | sos_idx: int, |
| | eos_idx: int, |
| | device: torch.device, |
| | ) -> torch.Tensor: |
| | """ |
| | Performs greedy decoding using the LightningModule's model. |
| | Returns a tensor of shape [1, sequence_length]. |
| | """ |
| | model.eval() |
| | transformer_model = model.model |
| |
|
| | try: |
| | with torch.no_grad(): |
| | memory = transformer_model.encode(src, src_padding_mask) |
| | memory = memory.to(device) |
| | memory_key_padding_mask = src_padding_mask.to(memory.device) |
| |
|
| | ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) |
| |
|
| | for _ in range(max_len - 1): |
| | tgt_seq_len = ys.shape[1] |
| | tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) |
| | tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device) |
| |
|
| | decoder_output = transformer_model.decode( |
| | tgt=ys, |
| | memory=memory, |
| | tgt_mask=tgt_mask, |
| | tgt_padding_mask=tgt_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask, |
| | ) |
| |
|
| | next_token_logits = transformer_model.generator(decoder_output[:, -1, :]) |
| | next_word_id = torch.argmax(next_token_logits, dim=1).item() |
| |
|
| | ys = torch.cat( |
| | [ |
| | ys, |
| | torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id), |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | if next_word_id == eos_idx: |
| | break |
| |
|
| | return ys[:, 1:] |
| |
|
| | except RuntimeError as e: |
| | logging.error(f"Runtime error during greedy decode: {e}", exc_info=True) |
| | if "CUDA out of memory" in str(e) and device.type == "cuda": |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | return torch.empty((1, 0), dtype=torch.long, device=device) |
| | except Exception as e: |
| | logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True) |
| | return torch.empty((1, 0), dtype=torch.long, device=device) |
| |
|
| |
|
| | |
| | def beam_search_decode( |
| | model: pl.LightningModule, |
| | src: torch.Tensor, |
| | src_padding_mask: torch.Tensor, |
| | max_len: int, |
| | sos_idx: int, |
| | eos_idx: int, |
| | pad_idx: int, |
| | device: torch.device, |
| | beam_width: int, |
| | num_return_sequences: int = 1, |
| | length_penalty_alpha: float = 0.6, |
| | ) -> list[tuple[torch.Tensor, float]]: |
| | """ |
| | Performs beam search decoding. |
| | Returns a list of tuples: (sequence_tensor [1, seq_len], score) |
| | """ |
| | model.eval() |
| | transformer_model = model.model |
| | num_return_sequences = min(beam_width, num_return_sequences) |
| |
|
| | try: |
| | with torch.no_grad(): |
| | |
| | memory = transformer_model.encode(src, src_padding_mask) |
| | memory = memory.to(device) |
| | memory_key_padding_mask = src_padding_mask.to(memory.device) |
| |
|
| | |
| | |
| | initial_beam = ( |
| | torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx), |
| | 0.0, |
| | ) |
| | beams = [initial_beam] |
| | finished_hypotheses = [] |
| |
|
| | |
| | for step in range(max_len - 1): |
| | if not beams: |
| | break |
| |
|
| | |
| | |
| | candidates = [] |
| |
|
| | |
| | |
| | for beam_idx, (current_seq, current_score) in enumerate(beams): |
| | if current_seq[0, -1].item() == eos_idx: |
| | |
| | penalty = ((current_seq.shape[1]) ** length_penalty_alpha) |
| | final_score = current_score / penalty if penalty > 0 else current_score |
| | heapq.heappush(finished_hypotheses, (final_score, current_seq)) |
| | |
| | while len(finished_hypotheses) > beam_width: |
| | heapq.heappop(finished_hypotheses) |
| | continue |
| |
|
| | |
| | ys = current_seq |
| | tgt_seq_len = ys.shape[1] |
| | tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) |
| | tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device) |
| |
|
| | |
| | |
| | decoder_output = transformer_model.decode( |
| | tgt=ys, |
| | memory=memory, |
| | tgt_mask=tgt_mask, |
| | tgt_padding_mask=tgt_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask, |
| | ) |
| |
|
| | |
| | next_token_logits = transformer_model.generator( |
| | decoder_output[:, -1, :] |
| | ) |
| |
|
| | |
| | log_probs = F.log_softmax(next_token_logits, dim=-1) |
| |
|
| | |
| | |
| | top_k_log_probs, top_k_indices = torch.topk(log_probs + current_score, beam_width, dim=1) |
| |
|
| | |
| | for i in range(beam_width): |
| | token_id = top_k_indices[0, i].item() |
| | score = top_k_log_probs[0, i].item() |
| | |
| | heapq.heappush(candidates, (-score, token_id, beam_idx)) |
| | |
| | |
| |
|
| | |
| | new_beams = [] |
| | |
| | num_candidates_to_consider = min(len(candidates), beam_width * len(beams)) |
| | |
| | |
| | top_candidates = heapq.nsmallest(beam_width, candidates) |
| |
|
| | added_sequences = set() |
| |
|
| | for neg_score, token_id, beam_idx in top_candidates: |
| | original_seq, _ = beams[beam_idx] |
| | new_seq = torch.cat( |
| | [ |
| | original_seq, |
| | torch.ones(1, 1, dtype=torch.long, device=device).fill_(token_id), |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | |
| | seq_tuple = tuple(new_seq.flatten().tolist()) |
| | if seq_tuple not in added_sequences: |
| | new_beams.append((new_seq, -neg_score)) |
| | added_sequences.add(seq_tuple) |
| |
|
| | beams = new_beams |
| |
|
| | |
| | if finished_hypotheses: |
| | |
| | best_active_score = -heapq.nsmallest(1, candidates)[0][0] if candidates else -float('inf') |
| | worst_finished_score = finished_hypotheses[0][0] |
| | if len(finished_hypotheses) >= num_return_sequences and best_active_score < worst_finished_score: |
| | logging.debug(f"Beam search early stopping at step {step}") |
| | break |
| |
|
| |
|
| | |
| | |
| | for seq, score in beams: |
| | if seq[0, -1].item() != eos_idx: |
| | penalty = ((seq.shape[1]) ** length_penalty_alpha) |
| | final_score = score / penalty if penalty > 0 else score |
| | heapq.heappush(finished_hypotheses, (final_score, seq)) |
| | while len(finished_hypotheses) > beam_width: |
| | heapq.heappop(finished_hypotheses) |
| |
|
| | |
| | |
| | top_hypotheses = heapq.nlargest(num_return_sequences, finished_hypotheses) |
| |
|
| | |
| | return [(seq[:, 1:], score) for score, seq in top_hypotheses] |
| |
|
| | except RuntimeError as e: |
| | logging.error(f"Runtime error during beam search: {e}", exc_info=True) |
| | if "CUDA out of memory" in str(e) and device.type == "cuda": |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | return [] |
| | except Exception as e: |
| | logging.error(f"Unexpected error during beam search: {e}", exc_info=True) |
| | return [] |
| |
|
| |
|
| | |
| | def translate( |
| | model: pl.LightningModule, |
| | src_sentence: str, |
| | smiles_tokenizer: Tokenizer, |
| | iupac_tokenizer: Tokenizer, |
| | device: torch.device, |
| | max_len: int, |
| | sos_idx: int, |
| | eos_idx: int, |
| | pad_idx: int, |
| | decoding_strategy: str = "Greedy", |
| | beam_width: int = 5, |
| | num_return_sequences: int = 1, |
| | length_penalty_alpha: float = 0.6, |
| | ) -> list[tuple[str, float]]: |
| | """ |
| | Translates a single SMILES string using the specified decoding strategy. |
| | """ |
| | model.eval() |
| |
|
| | |
| | try: |
| | smiles_tokenizer.enable_truncation(max_length=max_len) |
| | src_encoded = smiles_tokenizer.encode(src_sentence) |
| | if not src_encoded or not src_encoded.ids: |
| | logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}") |
| | return [("[Encoding Error]", 0.0)] |
| | src_ids = src_encoded.ids |
| | except Exception as e: |
| | logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True) |
| | return [("[Encoding Error]", 0.0)] |
| |
|
| | |
| | src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) |
| | src_padding_mask = (src == pad_idx).to(device) |
| |
|
| | |
| | generation_max_len = config.get("max_len", 256) |
| | results = [] |
| |
|
| | if decoding_strategy == "Greedy": |
| | tgt_tokens_tensor = greedy_decode( |
| | model=model, |
| | src=src, |
| | src_padding_mask=src_padding_mask, |
| | max_len=generation_max_len, |
| | sos_idx=sos_idx, |
| | eos_idx=eos_idx, |
| | device=device, |
| | ) |
| | if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0: |
| | results = [(tgt_tokens_tensor, 0.0)] |
| | else: |
| | logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}") |
| | return [("[Decoding Error - Empty Output]", 0.0)] |
| |
|
| | elif decoding_strategy == "Beam Search": |
| | results = beam_search_decode( |
| | model=model, |
| | src=src, |
| | src_padding_mask=src_padding_mask, |
| | max_len=generation_max_len, |
| | sos_idx=sos_idx, |
| | eos_idx=eos_idx, |
| | pad_idx=pad_idx, |
| | device=device, |
| | beam_width=beam_width, |
| | num_return_sequences=num_return_sequences, |
| | length_penalty_alpha=length_penalty_alpha, |
| | ) |
| | if not results: |
| | logging.warning(f"Beam search returned no results for SMILES: {src_sentence}") |
| | return [("[Decoding Error - Empty Output]", 0.0)] |
| | else: |
| | logging.error(f"Unknown decoding strategy: {decoding_strategy}") |
| | return [("[Error: Unknown Strategy]", 0.0)] |
| |
|
| |
|
| | |
| | translations = [] |
| | for tgt_tokens_tensor, score in results: |
| | if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0: |
| | translations.append(("[Decoding Error - Empty Sequence]", score)) |
| | continue |
| |
|
| | tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist() |
| | try: |
| | |
| | translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True) |
| | translations.append((translation, score)) |
| | except Exception as e: |
| | logging.error( |
| | f"Error decoding target tokens {tgt_tokens}: {e}", |
| | exc_info=True, |
| | ) |
| | translations.append(("[Decoding Error]", score)) |
| |
|
| | return translations |
| |
|
| |
|
| | |
| | def load_model_and_tokenizers(): |
| | """Loads tokenizers, config, and model from Hugging Face Hub.""" |
| | global model, smiles_tokenizer, iupac_tokenizer, device, config |
| | if model is not None: |
| | logging.info("Model and tokenizers already loaded.") |
| | return |
| |
|
| | logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...") |
| | try: |
| | |
| | |
| | |
| | |
| | |
| | device = torch.device("cpu") |
| | logging.info("Using CPU. Modify code to enable GPU if available and desired.") |
| |
|
| | |
| | logging.info("Downloading files from Hugging Face Hub...") |
| | cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache") |
| | os.makedirs(cache_dir, exist_ok=True) |
| | logging.info(f"Using cache directory: {cache_dir}") |
| |
|
| | try: |
| | checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir) |
| | smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir) |
| | iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir) |
| | config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir) |
| | |
| | try: |
| | hf_hub_download(repo_id=MODEL_REPO_ID, filename="enhanced_trainer.py", cache_dir=cache_dir, local_dir=".") |
| | logging.info("Downloaded enhanced_trainer.py") |
| | except Exception as download_err: |
| | if os.path.exists("enhanced_trainer.py"): |
| | logging.warning(f"Could not download enhanced_trainer.py (maybe private?), but found local file. Using local. Error: {download_err}") |
| | else: |
| | raise download_err |
| |
|
| | logging.info("Files downloaded successfully.") |
| | except Exception as e: |
| | logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True) |
| | raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}") |
| |
|
| | |
| | logging.info("Loading configuration...") |
| | try: |
| | with open(config_path, "r") as f: |
| | config = json.load(f) |
| | logging.info("Configuration loaded.") |
| | required_keys = ["src_vocab_size", "tgt_vocab_size", "emb_size", "nhead", "ffn_hid_dim", "num_encoder_layers", "num_decoder_layers", "dropout", "max_len", "pad_token_id", "bos_token_id", "eos_token_id"] |
| | missing_keys = [key for key in required_keys if config.get(key) is None] |
| | if missing_keys: |
| | raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}.") |
| | logging.info(f"Using config: { {k: config.get(k) for k in required_keys} }") |
| | except Exception as e: |
| | logging.error(f"Error loading or validating config: {e}", exc_info=True) |
| | raise gr.Error(f"Config Error: {e}") |
| |
|
| |
|
| | |
| | logging.info("Loading tokenizers...") |
| | try: |
| | smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path) |
| | iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path) |
| | |
| | if smiles_tokenizer.get_vocab_size() != config['src_vocab_size']: |
| | logging.warning(f"SMILES vocab size mismatch: Tokenizer={smiles_tokenizer.get_vocab_size()}, Config={config['src_vocab_size']}") |
| | if iupac_tokenizer.get_vocab_size() != config['tgt_vocab_size']: |
| | logging.warning(f"IUPAC vocab size mismatch: Tokenizer={iupac_tokenizer.get_vocab_size()}, Config={config['tgt_vocab_size']}") |
| | logging.info("Tokenizers loaded.") |
| | except Exception as e: |
| | logging.error(f"Failed to load tokenizers: {e}", exc_info=True) |
| | raise gr.Error(f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}") |
| |
|
| | |
| | logging.info("Loading model from checkpoint...") |
| | try: |
| | |
| | |
| | model_hparams = config.copy() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | expected_args = SmilesIupacLitModule.__init__.__code__.co_varnames |
| | hparams_to_pass = {k: v for k, v in model_hparams.items() if k in expected_args} |
| | logging.info(f"Passing hparams to LitModule: {hparams_to_pass.keys()}") |
| |
|
| |
|
| | model = SmilesIupacLitModule.load_from_checkpoint( |
| | checkpoint_path, |
| | map_location=device, |
| | |
| | strict=False, |
| | **hparams_to_pass |
| | ) |
| |
|
| | model.to(device) |
| | model.eval() |
| | model.freeze() |
| | logging.info(f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'.") |
| |
|
| | except FileNotFoundError: |
| | logging.error(f"Checkpoint file not found: {checkpoint_path}") |
| | raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.") |
| | except Exception as e: |
| | logging.error(f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True) |
| | if "size mismatch" in str(e): |
| | error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint." |
| | logging.error(error_detail) |
| | raise gr.Error(f"Model Error: {error_detail} Original error: {e}") |
| | elif "unexpected keyword argument" in str(e) or "missing 1 required positional argument" in str(e): |
| | error_detail = f"Mismatch between config.json keys and SmilesIupacLitModule constructor arguments. Check enhanced_trainer.py and config.json. Error: {e}" |
| | logging.error(error_detail) |
| | raise gr.Error(f"Model Error: {error_detail}") |
| | elif "memory" in str(e).lower(): |
| | logging.warning("Potential OOM error during model loading.") |
| | gc.collect() |
| | torch.cuda.empty_cache() if device.type == "cuda" else None |
| | raise gr.Error(f"Model Error: OOM loading model. Check Space resources. Error: {e}") |
| | else: |
| | raise gr.Error(f"Model Error: Failed to load checkpoint. Check logs. Error: {e}") |
| |
|
| | except gr.Error: |
| | raise |
| | except Exception as e: |
| | logging.error(f"Unexpected error during loading: {e}", exc_info=True) |
| | raise gr.Error(f"Initialization Error: Unexpected error. Check logs. Error: {e}") |
| |
|
| |
|
| | |
| | @spaces.GPU |
| | def predict_iupac(smiles_string, decoding_strategy, num_beams, num_return_sequences): |
| | """ |
| | Performs SMILES to IUPAC translation using the loaded model and selected strategy. |
| | """ |
| | global model, smiles_tokenizer, iupac_tokenizer, device, config |
| |
|
| | |
| | if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]): |
| | error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs." |
| | logging.error(error_msg) |
| | return f"Initialization Error: {error_msg}" |
| |
|
| | if not smiles_string or not smiles_string.strip(): |
| | return "Error: Please enter a valid SMILES string." |
| |
|
| | smiles_input = smiles_string.strip() |
| |
|
| | |
| | try: |
| | mol = MolFromSmiles(smiles_input) |
| | if mol is None: |
| | return f"Error: Invalid SMILES string provided: '{smiles_input}'" |
| | smiles_input = CanonSmiles(smiles_input) |
| | logging.info(f"Canonical SMILES: {smiles_input}") |
| | except Exception as e: |
| | logging.error(f"Error during SMILES validation/canonicalization: {e}", exc_info=True) |
| | return f"Error: Could not process SMILES string '{smiles_input}'. RDKit error: {e}" |
| |
|
| | |
| | if decoding_strategy == "Beam Search": |
| | if not isinstance(num_beams, int) or num_beams <= 0: |
| | return "Error: Beam width must be a positive integer." |
| | if not isinstance(num_return_sequences, int) or num_return_sequences <= 0: |
| | return "Error: Number of return sequences must be a positive integer." |
| | if num_return_sequences > num_beams: |
| | return f"Error: Number of return sequences ({num_return_sequences}) cannot exceed beam width ({num_beams})." |
| | else: |
| | |
| | num_beams = 1 |
| | num_return_sequences = 1 |
| |
|
| |
|
| | try: |
| | |
| | sos_idx = config["bos_token_id"] |
| | eos_idx = config["eos_token_id"] |
| | pad_idx = config["pad_token_id"] |
| | gen_max_len = config["max_len"] |
| | |
| | length_penalty = 0.6 |
| |
|
| | predicted_results = translate( |
| | model=model, |
| | src_sentence=smiles_input, |
| | smiles_tokenizer=smiles_tokenizer, |
| | iupac_tokenizer=iupac_tokenizer, |
| | device=device, |
| | max_len=gen_max_len, |
| | sos_idx=sos_idx, |
| | eos_idx=eos_idx, |
| | pad_idx=pad_idx, |
| | decoding_strategy=decoding_strategy, |
| | beam_width=num_beams, |
| | num_return_sequences=num_return_sequences, |
| | length_penalty_alpha=length_penalty, |
| | ) |
| | logging.info(f"Prediction returned {len(predicted_results)} result(s). Strategy: {decoding_strategy}, Beams: {num_beams}, Return: {num_return_sequences}") |
| |
|
| | |
| | output_lines = [] |
| | output_lines.append(f"Input SMILES: {smiles_input}") |
| | output_lines.append(f"Decoding Strategy: {decoding_strategy}") |
| | if decoding_strategy == "Beam Search": |
| | output_lines.append(f"Beam Width: {num_beams}") |
| | output_lines.append(f"Returned Sequences: {len(predicted_results)}") |
| | output_lines.append(f"Length Penalty Alpha: {length_penalty:.2f}") |
| |
|
| |
|
| | output_lines.append("\n--- Predictions ---") |
| |
|
| | if not predicted_results: |
| | output_lines.append("No predictions generated.") |
| | else: |
| | for i, (name, score) in enumerate(predicted_results): |
| | if "[Error]" in name or not name: |
| | output_lines.append(f"{i+1}. Prediction Failed: {name}") |
| | else: |
| | score_info = f"(Score: {score:.4f})" if decoding_strategy == "Beam Search" else "" |
| | output_lines.append(f"{i+1}. {name} {score_info}") |
| |
|
| | return "\n".join(output_lines) |
| |
|
| | except RuntimeError as e: |
| | logging.error(f"Runtime error during translation: {e}", exc_info=True) |
| | gc.collect() |
| | if device.type == 'cuda': torch.cuda.empty_cache() |
| | return f"Runtime Error during translation: {e}. Check logs." |
| | except Exception as e: |
| | logging.error(f"Unexpected error during translation: {e}", exc_info=True) |
| | return f"Unexpected Error during translation: {e}. Check logs." |
| |
|
| |
|
| | |
| | try: |
| | load_model_and_tokenizers() |
| | except gr.Error as ge: |
| | |
| | logging.error(f"Gradio Initialization Error during load: {ge}") |
| | |
| | |
| | except Exception as e: |
| | logging.error(f"Critical error during initial model loading: {e}", exc_info=True) |
| | |
| |
|
| |
|
| | |
| | title = "SMILES to IUPAC Name Translator" |
| | description = f""" |
| | Translate a SMILES string into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}). |
| | Choose between **Greedy Decoding** (fastest, picks the most likely next word) and **Beam Search Decoding** (explores multiple possibilities, potentially better results, slower). |
| | **Note:** Model loaded on **{str(device).upper() if device else 'N/A'}**. Beam search can be slow, especially with larger beam widths. |
| | Check `config.json` in the repo for model details. SMILES input will be canonicalized using RDKit. |
| | """ |
| |
|
| | |
| | with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface: |
| | gr.Markdown(f"# {title}") |
| | gr.Markdown(description) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | smiles_input = gr.Textbox( |
| | label="SMILES String", |
| | placeholder="Enter SMILES string (e.g., CCO, c1ccccc1)", |
| | lines=2, |
| | ) |
| | with gr.Accordion("Decoding Options", open=False): |
| | decode_strategy = gr.Radio( |
| | ["Greedy", "Beam Search"], |
| | label="Decoding Strategy", |
| | value="Greedy", |
| | info="Greedy is faster, Beam Search may be more accurate." |
| | ) |
| | beam_width_slider = gr.Slider( |
| | minimum=1, |
| | maximum=20, |
| | step=1, |
| | value=5, |
| | label="Beam Width", |
| | info="Number of beams to explore (Beam Search only)", |
| | visible=False |
| | ) |
| | num_seq_slider = gr.Slider( |
| | minimum=1, |
| | maximum=5, |
| | step=1, |
| | value=1, |
| | label="Number of Results", |
| | info="How many sequences to return (Beam Search only)", |
| | visible=False |
| | ) |
| |
|
| | submit_btn = gr.Button("Translate", variant="primary") |
| |
|
| | |
| | def update_beam_options(strategy): |
| | is_beam = strategy == "Beam Search" |
| | return { |
| | beam_width_slider: gr.update(visible=is_beam), |
| | num_seq_slider: gr.update(visible=is_beam) |
| | } |
| |
|
| | decode_strategy.change( |
| | fn=update_beam_options, |
| | inputs=decode_strategy, |
| | outputs=[beam_width_slider, num_seq_slider] |
| | ) |
| |
|
| |
|
| | with gr.Column(scale=2): |
| | output_text = gr.Textbox( |
| | label="Translation Results", |
| | lines=10, |
| | show_copy_button=True, |
| | |
| | ) |
| |
|
| | |
| | submit_btn.click( |
| | fn=predict_iupac, |
| | inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], |
| | outputs=output_text, |
| | api_name="translate_smiles" |
| | ) |
| |
|
| | |
| | smiles_input.submit( |
| | fn=predict_iupac, |
| | inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], |
| | outputs=output_text |
| | ) |
| |
|
| | |
| | gr.Examples( |
| | examples=[ |
| | ["CCO", "Greedy", 1, 1], |
| | ["c1ccccc1", "Greedy", 1, 1], |
| | ["CC(C)Br", "Beam Search", 5, 3], |
| | ["C[C@H](O)c1ccccc1", "Beam Search", 10, 5], |
| | ["INVALID_SMILES", "Greedy", 1, 1], |
| | ["N#CC(C)(C)OC(=O)C(C)=C", "Beam Search", 8, 2] |
| | ], |
| | inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], |
| | outputs=output_text, |
| | fn=predict_iupac, |
| | cache_examples=False, |
| | label="Example SMILES & Settings" |
| | ) |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | iface.launch() |