3
0

test_train.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from click.testing import CliRunner
  5. from mim.commands.install import cli as install
  6. from mim.commands.train import cli as train
  7. from mim.commands.uninstall import cli as uninstall
  8. def setup_module():
  9. runner = CliRunner()
  10. result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
  11. assert result.exit_code == 0
  12. result = runner.invoke(uninstall, ['mmcls', '--yes'])
  13. assert result.exit_code == 0
  14. @pytest.mark.parametrize('gpus', [
  15. 0,
  16. pytest.param(
  17. 1,
  18. marks=pytest.mark.skipif(
  19. not torch.cuda.is_available(), reason='requires CUDA support')),
  20. ])
  21. def test_train(gpus, tmp_path):
  22. runner = CliRunner()
  23. result = runner.invoke(install, ['mmcls', '--yes'])
  24. assert result.exit_code == 0
  25. # Since `mminstall.txt` is not included in the distribution of
  26. # mmcls<=0.23.1, we need to install mmcv-full manually.
  27. result = runner.invoke(install, ['mmcv-full', '--yes'])
  28. assert result.exit_code == 0
  29. result = runner.invoke(train, [
  30. 'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
  31. f'--work-dir={tmp_path}'
  32. ])
  33. assert result.exit_code == 0
  34. result = runner.invoke(train, [
  35. 'mmcls', 'tests/data/xxx.py', f'--gpus={gpus}',
  36. f'--work-dir={tmp_path}'
  37. ])
  38. assert result.exit_code != 0
  39. def teardown_module():
  40. runner = CliRunner()
  41. result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
  42. assert result.exit_code == 0
  43. result = runner.invoke(uninstall, ['mmcls', '--yes'])
  44. assert result.exit_code == 0