test / simple_dataloaders.py
jaewooo's picture
Initial upload
de15dc5 verified
import torch
from torch.utils.data import DataLoader
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_DataLoader
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader
def simple_dataloader_msrvtt_train(args, tokenizer):
msrvtt_dataset = MSRVTT_TrainDataLoader(
csv_path=args.train_csv,
json_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
unfold_sentences=args.expand_msrvtt_sentences,
frame_order=args.train_frame_order,
slice_framepos=args.slice_framepos,
)
# Use regular DataLoader without DistributedSampler
dataloader = DataLoader(
msrvtt_dataset,
batch_size=args.batch_size,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=True,
drop_last=True,
)
return dataloader, len(msrvtt_dataset), None
def simple_dataloader_msrvtt_test(args, tokenizer, subset="test"):
msrvtt_testset = MSRVTT_DataLoader(
csv_path=args.val_csv,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.eval_frame_order,
slice_framepos=args.slice_framepos,
)
dataloader_msrvtt = DataLoader(
msrvtt_testset,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=False,
drop_last=False,
)
return dataloader_msrvtt, len(msrvtt_testset)
SIMPLE_DATALOADER_DICT = {}
SIMPLE_DATALOADER_DICT["msrvtt"] = {"train": simple_dataloader_msrvtt_train, "val": simple_dataloader_msrvtt_test, "test": None}