123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- import unittest
- from apex import amp
- import random
- import torch
- from torch import nn
- from utils import common_init, HALF
- class TestRnnCells(unittest.TestCase):
- def setUp(self):
- self.handle = amp.init(enabled=True)
- common_init(self)
- def tearDown(self):
- self.handle._deactivate()
- def run_cell_test(self, cell, state_tuple=False):
- shape = (self.b, self.h)
- for typ in [torch.float, torch.half]:
- xs = [torch.randn(shape, dtype=typ).requires_grad_()
- for _ in range(self.t)]
- hidden_fn = lambda: torch.zeros(shape, dtype=typ)
- if state_tuple:
- hidden = (hidden_fn(), hidden_fn())
- else:
- hidden = hidden_fn()
- outputs = []
- for i in range(self.t):
- hidden = cell(xs[i], hidden)
- if state_tuple:
- output = hidden[0]
- else:
- output = hidden
- outputs.append(output)
- for y in outputs:
- self.assertEqual(y.type(), HALF)
- outputs[-1].float().sum().backward()
- for i, x in enumerate(xs):
- self.assertEqual(x.grad.dtype, x.dtype)
- def test_rnn_cell_is_half(self):
- cell = nn.RNNCell(self.h, self.h)
- self.run_cell_test(cell)
- def test_gru_cell_is_half(self):
- cell = nn.GRUCell(self.h, self.h)
- self.run_cell_test(cell)
- def test_lstm_cell_is_half(self):
- cell = nn.LSTMCell(self.h, self.h)
- self.run_cell_test(cell, state_tuple=True)
- class TestRnns(unittest.TestCase):
- def setUp(self):
- self.handle = amp.init(enabled=True)
- common_init(self)
- def tearDown(self):
- self.handle._deactivate()
- def run_rnn_test(self, rnn, layers, bidir, state_tuple=False):
- for typ in [torch.float, torch.half]:
- x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
- hidden_fn = lambda: torch.zeros((layers + (layers * bidir),
- self.b, self.h), dtype=typ)
- if state_tuple:
- hidden = (hidden_fn(), hidden_fn())
- else:
- hidden = hidden_fn()
- output, _ = rnn(x, hidden)
- self.assertEqual(output.type(), HALF)
- output[-1, :, :].float().sum().backward()
- self.assertEqual(x.grad.dtype, x.dtype)
- def test_rnn_is_half(self):
- configs = [(1, False), (2, False), (2, True)]
- for layers, bidir in configs:
- rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=layers,
- nonlinearity='relu', bidirectional=bidir)
- self.run_rnn_test(rnn, layers, bidir)
- def test_gru_is_half(self):
- configs = [(1, False), (2, False), (2, True)]
- for layers, bidir in configs:
- rnn = nn.GRU(input_size=self.h, hidden_size=self.h, num_layers=layers,
- bidirectional=bidir)
- self.run_rnn_test(rnn, layers, bidir)
- def test_lstm_is_half(self):
- configs = [(1, False), (2, False), (2, True)]
- for layers, bidir in configs:
- rnn = nn.LSTM(input_size=self.h, hidden_size=self.h, num_layers=layers,
- bidirectional=bidir)
- self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
- def test_rnn_packed_sequence(self):
- num_layers = 2
- rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
- for typ in [torch.float, torch.half]:
- x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
- lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)],
- reverse=True)
- # `pack_padded_sequence` breaks if default tensor type is non-CPU
- torch.set_default_tensor_type(torch.FloatTensor)
- lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu'))
- packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens)
- torch.set_default_tensor_type(torch.cuda.FloatTensor)
- hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ)
- output, _ = rnn(packed_seq, hidden)
- self.assertEqual(output.data.type(), HALF)
- output.data.float().sum().backward()
- self.assertEqual(x.grad.dtype, x.dtype)
- if __name__ == '__main__':
- unittest.main()
|