| |
| |
| |
| |
|
|
| """ |
| RND1 Generation Utilities. |
| |
| This module provides generation utilities and mixins for RND1 models, |
| including the main GenerationMixin class that integrates with HuggingFace. |
| """ |
|
|
| import torch |
| from typing import Optional, Union, Dict, Any |
| from transformers import GenerationMixin as HFGenerationMixin |
| from transformers.generation import GenerationConfig |
|
|
| from .generation_config import RND1GenerationConfig |
| from .sampling import diffusion_sample |
|
|
|
|
| class RND1GenerationMixin(HFGenerationMixin): |
| """ |
| Generation mixin for RND1 models. |
| |
| This mixin provides generation methods compatible with HuggingFace's |
| generation API while using RND1's diffusion-based sampling internally. |
| """ |
|
|
| def generate( |
| self, |
| inputs: Optional[torch.LongTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| |
| prefix_ids: Optional[torch.LongTensor] = None, |
| suffix_ids: Optional[torch.LongTensor] = None, |
| infill_length: Optional[int] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[torch.LongTensor, Dict[str, Any]]: |
| """ |
| Generate text using RND1's diffusion-based sampling. |
| |
| Follows HuggingFace's standard generate API, using diffusion sampling |
| internally. Supports both standard generation and infilling. |
| |
| Args: |
| inputs: Input token IDs to use as prefix (standard HF parameter) |
| generation_config: Generation configuration object. Default is RND1GenerationConfig. |
| prefix_ids: Alternative to inputs for infilling tasks |
| suffix_ids: Optional suffix for infilling tasks |
| infill_length: Length of infill region (for infilling) |
| return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput |
| **kwargs: Additional arguments (accepted for compatibility). These will be passed to the config constructor. |
| |
| Returns: |
| Generated token IDs or GenerateDecoderOnlyOutput |
| """ |
| if generation_config is not None: |
| gen_config = generation_config |
| model_kwargs = kwargs.copy() |
| else: |
| |
| gen_config, model_kwargs = self._prepare_generation_config(RND1GenerationConfig(), **kwargs) |
|
|
| device = next(self.parameters()).device |
|
|
| if inputs is not None: |
| prefix_ids = inputs.to(device) |
| elif prefix_ids is not None: |
| prefix_ids = prefix_ids.to(device) |
| else: |
| prefix_ids = None |
|
|
| if suffix_ids is not None: |
| suffix_ids = suffix_ids.to(device) |
|
|
| eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645) |
| pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", 151643) |
| bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None) |
| mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669)) |
|
|
| if infill_length is not None and prefix_ids is not None: |
| |
| prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0 |
| suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0 |
| seq_len = prefix_len + infill_length + suffix_len |
| else: |
| |
| if prefix_ids is not None: |
| prefix_len = prefix_ids.shape[1] |
| if gen_config.max_new_tokens is not None: |
| seq_len = prefix_len + gen_config.max_new_tokens |
| else: |
| seq_len = gen_config.max_length or self.config.max_position_embeddings |
| else: |
| seq_len = gen_config.max_length or self.config.max_position_embeddings |
|
|
| num_diffusion_steps = getattr(gen_config, "num_diffusion_steps", |
| getattr(self.config, "num_diffusion_steps", 256)) |
|
|
| temperature = float(getattr(gen_config, "temperature", 1.0)) |
| top_k = getattr(gen_config, "top_k", None) |
| top_p = getattr(gen_config, "top_p", None) |
|
|
| greedy = getattr(gen_config, "greedy", |
| not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True) |
|
|
|
|
| with torch.inference_mode(): |
| sequences = diffusion_sample( |
| model=self, |
| seq_len=seq_len, |
| num_steps=num_diffusion_steps, |
| mask_token_id=mask_token_id, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| greedy=greedy, |
| prefix_ids=prefix_ids, |
| suffix_ids=suffix_ids, |
| infill_length=infill_length, |
| eos_token_id=eos_token_id, |
| pad_token_id=pad_token_id, |
| bos_token_id=bos_token_id, |
| device=device, |
| visualizer=model_kwargs.get("visualizer", None), |
| ) |
|
|
| if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False): |
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
| return GenerateDecoderOnlyOutput(sequences=sequences) |
|
|
| return sequences |
|
|
| def generate_with_visualization( |
| self, |
| tokenizer, |
| inputs: Optional[torch.LongTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| suffix_ids: Optional[torch.LongTensor] = None, |
| infill_length: Optional[int] = None, |
| **kwargs, |
| ) -> torch.LongTensor: |
| """ |
| Generate with live visualization (for demos). |
| |
| This method requires a tokenizer to display the generation process. |
| For production use, prefer `generate()`. |
| |
| Args: |
| tokenizer: Tokenizer for decoding tokens to text |
| inputs: Input token IDs to use as prefix |
| generation_config: Generation configuration object |
| suffix_ids: Optional suffix token IDs |
| infill_length: Length of infill region |
| **kwargs: Additional arguments for backward compatibility |
| |
| Returns: |
| Generated token IDs as LongTensor |
| """ |
| from .terminal_visualizer import TerminalVisualizer |
| visualizer = TerminalVisualizer(tokenizer, show_visualization=True) |
|
|
| return self.generate( |
| inputs=inputs, |
| generation_config=generation_config, |
| suffix_ids=suffix_ids, |
| infill_length=infill_length, |
| visualizer=visualizer, |
| return_dict_in_generate=False, |
| **kwargs, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| **kwargs, |
| ) -> Dict[str, Any]: |
| """ |
| Prepare inputs for generation (required by HuggingFace). |
| |
| For RND1, we don't use the standard autoregressive generation, |
| so this just returns the input_ids. |
| """ |
| return {"input_ids": input_ids} |