Commit d7178ff2 by 前钰

Upload New File

parent 703d5d6b
import glob
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned
self.files_A = sorted(glob.glob(os.path.join(root, '%s/A' % mode) + '/*.*'))
self.files_B = sorted(glob.glob(os.path.join(root, '%s/B' % mode) + '/*.*'))
def __getitem__(self, index):
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
if self.unaligned:
item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]).convert('RGB'))
else:
item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]).convert('RGB'))
return {'A': item_A, 'B': item_B}
def __len__(self):
return max(len(self.files_A), len(self.files_B))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment