test_run.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from click.testing import CliRunner
  5. from mim.commands.install import cli as install
  6. from mim.commands.run import cli as run
  7. from mim.commands.uninstall import cli as uninstall
  8. def setup_module():
  9. runner = CliRunner()
  10. result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
  11. assert result.exit_code == 0
  12. result = runner.invoke(uninstall, ['mmcls', '--yes'])
  13. assert result.exit_code == 0
  14. @pytest.mark.parametrize('device,gpus', [
  15. ('cpu', 0),
  16. pytest.param(
  17. 'cuda',
  18. 1,
  19. marks=pytest.mark.skipif(
  20. not torch.cuda.is_available(), reason='requires CUDA support')),
  21. ])
  22. def test_run(device, gpus, tmp_path):
  23. runner = CliRunner()
  24. result = runner.invoke(install, ['mmcls', '--yes'])
  25. assert result.exit_code == 0
  26. # Since `mminstall.txt` is not included in the distribution of
  27. # mmcls<=0.23.1, we need to install mmcv-full manually.
  28. result = runner.invoke(install, ['mmcv-full', '--yes'])
  29. assert result.exit_code == 0
  30. result = runner.invoke(run, [
  31. 'mmcls', 'train', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
  32. f'--work-dir={tmp_path}'
  33. ])
  34. assert result.exit_code == 0
  35. result = runner.invoke(run, [
  36. 'mmcls', 'test', 'tests/data/lenet5_mnist.py',
  37. 'tests/data/epoch_1.pth', f'--device={device}', '--metrics=accuracy'
  38. ])
  39. assert result.exit_code == 0
  40. result = runner.invoke(run, [
  41. 'mmcls', 'xxx', 'tests/data/lenet5_mnist.py', 'tests/data/epoch_1.pth',
  42. f'--gpus={gpus}', '--metrics=accuracy'
  43. ])
  44. assert result.exit_code != 0
  45. result = runner.invoke(run, [
  46. 'mmcls', 'test', 'tests/data/xxx.py', 'tests/data/epoch_1.pth',
  47. f'--device={device}', '--metrics=accuracy'
  48. ])
  49. assert result.exit_code != 0
  50. def teardown_module():
  51. runner = CliRunner()
  52. result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
  53. assert result.exit_code == 0
  54. result = runner.invoke(uninstall, ['mmcls', '--yes'])
  55. assert result.exit_code == 0