PyTorch shuffle坑

2022-09-23 10:18:53
  • NumPy 的数组可以用np.random.shufflerandom.shuffle 打乱,元素还是之前的元素,只是位置变了;
  • PyTorch 没有自己的 shuffle 函数[1],但torch.Tensor 不能用上面两种方法 shuffle:感觉变成可放回抽样了,元素不全、有重复。

用 pytorch 如果自己写 data loading,用到 indices,且遇到 pytorch 0.3 这种不能用 numpy 数组对 torch.Tensor 做 advanced indexing(所以将 index 向量转成了 torch.Tensor),又想要每个 epoch 都 shuffle 一下 indices 的,这里可能有坑。

Code

  • pytorch 0.3、1.4
import randomimport numpyas npimport torchprint("--- torch.Tensor ---")# 买家秀print("- np.random.shuffle -")
a= torch.arange(12)print("a:", a)
np.random.shuffle(a)print("a shuffled:", a)print("- random.shuffle -")
b= torch.arange(12)print("b:", b)
random.shuffle(b)print("b shuffled:", b)print("\n--- np.ndarray ---")# 卖家秀print("- np.random.shuffle -")
c= np.arange(12)print("c:", c)
np.random.shuffle(c)print("c shuffled:", c)print("- random.shuffle -")
d= np.arange(12)print("d:", d)
random.shuffle(d)print("d shuffled:", d)
  • 输出
--- torch.Tensor ---
- np.random.shuffle -
a: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
a shuffled: tensor([0, 0, 1, 0, 3, 4, 6, 3, 1, 3, 0, 8])
- random.shuffle -
b: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
b shuffled: tensor([0, 0, 1, 1, 2, 2, 1, 6, 4, 6, 9, 3])

--- np.ndarray ---
- np.random.shuffle -
c:[ 0  1  2  3  4  5  6  7  8  9 10 11]
c shuffled:[ 6  8 10  3  1  4  5  2  0  7  9 11]
- random.shuffle -
d:[ 0  1  2  3  4  5  6  7  8  9 10 11]
d shuffled:[ 4 10  1  8  3  9 11  2  5  6  0  7]

References

  1. PyTorchshuffle
  2. pytorch shuffle 一个tensor
  • 作者:HackerTom
  • 原文链接:https://blog.csdn.net/HackerTom/article/details/107542952
    更新时间:2022-09-23 10:18:53