test_deprecated_warning.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import unittest
  2. import torch
  3. import apex
  4. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  5. def init_model_and_optimizer():
  6. model = torch.nn.Linear(1, 1, bias=False).cuda()
  7. optimizer = torch.optim.SGD(model.parameters(), 1.0)
  8. return model, optimizer
  9. @unittest.skipUnless(torch.cuda.is_available(), "")
  10. class TestDeprecatedWarning(unittest.TestCase):
  11. def test_amp(self):
  12. model, optimizer = init_model_and_optimizer()
  13. with self.assertWarns(apex.DeprecatedFeatureWarning):
  14. _ = apex.amp.initialize(model, optimizer)
  15. def test_fp16_model(self):
  16. model, _ = init_model_and_optimizer()
  17. with self.assertWarns(apex.DeprecatedFeatureWarning):
  18. _ = apex.fp16_utils.FP16Model(model)
  19. def test_fp16_optimizer(self):
  20. _, optimizer = init_model_and_optimizer()
  21. with self.assertWarns(apex.DeprecatedFeatureWarning):
  22. _ = apex.fp16_utils.FP16_Optimizer(optimizer)
  23. def test_fp16_loss_scaler(self):
  24. with self.assertWarns(apex.DeprecatedFeatureWarning):
  25. apex.fp16_utils.LossScaler()
  26. class TestParallel(NcclDistributedTestBase):
  27. @property
  28. def world_size(self):
  29. return min(torch.cuda.device_count(), 2)
  30. def test_distributed_data_parallel(self):
  31. model, _ = init_model_and_optimizer()
  32. with self.assertWarns(apex.DeprecatedFeatureWarning):
  33. _ = apex.parallel.DistributedDataParallel(model)
  34. def test_convert_syncbn_model(self):
  35. model, _ = init_model_and_optimizer()
  36. with self.assertWarns(apex.DeprecatedFeatureWarning):
  37. _ = apex.parallel.convert_syncbn_model(model)
  38. if __name__ == "__main__":
  39. unittest.main()