#!/usr/bin/env python3 # torch_uint32_patch.py - Direct patch for PyTorch UInt32Storage issue # This standalone patch is designed to be imported before any torch imports def patch_torch_for_uint32storage(): """ Apply comprehensive patches to PyTorch to handle UInt32Storage This must be imported and called before any torch operations """ import sys import pickle import importlib.util # Check if torch is already imported if "torch" in sys.modules: print("WARNING: torch is already imported, patching may be less effective") # Patch the pickle machinery first original_find_class = pickle.Unpickler.find_class def patched_find_class(self, module_name, name): if module_name == 'torch' and name == 'UInt32Storage': # Load torch directly to avoid circular imports import torch return torch.FloatStorage return original_find_class(self, module_name, name) pickle.Unpickler.find_class = patched_find_class # Now import torch and apply direct patches import torch # Create UInt32Storage class if not hasattr(torch, 'UInt32Storage'): # Use FloatStorage which seems more likely to work than IntStorage torch.UInt32Storage = torch.FloatStorage # Register in sys.modules directly sys.modules['torch.UInt32Storage'] = torch.FloatStorage # Patch _C module if hasattr(torch, '_C'): if hasattr(torch._C, '_FloatStorage'): setattr(torch._C, '_UInt32Storage', torch._C._FloatStorage) # Patch the torch.jit loading system if hasattr(torch.jit, 'load'): original_jit_load = torch.jit.load def patched_jit_load(*args, **kwargs): try: return original_jit_load(*args, **kwargs) except RuntimeError as e: if "UInt32Storage" in str(e): # Force the UInt32Storage patch again just to be sure torch.UInt32Storage = torch.FloatStorage sys.modules['torch.UInt32Storage'] = torch.FloatStorage # Try again return original_jit_load(*args, **kwargs) # For BlendShapeBase errors elif "BlendShapeBase' already defined" in str(e) and 'pymomentum' in sys.modules: try: # Try to reload the module importlib.reload(sys.modules['pymomentum']) return original_jit_load(*args, **kwargs) except: pass # Re-raise if not our specific error raise torch.jit.load = patched_jit_load # Low-level patching for _C.import_ir_module if hasattr(torch, '_C') and hasattr(torch._C, 'import_ir_module'): original_import_ir = torch._C.import_ir_module def patched_import_ir(*args, **kwargs): try: return original_import_ir(*args, **kwargs) except RuntimeError as e: error_str = str(e) if "UInt32Storage" in error_str: # Apply emergency patching torch.UInt32Storage = torch.FloatStorage sys.modules['torch.UInt32Storage'] = torch.FloatStorage setattr(torch._C, '_UInt32Storage', torch._C._FloatStorage) return original_import_ir(*args, **kwargs) elif "BlendShapeBase' already defined" in error_str and 'pymomentum' in sys.modules: try: # Try to reload the module importlib.reload(sys.modules['pymomentum']) return original_import_ir(*args, **kwargs) except: pass raise torch._C.import_ir_module = patched_import_ir print("PyTorch UInt32Storage patch applied successfully") # Execute patch immediately when imported patch_torch_for_uint32storage()