test_rnn.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import unittest
  2. from apex import amp
  3. import random
  4. import torch
  5. from torch import nn
  6. from utils import common_init, HALF
  7. class TestRnnCells(unittest.TestCase):
  8. def setUp(self):
  9. self.handle = amp.init(enabled=True)
  10. common_init(self)
  11. def tearDown(self):
  12. self.handle._deactivate()
  13. def run_cell_test(self, cell, state_tuple=False):
  14. shape = (self.b, self.h)
  15. for typ in [torch.float, torch.half]:
  16. xs = [torch.randn(shape, dtype=typ).requires_grad_()
  17. for _ in range(self.t)]
  18. hidden_fn = lambda: torch.zeros(shape, dtype=typ)
  19. if state_tuple:
  20. hidden = (hidden_fn(), hidden_fn())
  21. else:
  22. hidden = hidden_fn()
  23. outputs = []
  24. for i in range(self.t):
  25. hidden = cell(xs[i], hidden)
  26. if state_tuple:
  27. output = hidden[0]
  28. else:
  29. output = hidden
  30. outputs.append(output)
  31. for y in outputs:
  32. self.assertEqual(y.type(), HALF)
  33. outputs[-1].float().sum().backward()
  34. for i, x in enumerate(xs):
  35. self.assertEqual(x.grad.dtype, x.dtype)
  36. def test_rnn_cell_is_half(self):
  37. cell = nn.RNNCell(self.h, self.h)
  38. self.run_cell_test(cell)
  39. def test_gru_cell_is_half(self):
  40. cell = nn.GRUCell(self.h, self.h)
  41. self.run_cell_test(cell)
  42. def test_lstm_cell_is_half(self):
  43. cell = nn.LSTMCell(self.h, self.h)
  44. self.run_cell_test(cell, state_tuple=True)
  45. class TestRnns(unittest.TestCase):
  46. def setUp(self):
  47. self.handle = amp.init(enabled=True)
  48. common_init(self)
  49. def tearDown(self):
  50. self.handle._deactivate()
  51. def run_rnn_test(self, rnn, layers, bidir, state_tuple=False):
  52. for typ in [torch.float, torch.half]:
  53. x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
  54. hidden_fn = lambda: torch.zeros((layers + (layers * bidir),
  55. self.b, self.h), dtype=typ)
  56. if state_tuple:
  57. hidden = (hidden_fn(), hidden_fn())
  58. else:
  59. hidden = hidden_fn()
  60. output, _ = rnn(x, hidden)
  61. self.assertEqual(output.type(), HALF)
  62. output[-1, :, :].float().sum().backward()
  63. self.assertEqual(x.grad.dtype, x.dtype)
  64. def test_rnn_is_half(self):
  65. configs = [(1, False), (2, False), (2, True)]
  66. for layers, bidir in configs:
  67. rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=layers,
  68. nonlinearity='relu', bidirectional=bidir)
  69. self.run_rnn_test(rnn, layers, bidir)
  70. def test_gru_is_half(self):
  71. configs = [(1, False), (2, False), (2, True)]
  72. for layers, bidir in configs:
  73. rnn = nn.GRU(input_size=self.h, hidden_size=self.h, num_layers=layers,
  74. bidirectional=bidir)
  75. self.run_rnn_test(rnn, layers, bidir)
  76. def test_lstm_is_half(self):
  77. configs = [(1, False), (2, False), (2, True)]
  78. for layers, bidir in configs:
  79. rnn = nn.LSTM(input_size=self.h, hidden_size=self.h, num_layers=layers,
  80. bidirectional=bidir)
  81. self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
  82. def test_rnn_packed_sequence(self):
  83. num_layers = 2
  84. rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
  85. for typ in [torch.float, torch.half]:
  86. x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
  87. lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)],
  88. reverse=True)
  89. # `pack_padded_sequence` breaks if default tensor type is non-CPU
  90. torch.set_default_tensor_type(torch.FloatTensor)
  91. lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu'))
  92. packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens)
  93. torch.set_default_tensor_type(torch.cuda.FloatTensor)
  94. hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ)
  95. output, _ = rnn(packed_seq, hidden)
  96. self.assertEqual(output.data.type(), HALF)
  97. output.data.float().sum().backward()
  98. self.assertEqual(x.grad.dtype, x.dtype)
  99. if __name__ == '__main__':
  100. unittest.main()