#!/usr/bin/env python3 from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from __future__ import print_function import torch import numpy as np import random import os from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim import time import argparse from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from modules.modeling import CLIP4Clip from modules.optimization import BertAdam from util import parallel_apply, get_logger from simple_dataloaders import SIMPLE_DATALOADER_DICT global logger def get_args(): parser = argparse.ArgumentParser(description='Simplified CLIP4Clip Training') parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='') parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='') parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path') parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path') parser.add_argument('--num_thread_reader', type=int, default=1, help='') parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate') parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit') parser.add_argument('--batch_size', type=int, default=128, help='batch size') parser.add_argument('--batch_size_val', type=int, default=16, help='batch size eval') parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay') parser.add_argument('--n_display', type=int, default=100, help='Information display frequence') parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension') parser.add_argument('--seed', type=int, default=42, help='random seed') parser.add_argument('--max_words', type=int, default=32, help='') parser.add_argument('--max_frames', type=int, default=12, help='') parser.add_argument('--feature_framerate', type=int, default=1, help='') parser.add_argument('--margin', type=float, default=0.1, help='margin for loss') parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample') parser.add_argument('--datatype', type=str, default='msrvtt', help='data type') parser.add_argument('--world_size', type=int, default=1, help='number of distributed processes') parser.add_argument('--rank', type=int, default=0, help='distributed process rank') parser.add_argument('--local_rank', type=int, default=0, help='distributed process local rank') parser.add_argument('--coef_lr', type=float, default=1e-3, help='coefficient for bert branch.') parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).") parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.") parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.") parser.add_argument('--visual_num_hidden_layers', type=int, default=12, help="Layer NO. of visual.") parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.") parser.add_argument('--loose_type', action='store_true', help="Default using tight type for retrieval.") parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="") parser.add_argument('--linear_patch', type=str, default="2d", help="linear projection") parser.add_argument('--sim_header', type=str, default="meanP", help="choice a similarity header.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version") parser.add_argument('--freeze_layer_num', type=int, default=0, help="Layer NO. of CLIP need to freeze.") parser.add_argument('--slice_framepos', type=int, default=2, help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.") # Additional arguments for dataloader compatibility parser.add_argument('--train_frame_order', type=int, default=0, choices=[0, 1, 2], help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.") parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2], help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.") parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative') parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader') args = parser.parse_args() return args def set_seed_logger(args): global logger # predefining random initial seeds random.seed(args.seed) os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True world_size = args.world_size rank = args.rank args.rank = rank if not os.path.exists(args.output_dir): os.makedirs(args.output_dir, exist_ok=True) logger = get_logger(os.path.join(args.output_dir, "log.txt")) if args.local_rank == 0: logger.info("Effective parameters:") for key in sorted(args.__dict__): logger.info(" <<< {}: {}".format(key, args.__dict__[key])) return args def init_device(args, local_rank): global logger device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank) n_gpu = torch.cuda.device_count() logger.info("device: {} n_gpu: {}".format(device, n_gpu)) args.n_gpu = n_gpu if args.batch_size % args.n_gpu != 0: raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format( args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu)) return device, n_gpu def init_model(args, device, n_gpu, local_rank): # Set world_size to 1 for single node training args.world_size = 1 args.rank = 0 # Use cross-base model directly model = CLIP4Clip.from_pretrained("cross-base", cache_dir=PYTORCH_PRETRAINED_BERT_CACHE, task_config=args) model.to(device) return model def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.): if hasattr(model, 'module'): model = model.module param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)] no_decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)] decay_clip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." in n] decay_noclip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." not in n] no_decay_clip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." in n] no_decay_noclip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." not in n] weight_decay = 0.2 optimizer_grouped_parameters = [ {'params': [p for n, p in decay_clip_param_tp], 'weight_decay': weight_decay, 'lr': args.lr * coef_lr}, {'params': [p for n, p in decay_noclip_param_tp], 'weight_decay': weight_decay}, {'params': [p for n, p in no_decay_clip_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr}, {'params': [p for n, p in no_decay_noclip_param_tp], 'weight_decay': 0.0} ] scheduler = None optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=0.1, schedule='warmup_cosine', b1=0.9, b2=0.98, e=1e-6, t_total=num_train_optimization_steps, weight_decay=weight_decay, max_grad_norm=1.0) model = torch.nn.DataParallel(model, device_ids=[local_rank]) if n_gpu > 1 else model return optimizer, scheduler, model def save_model(epoch, args, model, type_name=""): # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model output_model_file = os.path.join( args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch)) torch.save(model_to_save.state_dict(), output_model_file) logger.info("Model saved to %s", output_model_file) return output_model_file def load_model(epoch, args, n_gpu, device, model_file=None): if model_file is None or len(model_file) == 0: model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch)) if os.path.exists(model_file): model_state_dict = torch.load(model_file, map_location='cpu') if args.local_rank == 0: logger.info("Model loaded from %s", model_file) # Prepare model cache_dir = args.cache_dir if hasattr(args, 'cache_dir') and args.cache_dir else PYTORCH_PRETRAINED_BERT_CACHE model = CLIP4Clip.from_pretrained(args.pretrained_clip_name, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args) model.to(device) else: model = None return model def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0): global logger torch.cuda.empty_cache() model.train() log_step = args.n_display start_time = time.time() total_loss = 0 for step, batch in enumerate(train_dataloader): if n_gpu == 1: # multi-gpu does scattering it-self batch = tuple(t.to(device=device, non_blocking=True) for t in batch) input_ids, input_mask, segment_ids, video, video_mask = batch loss = model(input_ids, segment_ids, input_mask, video, video_mask) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. loss.backward() total_loss += float(loss) optimizer.step() optimizer.zero_grad() # https://github.com/openai/CLIP/issues/46 if hasattr(model, 'module'): torch.clamp_(model.module.clip.logit_scale.data, max=np.log(100)) else: torch.clamp_(model.clip.logit_scale.data, max=np.log(100)) global_step += 1 if global_step % log_step == 0 and local_rank == 0: logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1, args.epochs, step + 1, len(train_dataloader), "-".join([str('%.9f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]), float(loss), (time.time() - start_time) / (log_step * (step + 1))) total_loss = total_loss / len(train_dataloader) return total_loss, global_step def eval_epoch(args, model, test_dataloader, device, n_gpu): if hasattr(model, 'module'): model = model.module.to(device) else: model = model.to(device) # ################################################################# ## below variables are used to multi-sentences retrieval # multi_sentence_: important tag for eval # cut_off_points: used to tag the label when calculate the metric # sentence_num: used to cut the sentence representation # video_num: used to cut the video representation # ################################################################# multi_sentence_ = False cut_off_points_, sentence_num_, video_num_ = [], -1, -1 if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') \ and test_dataloader.dataset.multi_sentence_per_video: multi_sentence_ = True cut_off_points_ = test_dataloader.dataset.cut_off_points sentence_num_ = test_dataloader.dataset.sentence_num video_num_ = test_dataloader.dataset.video_num cut_off_points_ = [itm - 1 for itm in cut_off_points_] if multi_sentence_: logger.info("Eval under the multi-sentence per video clip setting.") logger.info("sentence num: {}, video num: {}".format(sentence_num_, video_num_)) model.eval() with torch.no_grad(): batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list = [], [], [], [] batch_list_caption, batch_list_video_id = [], [] total_video_num = 0 # ---------------------------- # 1. cache the features # ---------------------------- for bid, batch in enumerate(test_dataloader): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, video, video_mask, \ pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \ pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, \ pairs_input_video_id = batch sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask) visual_output = model.get_visual_output(video, video_mask) batch_list_t.append(sequence_output) batch_list_v.append(visual_output) batch_list_caption.append(pairs_input_caption_ids) batch_list_video_id.append(pairs_input_video_id) total_video_num += video.shape[0] # ---------------------------------- # 2. calculate the similarity # ---------------------------------- if len(batch_list_t) > 0: batch_list_t = torch.cat(batch_list_t, dim=0) batch_list_v = torch.cat(batch_list_v, dim=0) if args.local_rank == 0: logger.info("total_video_num: {}".format(total_video_num)) batch_list_caption = torch.cat(batch_list_caption, dim=0) batch_list_video_id = torch.cat(batch_list_video_id, dim=0) sim_matrix = model.get_similarity_logits(batch_list_t, batch_list_v, batch_list_caption, batch_list_video_id, loose_type=model.loose_type) sim_matrix = sim_matrix.cpu().numpy() if multi_sentence_: logger.info("before reshape, sim matrix size: {} x {}".format(sim_matrix.shape[0], sim_matrix.shape[1])) cut_off_points2len_ = [itm + 1 for itm in cut_off_points_] max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)]) sim_matrix_new = np.zeros([video_num_, max_length]) sim_matrix_new[:, :] = np.nan for i in range(video_num_): for j in range(cut_off_points2len_[i] - (cut_off_points2len_[i-1] if i > 0 else 0)): sim_matrix_new[i, j] = sim_matrix[i, (cut_off_points2len_[i-1] if i > 0 else 0) + j] sim_matrix = sim_matrix_new logger.info("after reshape, sim matrix size: {} x {}".format(sim_matrix.shape[0], sim_matrix.shape[1])) tv_metrics = compute_metrics(sim_matrix) vt_metrics = compute_metrics(sim_matrix.T) logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0]))) logger.info("Text-to-Video:") logger.info('\t>>> R@1: {:.1f} - R@5: {:.1f} - R@10: {:.1f} - Median R: {:.1f} - Mean R: {:.1f}'. format(tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])) logger.info("Video-to-Text:") logger.info('\t>>> V2T$R@1: {:.1f} - V2T$R@5: {:.1f} - V2T$R@10: {:.1f} - V2T$Median R: {:.1f} - V2T$Mean R: {:.1f}'. format(vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])) R1 = tv_metrics['R1'] return R1 def main(): global logger args = get_args() args = set_seed_logger(args) device, n_gpu = init_device(args, args.local_rank) tokenizer = ClipTokenizer() model = init_model(args, device, n_gpu, args.local_rank) ## #################################### # dataloader loading ## #################################### assert args.datatype in SIMPLE_DATALOADER_DICT assert SIMPLE_DATALOADER_DICT[args.datatype]["test"] is not None \ or SIMPLE_DATALOADER_DICT[args.datatype]["val"] is not None test_dataloader, test_length = None, 0 if SIMPLE_DATALOADER_DICT[args.datatype]["test"] is not None: test_dataloader, test_length = SIMPLE_DATALOADER_DICT[args.datatype]["test"](args, tokenizer) if SIMPLE_DATALOADER_DICT[args.datatype]["val"] is not None: val_dataloader, val_length = SIMPLE_DATALOADER_DICT[args.datatype]["val"](args, tokenizer) if test_dataloader is None: test_dataloader, test_length = val_dataloader, val_length if args.local_rank == 0: logger.info("***** Running test *****") logger.info(" Num examples = %d", test_length) logger.info(" Batch size = %d", args.batch_size_val) logger.info(" Num steps = %d", len(test_dataloader)) ## #################################### # train and eval ## #################################### if args.do_train: train_dataloader, train_length, sampler = SIMPLE_DATALOADER_DICT[args.datatype]["train"](args, tokenizer) num_train_optimization_steps = (int(len(train_dataloader) + args.n_gpu - 1) / args.n_gpu) * args.epochs coef_lr = args.coef_lr optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr) if args.local_rank == 0: logger.info("***** Running training *****") logger.info(" Num examples = %d", train_length) logger.info(" Batch size = %d", args.batch_size) logger.info(" Num steps = %d", num_train_optimization_steps * args.n_gpu) best_score = 0.00001 best_output_model_file = "None" global_step = 0 for epoch in range(args.epochs): train_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=args.local_rank) if args.local_rank == 0: logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, train_loss) eval_result = eval_epoch(args, model, test_dataloader, device, n_gpu) if best_score <= eval_result: best_score = eval_result best_output_model_file = save_model(epoch, args, model, type_name="") if args.local_rank == 0: logger.info("Best R@1: %f, Model: %s", best_score, best_output_model_file) elif args.do_eval: eval_epoch(args, model, test_dataloader, device, n_gpu) if __name__ == "__main__": main()