mlse-player-3d / torch_uint32_patch.py
Jake Reardon
Create standalone PyTorch UInt32Storage patch for reliable model loading
393572c
#!/usr/bin/env python3
# torch_uint32_patch.py - Direct patch for PyTorch UInt32Storage issue
# This standalone patch is designed to be imported before any torch imports
def patch_torch_for_uint32storage():
"""
Apply comprehensive patches to PyTorch to handle UInt32Storage
This must be imported and called before any torch operations
"""
import sys
import pickle
import importlib.util
# Check if torch is already imported
if "torch" in sys.modules:
print("WARNING: torch is already imported, patching may be less effective")
# Patch the pickle machinery first
original_find_class = pickle.Unpickler.find_class
def patched_find_class(self, module_name, name):
if module_name == 'torch' and name == 'UInt32Storage':
# Load torch directly to avoid circular imports
import torch
return torch.FloatStorage
return original_find_class(self, module_name, name)
pickle.Unpickler.find_class = patched_find_class
# Now import torch and apply direct patches
import torch
# Create UInt32Storage class
if not hasattr(torch, 'UInt32Storage'):
# Use FloatStorage which seems more likely to work than IntStorage
torch.UInt32Storage = torch.FloatStorage
# Register in sys.modules directly
sys.modules['torch.UInt32Storage'] = torch.FloatStorage
# Patch _C module
if hasattr(torch, '_C'):
if hasattr(torch._C, '_FloatStorage'):
setattr(torch._C, '_UInt32Storage', torch._C._FloatStorage)
# Patch the torch.jit loading system
if hasattr(torch.jit, 'load'):
original_jit_load = torch.jit.load
def patched_jit_load(*args, **kwargs):
try:
return original_jit_load(*args, **kwargs)
except RuntimeError as e:
if "UInt32Storage" in str(e):
# Force the UInt32Storage patch again just to be sure
torch.UInt32Storage = torch.FloatStorage
sys.modules['torch.UInt32Storage'] = torch.FloatStorage
# Try again
return original_jit_load(*args, **kwargs)
# For BlendShapeBase errors
elif "BlendShapeBase' already defined" in str(e) and 'pymomentum' in sys.modules:
try:
# Try to reload the module
importlib.reload(sys.modules['pymomentum'])
return original_jit_load(*args, **kwargs)
except:
pass
# Re-raise if not our specific error
raise
torch.jit.load = patched_jit_load
# Low-level patching for _C.import_ir_module
if hasattr(torch, '_C') and hasattr(torch._C, 'import_ir_module'):
original_import_ir = torch._C.import_ir_module
def patched_import_ir(*args, **kwargs):
try:
return original_import_ir(*args, **kwargs)
except RuntimeError as e:
error_str = str(e)
if "UInt32Storage" in error_str:
# Apply emergency patching
torch.UInt32Storage = torch.FloatStorage
sys.modules['torch.UInt32Storage'] = torch.FloatStorage
setattr(torch._C, '_UInt32Storage', torch._C._FloatStorage)
return original_import_ir(*args, **kwargs)
elif "BlendShapeBase' already defined" in error_str and 'pymomentum' in sys.modules:
try:
# Try to reload the module
importlib.reload(sys.modules['pymomentum'])
return original_import_ir(*args, **kwargs)
except:
pass
raise
torch._C.import_ir_module = patched_import_ir
print("PyTorch UInt32Storage patch applied successfully")
# Execute patch immediately when imported
patch_torch_for_uint32storage()