utils.py 512 B

123456789101112131415161718192021
  1. import torch
  2. HALF = 'torch.cuda.HalfTensor'
  3. FLOAT = 'torch.cuda.FloatTensor'
  4. DTYPES = [torch.half, torch.float]
  5. ALWAYS_HALF = {torch.float: HALF,
  6. torch.half: HALF}
  7. ALWAYS_FLOAT = {torch.float: FLOAT,
  8. torch.half: FLOAT}
  9. MATCH_INPUT = {torch.float: FLOAT,
  10. torch.half: HALF}
  11. def common_init(test_case):
  12. test_case.h = 64
  13. test_case.b = 16
  14. test_case.c = 16
  15. test_case.k = 3
  16. test_case.t = 10
  17. torch.set_default_tensor_type(torch.cuda.FloatTensor)