在pytorch中若是使用自定义数据集,需要定义Dataset类,并覆盖父类的
__len__和__getitem__
函数
举个例子,返回常规的数据对x, y 也可以是多个x,y 比如小样本学习中需要query support对就是两个x,两个y
class MyDataset(Dataset):
定义相关数据
def __len__(self):
return len(self.x_data)
def __getitem__(self, idx):
return x, y
但是在 __getitem__中也可以返回字典类型的数据 , 例如
def __getitem__(self, idx)
batch = {'query_img': query_img,
'query_mask': query_mask,
'query_name': query_name,
'query_ignore_idx': query_ignore_idx,
'support_imgs': support_imgs,
'support_masks': support_masks,
'support_names': support_names,
'support_ignore_idxs': support_ignore_idxs,
'class_id': torch.tensor(class_sample)}
return batch
下面解释一下为什么可以返回字典.
通常当我们定义好Dataset并实例化dataset之后,会实例化一个DataLoader并将dataset传入其中,DataLoader的作用是拼接多个__getitem__获得的数据,返回一个batch的数据,在实例化DataLoader的时候有一个参数是collate_fn,它用来定义数据batch拼接方式
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
再来看一下默认的collate_fn函数是如何定义的
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == ():
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
起作用的应该是是这一行
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
其中elem是batch中的第一个元素,用列表循环式把batch中所有相同key的数据添加到同一个key的[]中
再来看一下collate_fn被调用的地方
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
在构建DataLoader时,需要传入参数dataset,这里可以是自己自定义数据集类,比如上图myDataset
在DataLoader 送入torch中进行训练时,会自动调用数据集类的__getitem__()方法
class myDataset(Dataset):
def __init__(self, csv_file, txt_file, root_dir, other_file):
self.csv_data = pd.read_csv(csv_fi...
确实,功能上是等价的,但是有的时候,当你返回的东西特别多并且各部分功能不一样的时候,上面分组就会很方便。并且,我们dataloader返回的结果,也将会是分组的,非常之方便。,那么dataloader返回的结果也会是嵌套的,反正不会乱掉。经过这次,确实更加了解了dataloader的强大的组织为批的能力。我以前还实现过collate_fn,现在看来,无法处理多组以及嵌套的情况,今天发现,原来可以返回多组,太自由了。甚至:还可以更加复杂的,例如嵌套,后来发现,原来一组不一定是由。组成,可以是任意的,例如。
torch.utils.data.Dataset 中的 __getitem__ 方法需要实现对数据集中单个样本的访问。该方法接受一个索引,并返回数据集中该索引对应的样本。通常,样本数据是通过读取数据文件或计算生成的。
例如,如果我们有一个图像分类数据集,可以在 __getitem__ 方法中读取索引对应的图像文件,并将其转换为 PyTorch 张量,同时返回图像对应的标签。
具体实现可能会有所不同...
Pytorch官方文档:https://pytorch-cn.readthedocs.io/zh/latest/
Pytorch学习文档:https://github.com/tensor-yu/PyTorch_Tutorial
参考:https://blog.csdn.net/u011995719/article/details/85102770
文章目录PyTorch学习(2):数据加载机制前言1.Dataset类2.构建自定义Dataset子类3.DataL
以下内容都是针对Pytorch 1.0-1.1介绍。
很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会自上而下地对Pytorch数据读取方法进行介绍。
自上而下理解三者关系
首先我们看一下DataLoader.__next__的源代码长什么样,为方便理解我只选取了nu...
idx的范围是从0到len-1(__len__的返回值)
但是如果采用了dataloader进行迭代,num_workers大于一的话,因为是多线程,所以运行速度不一样,这个时候如果在__getitem__函,数里输出idx的话,就是乱序的。但是实际上当线程数设置为1还是顺序的。
即使线程数大于1,如果返回idx,并且在dataloader迭代后的过程中输出结果的话,还是顺序的,也就是说,多线程可能速度不一样,但是最终的结果要保证和单线程的一致
转载自:https://blog.csdn.net/sinat_42239797/article/details/90641659 侵删
对于如何定义自己的Datasets我讲从以下几个方面进行解说
1.什么是Datasets?
2.为什么要定义Datasets?
3.如何定义Datasets?
定义Datasets分为以下几个板块:
1)Datasets的源代码及解说
2)Datasets的整体框架及解说
3)自己的Datasets框架及解说
4)DataLoader的使用
DataLoader Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. --PyTorch Documents一般来说PyTorch中深度学习训练的流程是这样的: 1. 创建Dateset 2. Dataset传递给DataLoader 3. DataL...