test_gridsearch.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from click.testing import CliRunner
  5. from mim.commands.gridsearch import cli as gridsearch
  6. from mim.commands.install import cli as install
  7. from mim.commands.uninstall import cli as uninstall
  8. def setup_module():
  9. runner = CliRunner()
  10. result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
  11. assert result.exit_code == 0
  12. result = runner.invoke(uninstall, ['mmcls', '--yes'])
  13. assert result.exit_code == 0
  14. @pytest.mark.parametrize('gpus', [
  15. 0,
  16. pytest.param(
  17. 1,
  18. marks=pytest.mark.skipif(
  19. not torch.cuda.is_available(), reason='requires CUDA support')),
  20. ])
  21. def test_gridsearch(gpus, tmp_path):
  22. runner = CliRunner()
  23. result = runner.invoke(install, ['mmcls', '--yes'])
  24. assert result.exit_code == 0
  25. # Since `mminstall.txt` is not included in the distribution of
  26. # mmcls<=0.23.1, we need to install mmcv-full manually.
  27. result = runner.invoke(install, ['mmcv-full', '--yes'])
  28. assert result.exit_code == 0
  29. args1 = [
  30. 'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
  31. f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
  32. ]
  33. args2 = [
  34. 'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
  35. f'--work-dir={tmp_path}', '--search-args',
  36. '--optimizer.weight_decay 1e-3 1e-4'
  37. ]
  38. args3 = [
  39. 'mmcls', 'tests/data/xxx.py', f'--gpus={gpus}',
  40. f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
  41. ]
  42. args4 = [
  43. 'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
  44. f'--work-dir={tmp_path}', '--search-args'
  45. ]
  46. result = runner.invoke(gridsearch, args1)
  47. assert result.exit_code == 0
  48. result = runner.invoke(gridsearch, args2)
  49. assert result.exit_code == 0
  50. result = runner.invoke(gridsearch, args3)
  51. assert result.exit_code != 0
  52. result = runner.invoke(gridsearch, args4)
  53. assert result.exit_code != 0
  54. def teardown_module():
  55. runner = CliRunner()
  56. result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
  57. assert result.exit_code == 0
  58. result = runner.invoke(uninstall, ['mmcls', '--yes'])
  59. assert result.exit_code == 0