123456789101112131415161718192021 |
- import torch
- HALF = 'torch.cuda.HalfTensor'
- FLOAT = 'torch.cuda.FloatTensor'
- DTYPES = [torch.half, torch.float]
- ALWAYS_HALF = {torch.float: HALF,
- torch.half: HALF}
- ALWAYS_FLOAT = {torch.float: FLOAT,
- torch.half: FLOAT}
- MATCH_INPUT = {torch.float: FLOAT,
- torch.half: HALF}
- def common_init(test_case):
- test_case.h = 64
- test_case.b = 16
- test_case.c = 16
- test_case.k = 3
- test_case.t = 10
- torch.set_default_tensor_type(torch.cuda.FloatTensor)
|