| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | import logging |
| | import floret |
| | import os |
| | from huggingface_hub import hf_hub_download |
| | from .configuration_lang import ImpressoConfig |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class LangDetectorModel(PreTrainedModel): |
| | config_class = ImpressoConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | |
| | self.dummy_param = nn.Parameter(torch.zeros(1)) |
| | bin_filename = self.config.config.filename |
| |
|
| | |
| | if not os.path.exists(bin_filename): |
| | |
| | bin_filename = hf_hub_download(repo_id=self.config.config._name_or_path, |
| | filename=bin_filename) |
| |
|
| | |
| | self.model_floret = floret.load_model(bin_filename) |
| |
|
| | def forward(self, input_ids, **kwargs): |
| | if isinstance(input_ids, str): |
| | |
| | texts = [input_ids] |
| | elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): |
| | texts = input_ids |
| | else: |
| | raise ValueError(f"Unexpected input type: {type(input_ids)}") |
| |
|
| | predictions, probabilities = self.model_floret.predict(texts, k=1) |
| | return ( |
| | predictions, |
| | probabilities, |
| | ) |
| |
|
| | @property |
| | def device(self): |
| | return next(self.parameters()).device |
| |
|
| | @classmethod |
| | def from_pretrained(cls, *args, **kwargs): |
| | |
| | |
| | config = ImpressoConfig(**kwargs) |
| | |
| | model = cls(config) |
| | return model |
| |
|