|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch BERT model.""" |
|
|
|
|
|
import logging |
|
|
import numpy as np |
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from modules.until_config import PretrainedConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def gelu(x): |
|
|
"""Implementation of the gelu activation function. |
|
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): |
|
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
|
""" |
|
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
|
|
def swish(x): |
|
|
return x * torch.sigmoid(x) |
|
|
|
|
|
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-12): |
|
|
"""Construct a layernorm module in the TF style (epsilon inside the square root). |
|
|
""" |
|
|
super(LayerNorm, self).__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.bias = nn.Parameter(torch.zeros(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, x): |
|
|
u = x.mean(-1, keepdim=True) |
|
|
s = (x - u).pow(2).mean(-1, keepdim=True) |
|
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon) |
|
|
return self.weight * x + self.bias |
|
|
|
|
|
class PreTrainedModel(nn.Module): |
|
|
""" An abstract class to handle weights initialization and |
|
|
a simple interface for dowloading and loading pretrained models. |
|
|
""" |
|
|
def __init__(self, config, *inputs, **kwargs): |
|
|
super(PreTrainedModel, self).__init__() |
|
|
if not isinstance(config, PretrainedConfig): |
|
|
raise ValueError( |
|
|
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " |
|
|
"To create a model from a Google pretrained model use " |
|
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( |
|
|
self.__class__.__name__, self.__class__.__name__ |
|
|
)) |
|
|
self.config = config |
|
|
|
|
|
def init_weights(self, module): |
|
|
""" Initialize the weights. |
|
|
""" |
|
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
|
|
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
elif isinstance(module, LayerNorm): |
|
|
if 'beta' in dir(module) and 'gamma' in dir(module): |
|
|
module.beta.data.zero_() |
|
|
module.gamma.data.fill_(1.0) |
|
|
else: |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
|
|
|
def resize_token_embeddings(self, new_num_tokens=None): |
|
|
raise NotImplementedError |
|
|
|
|
|
@classmethod |
|
|
def init_preweight(cls, model, state_dict, prefix=None, task_config=None): |
|
|
old_keys = [] |
|
|
new_keys = [] |
|
|
for key in state_dict.keys(): |
|
|
new_key = None |
|
|
if 'gamma' in key: |
|
|
new_key = key.replace('gamma', 'weight') |
|
|
if 'beta' in key: |
|
|
new_key = key.replace('beta', 'bias') |
|
|
if new_key: |
|
|
old_keys.append(key) |
|
|
new_keys.append(new_key) |
|
|
for old_key, new_key in zip(old_keys, new_keys): |
|
|
state_dict[new_key] = state_dict.pop(old_key) |
|
|
|
|
|
if prefix is not None: |
|
|
old_keys = [] |
|
|
new_keys = [] |
|
|
for key in state_dict.keys(): |
|
|
old_keys.append(key) |
|
|
new_keys.append(prefix + key) |
|
|
for old_key, new_key in zip(old_keys, new_keys): |
|
|
state_dict[new_key] = state_dict.pop(old_key) |
|
|
|
|
|
missing_keys = [] |
|
|
unexpected_keys = [] |
|
|
error_msgs = [] |
|
|
|
|
|
metadata = getattr(state_dict, '_metadata', None) |
|
|
state_dict = state_dict.copy() |
|
|
if metadata is not None: |
|
|
state_dict._metadata = metadata |
|
|
|
|
|
def load(module, prefix=''): |
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
|
|
module._load_from_state_dict( |
|
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) |
|
|
for name, child in module._modules.items(): |
|
|
if child is not None: |
|
|
load(child, prefix + name + '.') |
|
|
|
|
|
load(model, prefix='') |
|
|
|
|
|
if prefix is None and (task_config is None or task_config.local_rank == 0): |
|
|
logger.info("-" * 20) |
|
|
if len(missing_keys) > 0: |
|
|
logger.info("Weights of {} not initialized from pretrained model: {}" |
|
|
.format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) |
|
|
if len(unexpected_keys) > 0: |
|
|
logger.info("Weights from pretrained model not used in {}: {}" |
|
|
.format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) |
|
|
if len(error_msgs) > 0: |
|
|
logger.error("Weights from pretrained model cause errors in {}: {}" |
|
|
.format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) |
|
|
|
|
|
return model |
|
|
|
|
|
@property |
|
|
def dtype(self): |
|
|
""" |
|
|
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). |
|
|
""" |
|
|
try: |
|
|
return next(self.parameters()).dtype |
|
|
except StopIteration: |
|
|
|
|
|
def find_tensor_attributes(module: nn.Module): |
|
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
|
|
return tuples |
|
|
|
|
|
gen = self._named_members(get_members_fn=find_tensor_attributes) |
|
|
first_tuple = next(gen) |
|
|
return first_tuple[1].dtype |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): |
|
|
""" |
|
|
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. |
|
|
Download and cache the pre-trained model file if needed. |
|
|
""" |
|
|
|
|
|
model = cls(config, *inputs, **kwargs) |
|
|
if state_dict is None: |
|
|
return model |
|
|
model = cls.init_preweight(model, state_dict) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossEn(nn.Module): |
|
|
def __init__(self,): |
|
|
super(CrossEn, self).__init__() |
|
|
|
|
|
def forward(self, sim_matrix): |
|
|
logpt = F.log_softmax(sim_matrix, dim=-1) |
|
|
logpt = torch.diag(logpt) |
|
|
nce_loss = -logpt |
|
|
sim_loss = nce_loss.mean() |
|
|
return sim_loss |
|
|
|
|
|
class MILNCELoss(nn.Module): |
|
|
def __init__(self, batch_size=1, n_pair=1,): |
|
|
super(MILNCELoss, self).__init__() |
|
|
self.batch_size = batch_size |
|
|
self.n_pair = n_pair |
|
|
torch_v = float(".".join(torch.__version__.split(".")[:2])) |
|
|
self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8 |
|
|
|
|
|
def forward(self, sim_matrix): |
|
|
mm_mask = np.eye(self.batch_size) |
|
|
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair))) |
|
|
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device) |
|
|
|
|
|
from_text_matrix = sim_matrix + mm_mask * -1e12 |
|
|
from_video_matrix = sim_matrix.transpose(1, 0) |
|
|
|
|
|
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1) |
|
|
logpt = F.log_softmax(new_sim_matrix, dim=-1) |
|
|
|
|
|
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1) |
|
|
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12 |
|
|
|
|
|
new_logpt = -torch.logsumexp(masked_logpt, dim=-1) |
|
|
|
|
|
logpt_choice = torch.zeros_like(new_logpt) |
|
|
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2) |
|
|
logpt_choice[mark_ind] = 1 |
|
|
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean() |
|
|
return sim_loss |
|
|
|
|
|
class MaxMarginRankingLoss(nn.Module): |
|
|
def __init__(self, |
|
|
margin=1.0, |
|
|
negative_weighting=False, |
|
|
batch_size=1, |
|
|
n_pair=1, |
|
|
hard_negative_rate=0.5, |
|
|
): |
|
|
super(MaxMarginRankingLoss, self).__init__() |
|
|
self.margin = margin |
|
|
self.n_pair = n_pair |
|
|
self.batch_size = batch_size |
|
|
easy_negative_rate = 1 - hard_negative_rate |
|
|
self.easy_negative_rate = easy_negative_rate |
|
|
self.negative_weighting = negative_weighting |
|
|
if n_pair > 1 and batch_size > 1: |
|
|
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate)) |
|
|
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha |
|
|
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair))) |
|
|
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate)) |
|
|
self.mm_mask = mm_mask.float() |
|
|
|
|
|
def forward(self, x): |
|
|
d = torch.diag(x) |
|
|
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \ |
|
|
F.relu(self.margin + x - d.view(1, -1)) |
|
|
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1: |
|
|
max_margin = max_margin * self.mm_mask.to(max_margin.device) |
|
|
return max_margin.mean() |
|
|
|
|
|
class AllGather(torch.autograd.Function): |
|
|
"""An autograd function that performs allgather on a tensor.""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, tensor, args): |
|
|
if args.world_size == 1: |
|
|
|
|
|
ctx.rank = args.rank |
|
|
ctx.batch_size = tensor.shape[0] |
|
|
return tensor |
|
|
else: |
|
|
|
|
|
output = [torch.empty_like(tensor) for _ in range(args.world_size)] |
|
|
torch.distributed.all_gather(output, tensor) |
|
|
ctx.rank = args.rank |
|
|
ctx.batch_size = tensor.shape[0] |
|
|
return torch.cat(output, dim=0) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
if hasattr(ctx, 'batch_size') and ctx.batch_size > 0: |
|
|
return ( |
|
|
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], |
|
|
None, |
|
|
) |
|
|
else: |
|
|
return (grad_output, None) |
|
|
|