File size: 4,045 Bytes
393572c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python3
# torch_uint32_patch.py - Direct patch for PyTorch UInt32Storage issue
# This standalone patch is designed to be imported before any torch imports

def patch_torch_for_uint32storage():
    """
    Apply comprehensive patches to PyTorch to handle UInt32Storage
    This must be imported and called before any torch operations
    """
    import sys
    import pickle
    import importlib.util

    # Check if torch is already imported
    if "torch" in sys.modules:
        print("WARNING: torch is already imported, patching may be less effective")

    # Patch the pickle machinery first
    original_find_class = pickle.Unpickler.find_class
    def patched_find_class(self, module_name, name):
        if module_name == 'torch' and name == 'UInt32Storage':
            # Load torch directly to avoid circular imports
            import torch
            return torch.FloatStorage
        return original_find_class(self, module_name, name)
    pickle.Unpickler.find_class = patched_find_class

    # Now import torch and apply direct patches
    import torch

    # Create UInt32Storage class
    if not hasattr(torch, 'UInt32Storage'):
        # Use FloatStorage which seems more likely to work than IntStorage
        torch.UInt32Storage = torch.FloatStorage

    # Register in sys.modules directly
    sys.modules['torch.UInt32Storage'] = torch.FloatStorage

    # Patch _C module
    if hasattr(torch, '_C'):
        if hasattr(torch._C, '_FloatStorage'):
            setattr(torch._C, '_UInt32Storage', torch._C._FloatStorage)

    # Patch the torch.jit loading system
    if hasattr(torch.jit, 'load'):
        original_jit_load = torch.jit.load
        def patched_jit_load(*args, **kwargs):
            try:
                return original_jit_load(*args, **kwargs)
            except RuntimeError as e:
                if "UInt32Storage" in str(e):
                    # Force the UInt32Storage patch again just to be sure
                    torch.UInt32Storage = torch.FloatStorage
                    sys.modules['torch.UInt32Storage'] = torch.FloatStorage
                    # Try again
                    return original_jit_load(*args, **kwargs)
                # For BlendShapeBase errors
                elif "BlendShapeBase' already defined" in str(e) and 'pymomentum' in sys.modules:
                    try:
                        # Try to reload the module
                        importlib.reload(sys.modules['pymomentum'])
                        return original_jit_load(*args, **kwargs)
                    except:
                        pass
                # Re-raise if not our specific error
                raise
        torch.jit.load = patched_jit_load

    # Low-level patching for _C.import_ir_module
    if hasattr(torch, '_C') and hasattr(torch._C, 'import_ir_module'):
        original_import_ir = torch._C.import_ir_module
        def patched_import_ir(*args, **kwargs):
            try:
                return original_import_ir(*args, **kwargs)
            except RuntimeError as e:
                error_str = str(e)
                if "UInt32Storage" in error_str:
                    # Apply emergency patching
                    torch.UInt32Storage = torch.FloatStorage
                    sys.modules['torch.UInt32Storage'] = torch.FloatStorage
                    setattr(torch._C, '_UInt32Storage', torch._C._FloatStorage)
                    return original_import_ir(*args, **kwargs)
                elif "BlendShapeBase' already defined" in error_str and 'pymomentum' in sys.modules:
                    try:
                        # Try to reload the module
                        importlib.reload(sys.modules['pymomentum'])
                        return original_import_ir(*args, **kwargs)
                    except:
                        pass
                raise
        torch._C.import_ir_module = patched_import_ir

    print("PyTorch UInt32Storage patch applied successfully")

# Execute patch immediately when imported
patch_torch_for_uint32storage()