import unittest from apex import amp import random import torch from torch import nn from utils import common_init, HALF class TestRnnCells(unittest.TestCase): def setUp(self): self.handle = amp.init(enabled=True) common_init(self) def tearDown(self): self.handle._deactivate() def run_cell_test(self, cell, state_tuple=False): shape = (self.b, self.h) for typ in [torch.float, torch.half]: xs = [torch.randn(shape, dtype=typ).requires_grad_() for _ in range(self.t)] hidden_fn = lambda: torch.zeros(shape, dtype=typ) if state_tuple: hidden = (hidden_fn(), hidden_fn()) else: hidden = hidden_fn() outputs = [] for i in range(self.t): hidden = cell(xs[i], hidden) if state_tuple: output = hidden[0] else: output = hidden outputs.append(output) for y in outputs: self.assertEqual(y.type(), HALF) outputs[-1].float().sum().backward() for i, x in enumerate(xs): self.assertEqual(x.grad.dtype, x.dtype) def test_rnn_cell_is_half(self): cell = nn.RNNCell(self.h, self.h) self.run_cell_test(cell) def test_gru_cell_is_half(self): cell = nn.GRUCell(self.h, self.h) self.run_cell_test(cell) def test_lstm_cell_is_half(self): cell = nn.LSTMCell(self.h, self.h) self.run_cell_test(cell, state_tuple=True) class TestRnns(unittest.TestCase): def setUp(self): self.handle = amp.init(enabled=True) common_init(self) def tearDown(self): self.handle._deactivate() def run_rnn_test(self, rnn, layers, bidir, state_tuple=False): for typ in [torch.float, torch.half]: x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_() hidden_fn = lambda: torch.zeros((layers + (layers * bidir), self.b, self.h), dtype=typ) if state_tuple: hidden = (hidden_fn(), hidden_fn()) else: hidden = hidden_fn() output, _ = rnn(x, hidden) self.assertEqual(output.type(), HALF) output[-1, :, :].float().sum().backward() self.assertEqual(x.grad.dtype, x.dtype) def test_rnn_is_half(self): configs = [(1, False), (2, False), (2, True)] for layers, bidir in configs: rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=layers, nonlinearity='relu', bidirectional=bidir) self.run_rnn_test(rnn, layers, bidir) def test_gru_is_half(self): configs = [(1, False), (2, False), (2, True)] for layers, bidir in configs: rnn = nn.GRU(input_size=self.h, hidden_size=self.h, num_layers=layers, bidirectional=bidir) self.run_rnn_test(rnn, layers, bidir) def test_lstm_is_half(self): configs = [(1, False), (2, False), (2, True)] for layers, bidir in configs: rnn = nn.LSTM(input_size=self.h, hidden_size=self.h, num_layers=layers, bidirectional=bidir) self.run_rnn_test(rnn, layers, bidir, state_tuple=True) def test_rnn_packed_sequence(self): num_layers = 2 rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers) for typ in [torch.float, torch.half]: x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_() lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)], reverse=True) # `pack_padded_sequence` breaks if default tensor type is non-CPU torch.set_default_tensor_type(torch.FloatTensor) lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu')) packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens) torch.set_default_tensor_type(torch.cuda.FloatTensor) hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ) output, _ = rnn(packed_seq, hidden) self.assertEqual(output.data.type(), HALF) output.data.float().sum().backward() self.assertEqual(x.grad.dtype, x.dtype) if __name__ == '__main__': unittest.main()