Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # torch_uint32_patch.py - Direct patch for PyTorch UInt32Storage issue | |
| # This standalone patch is designed to be imported before any torch imports | |
| def patch_torch_for_uint32storage(): | |
| """ | |
| Apply comprehensive patches to PyTorch to handle UInt32Storage | |
| This must be imported and called before any torch operations | |
| """ | |
| import sys | |
| import pickle | |
| import importlib.util | |
| # Check if torch is already imported | |
| if "torch" in sys.modules: | |
| print("WARNING: torch is already imported, patching may be less effective") | |
| # Patch the pickle machinery first | |
| original_find_class = pickle.Unpickler.find_class | |
| def patched_find_class(self, module_name, name): | |
| if module_name == 'torch' and name == 'UInt32Storage': | |
| # Load torch directly to avoid circular imports | |
| import torch | |
| return torch.FloatStorage | |
| return original_find_class(self, module_name, name) | |
| pickle.Unpickler.find_class = patched_find_class | |
| # Now import torch and apply direct patches | |
| import torch | |
| # Create UInt32Storage class | |
| if not hasattr(torch, 'UInt32Storage'): | |
| # Use FloatStorage which seems more likely to work than IntStorage | |
| torch.UInt32Storage = torch.FloatStorage | |
| # Register in sys.modules directly | |
| sys.modules['torch.UInt32Storage'] = torch.FloatStorage | |
| # Patch _C module | |
| if hasattr(torch, '_C'): | |
| if hasattr(torch._C, '_FloatStorage'): | |
| setattr(torch._C, '_UInt32Storage', torch._C._FloatStorage) | |
| # Patch the torch.jit loading system | |
| if hasattr(torch.jit, 'load'): | |
| original_jit_load = torch.jit.load | |
| def patched_jit_load(*args, **kwargs): | |
| try: | |
| return original_jit_load(*args, **kwargs) | |
| except RuntimeError as e: | |
| if "UInt32Storage" in str(e): | |
| # Force the UInt32Storage patch again just to be sure | |
| torch.UInt32Storage = torch.FloatStorage | |
| sys.modules['torch.UInt32Storage'] = torch.FloatStorage | |
| # Try again | |
| return original_jit_load(*args, **kwargs) | |
| # For BlendShapeBase errors | |
| elif "BlendShapeBase' already defined" in str(e) and 'pymomentum' in sys.modules: | |
| try: | |
| # Try to reload the module | |
| importlib.reload(sys.modules['pymomentum']) | |
| return original_jit_load(*args, **kwargs) | |
| except: | |
| pass | |
| # Re-raise if not our specific error | |
| raise | |
| torch.jit.load = patched_jit_load | |
| # Low-level patching for _C.import_ir_module | |
| if hasattr(torch, '_C') and hasattr(torch._C, 'import_ir_module'): | |
| original_import_ir = torch._C.import_ir_module | |
| def patched_import_ir(*args, **kwargs): | |
| try: | |
| return original_import_ir(*args, **kwargs) | |
| except RuntimeError as e: | |
| error_str = str(e) | |
| if "UInt32Storage" in error_str: | |
| # Apply emergency patching | |
| torch.UInt32Storage = torch.FloatStorage | |
| sys.modules['torch.UInt32Storage'] = torch.FloatStorage | |
| setattr(torch._C, '_UInt32Storage', torch._C._FloatStorage) | |
| return original_import_ir(*args, **kwargs) | |
| elif "BlendShapeBase' already defined" in error_str and 'pymomentum' in sys.modules: | |
| try: | |
| # Try to reload the module | |
| importlib.reload(sys.modules['pymomentum']) | |
| return original_import_ir(*args, **kwargs) | |
| except: | |
| pass | |
| raise | |
| torch._C.import_ir_module = patched_import_ir | |
| print("PyTorch UInt32Storage patch applied successfully") | |
| # Execute patch immediately when imported | |
| patch_torch_for_uint32storage() |