Как загрузить несколько изображений в градациях серого как один тензор в pytorch?
В общем, количество каналов не важно.
Вам нужна операция, известная как «загрузка пакета данных». Для этого в PyTorch есть класс DataLoader
. DataLoader
классу дополнительно нужен Dataset
класс.
Если в DataLoader
размер пакета равен 64 (bs = 64), вы загрузите 64 изображения из одного раза в качестве тензора.
Если вы используете ImageFolder
, это не вернет минибатч для ты. ImageFolder
- это производный от Dataset
класс.
Проблема с ImageFolder
(если вы просто используете это) в том, что вы получите одно изображение для каждого индекса. Затем вы объедините несколько изображений в мини-серию.
Вот один пример использования ImageFolder
с данными CIFAR10.
from torchvision import transforms
imagef = torchvision.datasets.ImageFolder(r'C:\Users\dj\data\cifar10\test', transform=transforms.ToTensor())
print(imagef)
print(imagef.classes)
img, label = imagef[0]
display(img)
print(img.size())
print(label)
Из:
Dataset ImageFolder
Number of datapoints: 10000
Root location: C:\Users\dj\data\cifar10\test
StandardTransform
Transform: ToTensor()
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
tensor([[[0.6078, 0.6549, 0.6902, ..., 0.7882, 0.7922, 0.7529],
[0.6000, 0.6392, 0.6706, ..., 0.7922, 0.7961, 0.7412],
[0.6078, 0.6275, 0.6588, ..., 0.8078, 0.8000, 0.7412],
...,
[0.3490, 0.2235, 0.2392, ..., 0.3490, 0.2314, 0.2627],
[0.3490, 0.2353, 0.2471, ..., 0.2235, 0.2392, 0.2941],
[0.3608, 0.2353, 0.2392, ..., 0.2353, 0.2510, 0.2863]], ...
torch.Size([3, 32, 32])
0
Следующий пример основан на DataLoader
:
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms
import PIL.Image as Image
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
ds = torchvision.datasets.DatasetFolder(r'C:\Users\dj\data\cifar10\test',
loader=pil_loader,
extensions=('.png'),
transform=transforms.ToTensor())
dl = DataLoader(ds, batch_size=2)
len(dl)
for imgs,lbls in dl:
print(imgs.size()) # torch.Size([2, 3, 32, 32])
break
Это DataLoader
- то, что вам может понадобиться. Тот, который я представляю, имеет настраиваемую функцию загрузки: pil_loader
.
Вы также можете использовать ImageFolder
вместо DatasetFolder
в предыдущем примере.
Это будет примерно так:
ds = torchvision.datasets.ImageFolder(r'C:\Users\dj\data\cifar10\test', transform=transforms.ToTensor())
dl = DataLoader(ds, batch_size=3)
print(len(dl))
for imgs,lbls in dl:
print(imgs.size())
break
person
prosti
schedule
20.11.2019