|
|
import torch |
|
|
import torch.nn as nn |
|
|
import threading |
|
|
from torch._utils import ExceptionWrapper |
|
|
import logging |
|
|
|
|
|
def get_a_var(obj): |
|
|
if isinstance(obj, torch.Tensor): |
|
|
return obj |
|
|
|
|
|
if isinstance(obj, list) or isinstance(obj, tuple): |
|
|
for result in map(get_a_var, obj): |
|
|
if isinstance(result, torch.Tensor): |
|
|
return result |
|
|
if isinstance(obj, dict): |
|
|
for result in map(get_a_var, obj.items()): |
|
|
if isinstance(result, torch.Tensor): |
|
|
return result |
|
|
return None |
|
|
|
|
|
def parallel_apply(fct, model, inputs, device_ids): |
|
|
modules = nn.parallel.replicate(model, device_ids) |
|
|
assert len(modules) == len(inputs) |
|
|
lock = threading.Lock() |
|
|
results = {} |
|
|
grad_enabled = torch.is_grad_enabled() |
|
|
|
|
|
def _worker(i, module, input): |
|
|
torch.set_grad_enabled(grad_enabled) |
|
|
device = get_a_var(input).get_device() |
|
|
try: |
|
|
with torch.cuda.device(device): |
|
|
|
|
|
if not isinstance(input, (list, tuple)): |
|
|
input = (input,) |
|
|
output = fct(module, *input) |
|
|
with lock: |
|
|
results[i] = output |
|
|
except Exception: |
|
|
with lock: |
|
|
results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) |
|
|
|
|
|
if len(modules) > 1: |
|
|
threads = [threading.Thread(target=_worker, args=(i, module, input)) |
|
|
for i, (module, input) in enumerate(zip(modules, inputs))] |
|
|
|
|
|
for thread in threads: |
|
|
thread.start() |
|
|
for thread in threads: |
|
|
thread.join() |
|
|
else: |
|
|
_worker(0, modules[0], inputs[0]) |
|
|
|
|
|
outputs = [] |
|
|
for i in range(len(inputs)): |
|
|
output = results[i] |
|
|
if isinstance(output, ExceptionWrapper): |
|
|
output.reraise() |
|
|
outputs.append(output) |
|
|
return outputs |
|
|
|
|
|
def get_logger(filename=None): |
|
|
logger = logging.getLogger('logger') |
|
|
logger.setLevel(logging.DEBUG) |
|
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
datefmt='%m/%d/%Y %H:%M:%S', |
|
|
level=logging.INFO) |
|
|
if filename is not None: |
|
|
handler = logging.FileHandler(filename) |
|
|
handler.setLevel(logging.DEBUG) |
|
|
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) |
|
|
logging.getLogger().addHandler(handler) |
|
|
return logger |