|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_DataLoader |
|
|
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader |
|
|
|
|
|
def simple_dataloader_msrvtt_train(args, tokenizer): |
|
|
msrvtt_dataset = MSRVTT_TrainDataLoader( |
|
|
csv_path=args.train_csv, |
|
|
json_path=args.data_path, |
|
|
features_path=args.features_path, |
|
|
max_words=args.max_words, |
|
|
feature_framerate=args.feature_framerate, |
|
|
tokenizer=tokenizer, |
|
|
max_frames=args.max_frames, |
|
|
unfold_sentences=args.expand_msrvtt_sentences, |
|
|
frame_order=args.train_frame_order, |
|
|
slice_framepos=args.slice_framepos, |
|
|
) |
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
msrvtt_dataset, |
|
|
batch_size=args.batch_size, |
|
|
num_workers=args.num_thread_reader, |
|
|
pin_memory=False, |
|
|
shuffle=True, |
|
|
drop_last=True, |
|
|
) |
|
|
|
|
|
return dataloader, len(msrvtt_dataset), None |
|
|
|
|
|
def simple_dataloader_msrvtt_test(args, tokenizer, subset="test"): |
|
|
msrvtt_testset = MSRVTT_DataLoader( |
|
|
csv_path=args.val_csv, |
|
|
features_path=args.features_path, |
|
|
max_words=args.max_words, |
|
|
feature_framerate=args.feature_framerate, |
|
|
tokenizer=tokenizer, |
|
|
max_frames=args.max_frames, |
|
|
frame_order=args.eval_frame_order, |
|
|
slice_framepos=args.slice_framepos, |
|
|
) |
|
|
|
|
|
dataloader_msrvtt = DataLoader( |
|
|
msrvtt_testset, |
|
|
batch_size=args.batch_size_val, |
|
|
num_workers=args.num_thread_reader, |
|
|
pin_memory=False, |
|
|
shuffle=False, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
return dataloader_msrvtt, len(msrvtt_testset) |
|
|
|
|
|
SIMPLE_DATALOADER_DICT = {} |
|
|
SIMPLE_DATALOADER_DICT["msrvtt"] = {"train": simple_dataloader_msrvtt_train, "val": simple_dataloader_msrvtt_test, "test": None} |