test_base.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import unittest
  16. import contextlib
  17. import paddle
  18. import paddle.fluid as fluid
  19. from paddle.fluid.framework import Program
  20. from paddle.fluid import core
  21. class LayerTest(unittest.TestCase):
  22. @classmethod
  23. def setUpClass(cls):
  24. cls.seed = 111
  25. @classmethod
  26. def tearDownClass(cls):
  27. pass
  28. def _get_place(self, force_to_use_cpu=False):
  29. # this option for ops that only have cpu kernel
  30. if force_to_use_cpu:
  31. return core.CPUPlace()
  32. else:
  33. if core.is_compiled_with_cuda():
  34. return core.CUDAPlace(0)
  35. return core.CPUPlace()
  36. @contextlib.contextmanager
  37. def static_graph(self):
  38. paddle.enable_static()
  39. scope = fluid.core.Scope()
  40. program = Program()
  41. with fluid.scope_guard(scope):
  42. with fluid.program_guard(program):
  43. paddle.seed(self.seed)
  44. paddle.framework.random._manual_program_seed(self.seed)
  45. yield
  46. def get_static_graph_result(self,
  47. feed,
  48. fetch_list,
  49. with_lod=False,
  50. force_to_use_cpu=False):
  51. exe = fluid.Executor(self._get_place(force_to_use_cpu))
  52. exe.run(fluid.default_startup_program())
  53. return exe.run(fluid.default_main_program(),
  54. feed=feed,
  55. fetch_list=fetch_list,
  56. return_numpy=(not with_lod))
  57. @contextlib.contextmanager
  58. def dynamic_graph(self, force_to_use_cpu=False):
  59. paddle.disable_static()
  60. with fluid.dygraph.guard(
  61. self._get_place(force_to_use_cpu=force_to_use_cpu)):
  62. paddle.seed(self.seed)
  63. paddle.framework.random._manual_program_seed(self.seed)
  64. yield