samplers.py 781 B

1234567891011121314151617181920212223242526272829
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import torch
  8. class SubsetRandomSampler(torch.utils.data.Sampler):
  9. r"""Samples elements randomly from a given list of indices, without replacement.
  10. Arguments:
  11. indices (sequence): a sequence of indices
  12. """
  13. def __init__(self, indices):
  14. self.epoch = 0
  15. self.indices = indices
  16. def __iter__(self):
  17. return (self.indices[i] for i in torch.randperm(len(self.indices)))
  18. def __len__(self):
  19. return len(self.indices)
  20. def set_epoch(self, epoch):
  21. self.epoch = epoch