import torch from torch.utils.data import DataLoader from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_DataLoader from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader def simple_dataloader_msrvtt_train(args, tokenizer): msrvtt_dataset = MSRVTT_TrainDataLoader( csv_path=args.train_csv, json_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, unfold_sentences=args.expand_msrvtt_sentences, frame_order=args.train_frame_order, slice_framepos=args.slice_framepos, ) # Use regular DataLoader without DistributedSampler dataloader = DataLoader( msrvtt_dataset, batch_size=args.batch_size, num_workers=args.num_thread_reader, pin_memory=False, shuffle=True, drop_last=True, ) return dataloader, len(msrvtt_dataset), None def simple_dataloader_msrvtt_test(args, tokenizer, subset="test"): msrvtt_testset = MSRVTT_DataLoader( csv_path=args.val_csv, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, frame_order=args.eval_frame_order, slice_framepos=args.slice_framepos, ) dataloader_msrvtt = DataLoader( msrvtt_testset, batch_size=args.batch_size_val, num_workers=args.num_thread_reader, pin_memory=False, shuffle=False, drop_last=False, ) return dataloader_msrvtt, len(msrvtt_testset) SIMPLE_DATALOADER_DICT = {} SIMPLE_DATALOADER_DICT["msrvtt"] = {"train": simple_dataloader_msrvtt_train, "val": simple_dataloader_msrvtt_test, "test": None}