123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import pytest
- import torch
- from click.testing import CliRunner
- from mim.commands.gridsearch import cli as gridsearch
- from mim.commands.install import cli as install
- from mim.commands.uninstall import cli as uninstall
- def setup_module():
- runner = CliRunner()
- result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
- assert result.exit_code == 0
- result = runner.invoke(uninstall, ['mmcls', '--yes'])
- assert result.exit_code == 0
- @pytest.mark.parametrize('gpus', [
- 0,
- pytest.param(
- 1,
- marks=pytest.mark.skipif(
- not torch.cuda.is_available(), reason='requires CUDA support')),
- ])
- def test_gridsearch(gpus, tmp_path):
- runner = CliRunner()
- result = runner.invoke(install, ['mmcls', '--yes'])
- assert result.exit_code == 0
- # Since `mminstall.txt` is not included in the distribution of
- # mmcls<=0.23.1, we need to install mmcv-full manually.
- result = runner.invoke(install, ['mmcv-full', '--yes'])
- assert result.exit_code == 0
- args1 = [
- 'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
- f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
- ]
- args2 = [
- 'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
- f'--work-dir={tmp_path}', '--search-args',
- '--optimizer.weight_decay 1e-3 1e-4'
- ]
- args3 = [
- 'mmcls', 'tests/data/xxx.py', f'--gpus={gpus}',
- f'--work-dir={tmp_path}', '--search-args', '--optimizer.lr 1e-3 1e-4'
- ]
- args4 = [
- 'mmcls', 'tests/data/lenet5_mnist.py', f'--gpus={gpus}',
- f'--work-dir={tmp_path}', '--search-args'
- ]
- result = runner.invoke(gridsearch, args1)
- assert result.exit_code == 0
- result = runner.invoke(gridsearch, args2)
- assert result.exit_code == 0
- result = runner.invoke(gridsearch, args3)
- assert result.exit_code != 0
- result = runner.invoke(gridsearch, args4)
- assert result.exit_code != 0
- def teardown_module():
- runner = CliRunner()
- result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
- assert result.exit_code == 0
- result = runner.invoke(uninstall, ['mmcls', '--yes'])
- assert result.exit_code == 0
|