diff --git a/.travis.yml b/.travis.yml index b48089d52ecdb..a2963311c5af8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,25 +6,25 @@ matrix: include: - os: linux dist: trusty - env: PYTHON=2.7 + env: PYTHON=2.7 PYTHONWARNINGS=ignore - os: linux dist: trusty - env: PYTHON=3.5 + env: PYTHON=3.5 PYTHONWARNINGS=ignore - os: osx osx_image: xcode7 - env: PYTHON=2.7 + env: PYTHON=2.7 PYTHONWARNINGS=ignore - os: osx osx_image: xcode7 - env: PYTHON=3.5 + env: PYTHON=3.5 PYTHONWARNINGS=ignore - os: linux dist: trusty env: - JDK='Oracle JDK 8' - - PYTHON=3.5 + - PYTHON=3.5 PYTHONWARNINGS=ignore install: - ./.travis/install-dependencies.sh - export PATH="$HOME/miniconda/bin:$PATH" @@ -33,7 +33,7 @@ matrix: - os: linux dist: trusty - env: LINT=1 + env: LINT=1 PYTHONWARNINGS=ignore before_install: # In case we ever want to use a different version of clang-format: #- wget -O - http://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - @@ -55,10 +55,13 @@ matrix: # Run Python linting, ignore dict vs {} (C408), others are defaults - flake8 --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 - .travis/format.sh --all + # Make sure that the README is formatted properly. + - cd python + - python setup.py check --restructuredtext --strict --metadata - os: linux dist: trusty - env: VALGRIND=1 PYTHON=2.7 + env: VALGRIND=1 PYTHON=2.7 PYTHONWARNINGS=ignore before_install: - sudo apt-get update -qq - sudo apt-get install -qq valgrind @@ -75,7 +78,7 @@ matrix: # Build Linux wheels. - os: linux dist: trusty - env: LINUX_WHEELS=1 + env: LINUX_WHEELS=1 PYTHONWARNINGS=ignore install: - ./.travis/install-dependencies.sh # This command should be kept in sync with ray/python/README-building-wheels.md. @@ -86,7 +89,7 @@ matrix: # Build MacOS wheels. - os: osx osx_image: xcode7 - env: MAC_WHEELS=1 + env: MAC_WHEELS=1 PYTHONWARNINGS=ignore install: - ./.travis/install-dependencies.sh # This command should be kept in sync with ray/python/README-building-wheels.md. @@ -100,6 +103,7 @@ matrix: env: - PYTHON=3.5 - RAY_USE_NEW_GCS=on + - PYTHONWARNINGS=ignore install: @@ -131,45 +135,53 @@ script: # module is only found if the test directory is in the PYTHONPATH. - export PYTHONPATH="$PYTHONPATH:./test/" - - python -m pytest -v python/ray/test/test_global_state.py - - python -m pytest -v python/ray/test/test_queue.py - - python -m pytest -v python/ray/test/test_ray_init.py - - python -m pytest -v test/xray_test.py - - - python -m pytest -v test/runtest.py - - python -m pytest -v test/array_test.py - - python -m pytest -v test/actor_test.py - - python -m pytest -v test/autoscaler_test.py - - python -m pytest -v test/tensorflow_test.py - - python -m pytest -v test/failure_test.py - - python -m pytest -v test/microbenchmarks.py - - python -m pytest -v test/stress_tests.py - - python -m pytest -v test/component_failures_test.py - - python -m pytest -v test/multi_node_test.py - - python -m pytest -v test/multi_node_test_2.py - - python -m pytest -v test/recursion_test.py - - python -m pytest -v test/monitor_test.py - - python -m pytest -v test/cython_test.py - - python -m pytest -v test/credis_test.py - - python -m pytest -v test/node_manager_test.py - # ray tune tests - python python/ray/tune/test/dependency_test.py - - python -m pytest -v python/ray/tune/test/trial_runner_test.py - - python -m pytest -v python/ray/tune/test/trial_scheduler_test.py - - python -m pytest -v python/ray/tune/test/experiment_test.py - - python -m pytest -v python/ray/tune/test/tune_server_test.py - - python -m pytest -v python/ray/tune/test/ray_trial_executor_test.py - - python -m pytest -v python/ray/tune/test/automl_searcher_test.py + - python -m pytest -v --durations=10 python/ray/tune/test/trial_runner_test.py + - python -m pytest -v --durations=10 python/ray/tune/test/trial_scheduler_test.py + - python -m pytest -v --durations=10 python/ray/tune/test/experiment_test.py + - python -m pytest -v --durations=10 python/ray/tune/test/tune_server_test.py + - python -m pytest -v --durations=10 python/ray/tune/test/ray_trial_executor_test.py + - python -m pytest -v --durations=10 python/ray/tune/test/automl_searcher_test.py # ray rllib tests - - python -m pytest -v python/ray/rllib/test/test_catalog.py - - python -m pytest -v python/ray/rllib/test/test_filters.py - - python -m pytest -v python/ray/rllib/test/test_optimizers.py - - python -m pytest -v python/ray/rllib/test/test_evaluators.py + - python -m pytest -v --durations=10 python/ray/rllib/test/test_catalog.py + - python -m pytest -v --durations=10 python/ray/rllib/test/test_filters.py + - python -m pytest -v --durations=10 python/ray/rllib/test/test_optimizers.py + - python -m pytest -v --durations=10 python/ray/rllib/test/test_evaluators.py + + # Python3.5+ only. Otherwise we will get `SyntaxError` regardless of how we set the tester. + - python -c 'import sys;exit(sys.version_info>=(3,5))' || python -m pytest -v --durations=10 python/ray/experimental/test/async_test.py + + - python -m pytest -v --durations=10 python/ray/test/test_global_state.py + - python -m pytest -v --durations=10 python/ray/test/test_queue.py + - python -m pytest -v --durations=10 python/ray/test/test_ray_init.py + - python -m pytest -v --durations=10 test/xray_test.py + + - python -m pytest -v --durations=10 test/runtest.py + - python -m pytest -v --durations=10 test/array_test.py + - python -m pytest -v --durations=10 test/actor_test.py + - python -m pytest -v --durations=10 test/autoscaler_test.py + - python -m pytest -v --durations=10 test/tensorflow_test.py + - python -m pytest -v --durations=10 test/failure_test.py + - python -m pytest -v --durations=10 test/microbenchmarks.py + - python -m pytest -v --durations=10 test/stress_tests.py + - python -m pytest -v --durations=10 test/component_failures_test.py + - python -m pytest -v --durations=10 test/multi_node_test.py + - python -m pytest -v --durations=10 test/multi_node_test_2.py + - python -m pytest -v --durations=10 test/recursion_test.py + - python -m pytest -v --durations=10 test/monitor_test.py + - python -m pytest -v --durations=10 test/cython_test.py + - python -m pytest -v --durations=10 test/credis_test.py + - python -m pytest -v --durations=10 test/node_manager_test.py + # TODO(yuhguo): object_manager_test.py requires a lot of CPU/memory, and + # better be put in Jenkins. However, it fails frequently in Jenkins, but + # works well in Travis. We should consider moving it back to Jenkins once + # we figure out the reason. + - python -m pytest -v --durations=10 test/object_manager_test.py # ray temp file tests - - python -m pytest -v test/tempfile_test.py + - python -m pytest -v --durations=10 test/tempfile_test.py # modin test files - python python/ray/test/test_modin.py diff --git a/.travis/install-dependencies.sh b/.travis/install-dependencies.sh index 0fb597d4686f8..c247ee92cc5c3 100755 --- a/.travis/install-dependencies.sh +++ b/.travis/install-dependencies.sh @@ -24,8 +24,8 @@ if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-4.5.4-Linux-x86_64.sh -O miniconda.sh -nv bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ - feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout + pip install -q cython==0.29.0 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ + feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then sudo apt-get update sudo apt-get install -y cmake pkg-config python-dev python-numpy build-essential autoconf curl libtool unzip @@ -33,7 +33,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then wget https://repo.continuum.io/miniconda/Miniconda3-4.5.4-Linux-x86_64.sh -O miniconda.sh -nv bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ + pip install -q cython==0.29.0 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed @@ -50,8 +50,8 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ - feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout + pip install -q cython==0.29.0 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ + feather-format lxml openpyxl xlrd py-spy setproctitle faulthandler pytest-timeout mock elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed which -s brew @@ -67,7 +67,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then wget https://repo.continuum.io/miniconda/Miniconda3-4.5.4-MacOSX-x86_64.sh -O miniconda.sh -nv bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install -q cython==0.27.3 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ + pip install -q cython==0.29.0 cmake tensorflow gym opencv-python pyyaml pandas==0.23.4 requests \ feather-format lxml openpyxl xlrd py-spy setproctitle pytest-timeout elif [[ "$LINT" == "1" ]]; then sudo apt-get update diff --git a/CMakeLists.txt b/CMakeLists.txt index a6734e62ce144..4779a03fcd00a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,7 +101,7 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") set(ray_file_list "src/ray/thirdparty/redis/src/redis-server" "src/ray/gcs/redis_module/libray_redis_module.so" - "src/ray/raylet/liblocal_scheduler_library_python.so" + "src/ray/raylet/libraylet_library_python.so" "src/ray/raylet/raylet_monitor" "src/ray/raylet/raylet") @@ -128,8 +128,8 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") add_dependencies(copy_ray copy_ray_files) # Make sure that the Python extensions are built before copying the files. - get_local_scheduler_library("python" LOCAL_SCHEDULER_LIBRARY_PYTHON) - add_dependencies(copy_ray ${LOCAL_SCHEDULER_LIBRARY_PYTHON}) + get_raylet_library("python" RAYLET_LIBRARY_PYTHON) + add_dependencies(copy_ray ${RAYLET_LIBRARY_PYTHON}) foreach(file ${ray_file_list}) add_custom_command(TARGET copy_ray POST_BUILD @@ -146,8 +146,8 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") endif() if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") - get_local_scheduler_library("java" LOCAL_SCHEDULER_LIBRARY_JAVA) - add_dependencies(copy_ray ${LOCAL_SCHEDULER_LIBRARY_JAVA}) + get_raylet_library("java" RAYLET_LIBRARY_JAVA) + add_dependencies(copy_ray ${RAYLET_LIBRARY_JAVA}) # copy libplasma_java files add_custom_command(TARGET copy_ray POST_BUILD diff --git a/README.rst b/README.rst index 7e50123d9bf6c..05313ee6cd329 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,4 @@ -.. raw:: html - - +.. image:: https://github.com/ray-project/ray/raw/master/doc/source/images/ray_header_logo.png .. image:: https://travis-ci.com/ray-project/ray.svg?branch=master :target: https://travis-ci.com/ray-project/ray @@ -8,7 +6,7 @@ .. image:: https://readthedocs.org/projects/ray/badge/?version=latest :target: http://ray.readthedocs.io/en/latest/?badge=latest -.. image:: https://img.shields.io/badge/pypi-0.6.0-blue.svg +.. image:: https://img.shields.io/badge/pypi-0.6.1-blue.svg :target: https://pypi.org/project/ray/ | @@ -41,12 +39,12 @@ Example Use Ray comes with libraries that accelerate deep learning and reinforcement learning development: -- `Ray Tune`_: Hyperparameter Optimization Framework -- `Ray RLlib`_: Scalable Reinforcement Learning +- `Tune`_: Hyperparameter Optimization Framework +- `RLlib`_: Scalable Reinforcement Learning - `Distributed Training `__ -.. _`Ray Tune`: http://ray.readthedocs.io/en/latest/tune.html -.. _`Ray RLlib`: http://ray.readthedocs.io/en/latest/rllib.html +.. _`Tune`: http://ray.readthedocs.io/en/latest/tune.html +.. _`RLlib`: http://ray.readthedocs.io/en/latest/rllib.html Installation ------------ diff --git a/cmake/Modules/ArrowExternalProject.cmake b/cmake/Modules/ArrowExternalProject.cmake index 5a054afc74f1b..67d09686b0fb7 100644 --- a/cmake/Modules/ArrowExternalProject.cmake +++ b/cmake/Modules/ArrowExternalProject.cmake @@ -14,11 +14,15 @@ # - PLASMA_STATIC_LIB # - PLASMA_SHARED_LIB -set(arrow_URL https://github.com/apache/arrow.git) -# The PR for this commit is https://github.com/apache/arrow/pull/3061. We +set(arrow_URL https://github.com/ray-project/arrow.git) +# This commit is based on https://github.com/apache/arrow/pull/3197. We # include the link here to make it easier to find the right commit because # Arrow often rewrites git history and invalidates certain commits. -set(arrow_TAG a667fca3b71772886bb2595986266d2039823dcc) +# It has been patched to fix an upstream symbol clash with TensorFlow, +# the patch is available at +# https://github.com/ray-project/arrow/commit/c347cd571e51723fc8512922f1b3a8e45e45b169 +# See the discussion in https://github.com/apache/arrow/pull/3177 +set(arrow_TAG c347cd571e51723fc8512922f1b3a8e45e45b169) set(ARROW_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/arrow-install) set(ARROW_HOME ${ARROW_INSTALL_PREFIX}) diff --git a/cmake/Modules/BoostExternalProject.cmake b/cmake/Modules/BoostExternalProject.cmake index 1fbbb0c0b58ef..f839304fbccfc 100644 --- a/cmake/Modules/BoostExternalProject.cmake +++ b/cmake/Modules/BoostExternalProject.cmake @@ -27,14 +27,15 @@ else() set(BOOST_ROOT ${Boost_INSTALL_PREFIX}) set(Boost_LIBRARY_DIR ${Boost_INSTALL_PREFIX}/lib) set(Boost_SYSTEM_LIBRARY ${Boost_LIBRARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}boost_system${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(Boost_THREAD_LIBRARY ${Boost_LIBRARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}boost_thread${CMAKE_STATIC_LIBRARY_SUFFIX}) set(Boost_FILESYSTEM_LIBRARY ${Boost_LIBRARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}boost_filesystem${CMAKE_STATIC_LIBRARY_SUFFIX}) #set(boost_URL https://github.com/boostorg/boost.git) - #set(boost_TAG boost-1.65.1) + #set(boost_TAG boost-1.68.0) - set(Boost_TAR_GZ_URL http://dl.bintray.com/boostorg/release/1.65.1/source/boost_1_65_1.tar.gz) + set(Boost_TAR_GZ_URL http://dl.bintray.com/boostorg/release/1.68.0/source/boost_1_68_0.tar.gz) set(Boost_BUILD_PRODUCTS ${Boost_SYSTEM_LIBRARY} ${Boost_FILESYSTEM_LIBRARY}) - set(Boost_URL_MD5 "ee64fd29a3fe42232c6ac3c419e523cf") + set(Boost_URL_MD5 "5d8b4503582fffa9eefdb9045359c239") set(Boost_USE_STATIC_LIBS ON) @@ -48,6 +49,6 @@ else() BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${Boost_BUILD_PRODUCTS} CONFIGURE_COMMAND ./bootstrap.sh - BUILD_COMMAND bash -c "./b2 cxxflags=-fPIC cflags=-fPIC variant=release link=static --with-filesystem --with-system --with-regex -j8 install --prefix=${Boost_INSTALL_PREFIX} > /dev/null" + BUILD_COMMAND bash -c "./b2 cxxflags=-fPIC cflags=-fPIC variant=release link=static --with-filesystem --with-system --with-thread --with-atomic --with-chrono --with-date_time --with-regex -j8 install --prefix=${Boost_INSTALL_PREFIX} > /dev/null" INSTALL_COMMAND "") endif () diff --git a/cmake/Modules/FlatBuffersExternalProject.cmake b/cmake/Modules/FlatBuffersExternalProject.cmake index 508010afced49..15be464491cc1 100644 --- a/cmake/Modules/FlatBuffersExternalProject.cmake +++ b/cmake/Modules/FlatBuffersExternalProject.cmake @@ -18,9 +18,9 @@ if(DEFINED ENV{RAY_FLATBUFFERS_HOME} AND EXISTS $ENV{RAY_FLATBUFFERS_HOME}) add_custom_target(flatbuffers_ep) else() - set(flatbuffers_VERSION "1.9.0") + set(flatbuffers_VERSION "1.10.0") set(flatbuffers_URL "https://github.com/google/flatbuffers/archive/v${flatbuffers_VERSION}.tar.gz") - set(flatbuffers_URL_MD5 "8be7513bf960034f6873326d09521a4b") + set(flatbuffers_URL_MD5 "f7d19a3f021d93422b0bc287d7148cd2") set(FLATBUFFERS_INSTALL_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/external/flatbuffers-install") diff --git a/cmake/Modules/ThirdpartyToolchain.cmake b/cmake/Modules/ThirdpartyToolchain.cmake index d71390662ca59..f95a1b2906e91 100644 --- a/cmake/Modules/ThirdpartyToolchain.cmake +++ b/cmake/Modules/ThirdpartyToolchain.cmake @@ -54,11 +54,14 @@ ADD_THIRDPARTY_LIB(boost_system STATIC_LIB ${Boost_SYSTEM_LIBRARY}) ADD_THIRDPARTY_LIB(boost_filesystem STATIC_LIB ${Boost_FILESYSTEM_LIBRARY}) +ADD_THIRDPARTY_LIB(boost_thread + STATIC_LIB ${Boost_THREAD_LIBRARY}) add_dependencies(boost_system boost_ep) add_dependencies(boost_filesystem boost_ep) +add_dependencies(boost_thread boost_ep) -add_custom_target(boost DEPENDS boost_system boost_filesystem) +add_custom_target(boost DEPENDS boost_system boost_filesystem boost_thread) # flatbuffers include(FlatBuffersExternalProject) @@ -120,6 +123,7 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") "PYARROW_WITH_TENSORFLOW=1" "PYARROW_BUNDLE_ARROW_CPP=1" "PARQUET_HOME=${PARQUET_HOME}" + "BOOST_ROOT=${BOOST_ROOT}" "PYARROW_WITH_PARQUET=1" "PYARROW_PARALLEL=") diff --git a/dev/RELEASE_PROCESS.rst b/dev/RELEASE_PROCESS.rst index fdf59c6ad7314..6c36a334963ff 100644 --- a/dev/RELEASE_PROCESS.rst +++ b/dev/RELEASE_PROCESS.rst @@ -3,23 +3,34 @@ Release Process This document describes the process for creating new releases. -1. **Testing:** Before a release is created, significant testing should be done. - In particular, our continuous integration currently does not include - sufficient coverage for testing at scale and for testing fault tolerance, so - it is important to run tests on clusters of hundreds of machines and to make - sure that the tests complete when some machines are killed. +1. First, you should build wheels that you'd like to use for testing. That can + be done by following the `documentation for building wheels`_. -2. **Libraries:** Make sure that the libraries (e.g., RLlib, Tune) are in a - releasable state. +2. **Testing:** Before a release is created, significant testing should be done. + Run the script `test/stress_tests/run_stress_tests.sh`_ and make sure it + passes. *And make sure it is testing the right version of Ray!* This will use + the autoscaler to start a bunch of machines and run some tests. Any new + stress tests should be added to this script so that they will be run + automatically for future release testing. -3. **Increment the Python version:** Create a PR that increments the Python +3. **Libraries:** Make sure that the libraries (e.g., RLlib, Tune, SGD) are in a + releasable state. TODO(rkn): These libraries should be tested automatically + by the script above, but they aren't yet. + +4. **Increment the Python version:** Create a PR that increments the Python package version. See `this example`_. -4. **Create a GitHub release:** Create a GitHub release through the `GitHub +5. **Create a GitHub release:** Create a GitHub release through the `GitHub website`_. The release should be created at the commit from the previous - step. This should include release notes. + step. This should include **release notes**. Copy the style and formatting + used by previous releases. + +6. **Python wheels:** The Python wheels will automatically be built on Travis + and uploaded to the ``ray-wheels`` S3 bucket. Download these wheels (e.g., + using ``wget``) and install them with ``pip`` and run some simple Ray scripts + to verify that they work. -5. **Upload to PyPI Test:** Upload the wheels to the PyPI test site using +7. **Upload to PyPI Test:** Upload the wheels to the PyPI test site using ``twine`` (ask Robert to add you as a maintainer to the PyPI project). You'll need to run a command like @@ -27,9 +38,9 @@ This document describes the process for creating new releases. twine upload --repository-url https://test.pypi.org/legacy/ ray/.whl/* - assuming that you've built all of the wheels and put them in ``ray/.whl`` - (note that you can also get them from the "ray-wheels" S3 bucket), - that you've installed ``twine``, and that you've made PyPI accounts. + assuming that you've downloaded the wheels from the ``ray-wheels`` S3 bucket + and put them in ``ray/.whl``, that you've installed ``twine`` through + ``pip``, and that you've made PyPI accounts. Test that you can install the wheels with pip from the PyPI test repository with @@ -43,9 +54,9 @@ This document describes the process for creating new releases. installed by checking ``ray.__version__`` and ``ray.__file__``. Do this at least for MacOS and for Linux, as well as for Python 2 and Python - 3. + 3. Also do this for different versions of MacOS. -6. **Upload to PyPI:** Now that you've tested the wheels on the PyPI test +8. **Upload to PyPI:** Now that you've tested the wheels on the PyPI test repository, they can be uploaded to the main PyPI repository. Be careful, **it will not be possible to modify wheels once you upload them**, so any mistake will require a new release. You can upload the wheels with a command @@ -64,5 +75,7 @@ This document describes the process for creating new releases. finds the correct Ray version, and successfully runs some simple scripts on both MacOS and Linux as well as Python 2 and Python 3. -.. _`this example`: https://github.com/ray-project/ray/pull/1745 +.. _`documentation for building wheels`: https://github.com/ray-project/ray/blob/master/python/README-building-wheels.md +.. _`test/stress_tests/run_stress_tests.sh`: https://github.com/ray-project/ray/blob/master/test/stress_tests/run_stress_tests.sh +.. _`this example`: https://github.com/ray-project/ray/pull/3420 .. _`GitHub website`: https://github.com/ray-project/ray/releases diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt index f598baa081679..7fa230d2edefd 100644 --- a/doc/requirements-doc.txt +++ b/doc/requirements-doc.txt @@ -1,5 +1,6 @@ colorama click +filelock flatbuffers funcsigs mock diff --git a/doc/source/async_api.rst b/doc/source/async_api.rst new file mode 100644 index 0000000000000..95867745f8ee6 --- /dev/null +++ b/doc/source/async_api.rst @@ -0,0 +1,87 @@ +Async API (Experimental) +======================== + +Since Python 3.5, it is possible to write concurrent code using the ``async/await`` `syntax `__. + +This document describes Ray's support for asyncio, which enables integration with popular async frameworks (e.g., aiohttp, aioredis, etc.) for high performance web and prediction serving. + +Starting Ray +------------ + +You must initialize Ray first. + +Please refer to `Starting Ray`_ for instructions. + +.. _`Starting Ray`: http://ray.readthedocs.io/en/latest/tutorial.html#starting-ray + + +Converting Ray objects into asyncio futures +------------------------------------------- + +Ray object IDs can be converted into asyncio futures with ``ray.experimental.async_api``. + +.. code-block:: python + + import asyncio + import time + import ray + from ray.experimental import async_api + + @ray.remote + def f(): + time.sleep(1) + return {'key1': ['value']} + + ray.init() + future = async_api.as_future(f.remote()) + asyncio.get_event_loop().run_until_complete(future) # {'key1': ['value']} + + +.. autofunction:: ray.experimental.async_api.as_future + + +Example Usage +------------- + ++----------------------------------------+-----------------------------------------------------+ +| **Basic Python** | **Distributed with Ray** | ++----------------------------------------+-----------------------------------------------------+ +| .. code-block:: python | .. code-block:: python | +| | | +| # Execute f serially. | # Execute f in parallel. | +| | | +| | | +| def f(): | @ray.remote | +| time.sleep(1) | def f(): | +| return 1 | time.sleep(1) | +| | return 1 | +| | | +| | ray.init() | +| results = [f() for i in range(4)] | results = ray.get([f.remote() for i in range(4)]) | ++----------------------------------------+-----------------------------------------------------+ +| **Async Python** | **Async Ray** | ++----------------------------------------+-----------------------------------------------------+ +| .. code-block:: python | .. code-block:: python | +| | | +| # Execute f asynchronously. | # Execute f asynchronously with Ray/asyncio. | +| | | +| | from ray.experimental import async_api | +| | | +| | @ray.remote | +| async def f(): | def f(): | +| await asyncio.sleep(1) | time.sleep(1) | +| return 1 | return 1 | +| | | +| | ray.init() | +| loop = asyncio.get_event_loop() | loop = asyncio.get_event_loop() | +| tasks = [f() for i in range(4)] | tasks = [async_api.as_future(f.remote()) | +| | for i in range(4)] | +| results = loop.run_until_complete( | results = loop.run_until_complete( | +| asyncio.gather(tasks)) | asyncio.gather(tasks)) | ++----------------------------------------+-----------------------------------------------------+ + + +Known Issues +------------ + +Async API support is experimental, and we are working to improve its performance. Please `let us know `__ any issues you encounter. diff --git a/doc/source/autoscaling.rst b/doc/source/autoscaling.rst index 90c8e92f3d278..64d57e0ef6878 100644 --- a/doc/source/autoscaling.rst +++ b/doc/source/autoscaling.rst @@ -1,7 +1,9 @@ -Cloud Setup and Auto-Scaling -============================ +Cluster Setup and Auto-Scaling +============================== -The ``ray up`` command starts or updates an AWS or GCP Ray cluster from your personal computer. Once the cluster is up, you can then SSH into it to run Ray programs. +This document provides instructions for launching a Ray cluster either privately, on AWS, or on GCP. + +The ``ray up`` command starts or updates a Ray cluster from your personal computer. Once the cluster is up, you can then SSH into it to run Ray programs. Quick start (AWS) ----------------- @@ -50,6 +52,28 @@ SSH into the head node and then run Ray programs with ``ray.init(redis_address=" # Teardown the cluster $ ray down ray/python/ray/autoscaler/gcp/example-full.yaml +Quick start (Private Cluster) +----------------------------- + +This is used when you have a list of machine IP addresses to connect in a Ray cluster. You can get started by filling out the fields in the provided `ray/python/ray/autoscaler/local/example-full.yaml `__. +Be sure to specify the proper ``head_ip``, list of ``worker_ips``, and the ``ssh_user`` field. + +Try it out by running these commands from your personal computer. Once the cluster is started, you can then +SSH into the head node and then run Ray programs with ``ray.init(redis_address="localhost:6379")``. + +.. code-block:: bash + + # Create or update the cluster. When the command finishes, it will print + # out the command that can be used to SSH into the cluster head node. + $ ray up ray/python/ray/autoscaler/local/example-full.yaml + + # Reconfigure autoscaling behavior without interrupting running jobs + $ ray up ray/python/ray/autoscaler/local/example-full.yaml \ + --max-workers=N --no-restart + + # Teardown the cluster + $ ray down ray/python/ray/autoscaler/local/example-full.yaml + Running commands on new and existing clusters --------------------------------------------- @@ -197,7 +221,8 @@ The ``example-full.yaml`` configuration is enough to get started with Ray, but f InstanceType: p2.8xlarge **Docker**: Specify docker image. This executes all commands on all nodes in the docker container, -and opens all the necessary ports to support the Ray cluster. This currently does not have GPU support. +and opens all the necessary ports to support the Ray cluster. It will also automatically install +Docker if Docker is not installed. This currently does not have GPU support. .. code-block:: yaml @@ -264,3 +289,15 @@ Additional Cloud providers -------------------------- To use Ray autoscaling on other Cloud providers or cluster management systems, you can implement the ``NodeProvider`` interface (~100 LOC) and register it in `node_provider.py `__. Contributions are welcome! + +Questions or Issues? +-------------------- + +You can post questions or issues or feedback through the following channels: + +1. `Our Mailing List`_: For discussions about development, questions about + usage, or any general questions and feedback. +2. `GitHub Issues`_: For bug reports and feature requests. + +.. _`Our Mailing List`: https://groups.google.com/forum/#!forum/ray-dev +.. _`GitHub Issues`: https://github.com/ray-project/ray/issues diff --git a/doc/source/conf.py b/doc/source/conf.py index 2a2b1a37c207e..8193ccf408680 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -50,6 +50,7 @@ "ray.core.generated.ray.protocol.Task", "ray.core.generated.TablePrefix", "ray.core.generated.TablePubsub", + "ray.core.generated.Language", ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/doc/source/distributed_sgd.rst b/doc/source/distributed_sgd.rst index 5d1e480766258..e5084762466b0 100644 --- a/doc/source/distributed_sgd.rst +++ b/doc/source/distributed_sgd.rst @@ -8,7 +8,7 @@ Ray SGD is built on top of the Ray task and actor abstractions to provide seamle Interface --------- -To use Ray SGD, define a `model class `__ with ``loss`` and ``optimizer`` attributes: +To use Ray SGD, define a `model class `__: .. autoclass:: ray.experimental.sgd.Model diff --git a/doc/source/example-a3c.rst b/doc/source/example-a3c.rst index 23a6a3e158115..47378fce9f915 100644 --- a/doc/source/example-a3c.rst +++ b/doc/source/example-a3c.rst @@ -29,7 +29,7 @@ You can run the code with .. code-block:: bash - python/ray/rllib/train.py --env=Pong-ram-v4 --run=A3C --config='{"num_workers": N}' + rllib train --env=Pong-ram-v4 --run=A3C --config='{"num_workers": N}' Reinforcement Learning ---------------------- diff --git a/doc/source/example-evolution-strategies.rst b/doc/source/example-evolution-strategies.rst index 8f613b08dcdd5..d048d261fff95 100644 --- a/doc/source/example-evolution-strategies.rst +++ b/doc/source/example-evolution-strategies.rst @@ -18,13 +18,13 @@ on the ``Humanoid-v1`` gym environment. .. code-block:: bash - python/ray/rllib/train.py --env=Humanoid-v1 --run=ES + rllib train --env=Humanoid-v1 --run=ES To train a policy on a cluster (e.g., using 900 workers), run the following. .. code-block:: bash - python ray/python/ray/rllib/train.py \ + rllib train \ --env=Humanoid-v1 \ --run=ES \ --redis-address= \ diff --git a/doc/source/example-policy-gradient.rst b/doc/source/example-policy-gradient.rst index 3fccb992ad16c..9b58575044c3b 100644 --- a/doc/source/example-policy-gradient.rst +++ b/doc/source/example-policy-gradient.rst @@ -21,7 +21,7 @@ Then you can run the example as follows. .. code-block:: bash - python/ray/rllib/train.py --env=Pong-ram-v4 --run=PPO + rllib train --env=Pong-ram-v4 --run=PPO This will train an agent on the ``Pong-ram-v4`` Atari environment. You can also try passing in the ``Pong-v0`` environment or the ``CartPole-v0`` environment. diff --git a/doc/source/images/ray_header_logo.png b/doc/source/images/ray_header_logo.png new file mode 100644 index 0000000000000..ecae748033ed1 Binary files /dev/null and b/doc/source/images/ray_header_logo.png differ diff --git a/doc/source/index.rst b/doc/source/index.rst index e5ebdae4536cf..7f542c9d0eceb 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -32,7 +32,7 @@ Example Use | results = [f() for i in range(4)] | results = ray.get([f.remote() for i in range(4)]) | +------------------------------------------------+----------------------------------------------------+ - +To launch a Ray cluster, either privately, on AWS, or on GCP, `follow these instructions `_. View the `codebase on GitHub`_. @@ -65,6 +65,14 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin actors.rst using-ray-with-gpus.rst webui.rst + async_api.rst + +.. toctree:: + :maxdepth: 1 + :caption: Cluster Usage + + autoscaling.rst + using-ray-on-a-cluster.rst .. toctree:: :maxdepth: 1 @@ -86,6 +94,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin rllib-env.rst rllib-algorithms.rst rllib-models.rst + rllib-dev.rst rllib-concepts.rst rllib-package-ref.rst @@ -123,15 +132,6 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin redis-memory-management.rst tempfile.rst -.. toctree:: - :maxdepth: 1 - :caption: Cluster Usage - - autoscaling.rst - using-ray-on-a-cluster.rst - using-ray-on-a-large-cluster.rst - using-ray-and-docker-on-a-cluster.md - .. toctree:: :maxdepth: 1 :caption: Help diff --git a/doc/source/install-on-docker.rst b/doc/source/install-on-docker.rst index 9fa245c160b9b..6baa0363f69ad 100644 --- a/doc/source/install-on-docker.rst +++ b/doc/source/install-on-docker.rst @@ -1,7 +1,7 @@ Installation on Docker ====================== -You can install Ray on any platform that runs Docker. We do not presently +You can install Ray from source on any platform that runs Docker. We do not presently publish Docker images for Ray, but you can build them yourself using the Ray distribution. @@ -25,6 +25,8 @@ the corresponding installation instructions. Linux user may find these Docker installation on EC2 with Ubuntu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. note:: The Ray `autoscaler `_ can automatically install Docker on all of the nodes of your cluster. + The instructions below show in detail how to prepare an Amazon EC2 instance running Ubuntu 16.04 for use with Docker. @@ -165,14 +167,6 @@ Launch the examples container. docker run --shm-size=1024m -t -i ray-project/examples -Hyperparameter optimization -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: bash - - cd /ray/examples/hyperopt/ - python /ray/examples/hyperopt/hyperopt_simple.py - Batch L-BFGS ~~~~~~~~~~~~ diff --git a/doc/source/installation.rst b/doc/source/installation.rst index 68bd37ae96f5d..daf9a8cf8b3e9 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -17,10 +17,6 @@ Trying snapshots from master Here are links to the latest wheels (which are built off of master). To install these wheels, run the following command: -.. danger:: - - These versions will have newer features but are subject to more bugs. If you encounter crashes or other instabilities, please revert to the latest stable version. - .. code-block:: bash pip install -U [link to wheel] @@ -37,16 +33,16 @@ Here are links to the latest wheels (which are built off of master). To install =================== =================== -.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp37-cp37m-manylinux1_x86_64.whl -.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp36-cp36m-manylinux1_x86_64.whl -.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp35-cp35m-manylinux1_x86_64.whl -.. _`Linux Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp34-cp34m-manylinux1_x86_64.whl -.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp27-cp27mu-manylinux1_x86_64.whl -.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp37-cp37m-macosx_10_6_intel.whl -.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp36-cp36m-macosx_10_6_intel.whl -.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp35-cp35m-macosx_10_6_intel.whl -.. _`MacOS Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp34-cp34m-macosx_10_6_intel.whl -.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp27-cp27m-macosx_10_6_intel.whl +.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp37-cp37m-manylinux1_x86_64.whl +.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp36-cp36m-manylinux1_x86_64.whl +.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp35-cp35m-manylinux1_x86_64.whl +.. _`Linux Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp34-cp34m-manylinux1_x86_64.whl +.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp27-cp27mu-manylinux1_x86_64.whl +.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp37-cp37m-macosx_10_6_intel.whl +.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp36-cp36m-macosx_10_6_intel.whl +.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp35-cp35m-macosx_10_6_intel.whl +.. _`MacOS Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp34-cp34m-macosx_10_6_intel.whl +.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp27-cp27m-macosx_10_6_intel.whl Building Ray from source @@ -76,7 +72,7 @@ For Ubuntu, run the following commands: # If you are on Ubuntu 14.04, you need the following. pip install cmake - pip install cython==0.27.3 + pip install cython==0.29.0 For MacOS, run the following commands: @@ -85,7 +81,7 @@ For MacOS, run the following commands: brew update brew install cmake pkg-config automake autoconf libtool openssl bison wget - pip install cython==0.27.3 + pip install cython==0.29.0 If you are using Anaconda, you may also need to run the following. @@ -120,7 +116,8 @@ that you've cloned the git repository. .. code-block:: bash - python test/runtest.py + export PYTHONPATH="$PYTHONPATH:./test/" + python -m pytest test/runtest.py Cleaning the source tree ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -131,4 +128,18 @@ The source tree can be cleaned by running git clean -f -f -x -d -in the ``ray/`` directory. +in the ``ray/`` directory. Warning: this command will delete all untracked files +and directories and will reset the repository to its checked out state. +For a shallower working directory cleanup, you may want to try: + +.. code-block:: bash + + rm -rf ./build + +under ``ray/``. Incremental builds should work as follows: + +.. code-block:: bash + + pushd ./build && make && popd + +under ``ray/``. diff --git a/doc/source/redis-memory-management.rst b/doc/source/redis-memory-management.rst index 91b207db54e11..5e6edcc02f6c4 100644 --- a/doc/source/redis-memory-management.rst +++ b/doc/source/redis-memory-management.rst @@ -7,92 +7,9 @@ servers, as described in `An Overview of the Internals task/object generation rate could risk high memory pressure, potentially leading to out-of-memory (OOM) errors. -Here, we describe an experimental feature that transparently flushes metadata -entries out of Redis memory. +In Ray `0.6.1+` Redis shards can be configured to LRU evict task and object +metadata by setting ``redis_max_memory`` when starting Ray. This supercedes the +previously documented flushing functionality. -Requirements ------------- - -As of early July 2018, the automatic memory management feature requires building -Ray from source. We are planning on eliminating this step in the near future by -releasing official wheels. - -Building Ray -~~~~~~~~~~~~ - -First, follow `instructions to build Ray from source -`__ to install prerequisites. After -the prerequisites are installed, instead of doing the regular ``pip install`` as -referenced in that document, pass an additional special flag, -``RAY_USE_NEW_GCS=on``: - -.. code-block:: bash - - git clone https://github.com/ray-project/ray.git - cd ray/python - RAY_USE_NEW_GCS=on pip install -e . --verbose # Add --user if you see a permission denied error. - -Running Ray applications -~~~~~~~~~~~~~~~~~~~~~~~~ - -At run time the environment variables ``RAY_USE_NEW_GCS=on`` and -``RAY_USE_XRAY=1`` are required. - -.. code-block:: bash - - export RAY_USE_NEW_GCS=on - export RAY_USE_XRAY=1 - python my_ray_script.py # Or launch python/ipython. - -Activate memory flushing ------------------------- - -After building Ray using the method above, simply add these two lines after -``ray.init()`` to activate automatic memory flushing: - -.. code-block:: python - - ray.init(...) - - policy = ray.experimental.SimpleGcsFlushPolicy() - ray.experimental.set_flushing_policy(policy) - - # My awesome Ray application logic follows. - -Paramaters of the flushing policy ---------------------------------- - -There are three `user-configurable parameters -`_ -of the ``SimpleGcsFlushPolicy``: - -* ``flush_when_at_least_bytes``: Wait until this many bytes of memory usage - accumulated in the redis server before flushing kicks in. -* ``flush_period_secs``: Issue a flush to the Redis server every this many - seconds. -* ``flush_num_entries_each_time``: A hint to the system on the number of entries - to flush on each request. - -The default values should serve to be non-invasive for lightweight Ray -applications. ``flush_when_at_least_bytes`` is set to ``(1<<31)`` or 2GB, -``flush_period_secs`` to 10, and ``flush_num_entries_each_time`` to 10000: - -.. code-block:: python - - # Default parameters. - ray.experimental.SimpleGcsFlushPolicy( - flush_when_at_least_bytes=(1 << 31), - flush_period_secs=10, - flush_num_entries_each_time=10000) - -In particular, these default values imply that - -1. the Redis server would accumulate memory usage up to 2GB without any entries -being flushed, then the flushing would kick in; and - -2. generally, "older" metadata entries would be flushed first, and the Redis -server would always keep the most recent window of metadata of 2GB in size. - -**For advanced users.** Advanced users can tune the above parameters to their -applications' needs; note that the desired flush rate is equal to (flush -period) * (num entries each flush). +Note that profiling is disabled when ``redis_max_memory`` is set. This is because +profiling data cannot be LRU evicted. diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 1d0501215745c..cb12ade0f1504 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -8,7 +8,7 @@ Distributed Prioritized Experience Replay (Ape-X) ------------------------------------------------- `[paper] `__ `[implementation] `__ -Ape-X variations of DQN and DDPG (`APEX_DQN `__, `APEX_DDPG `__ in RLlib) use a single GPU learner and many CPU workers for experience collection. Experience collection can scale to hundreds of CPU workers due to the distributed prioritization of experience prior to storage in replay buffers. +Ape-X variations of DQN, DDPG, and QMIX (`APEX_DQN `__, `APEX_DDPG `__, `APEX_QMIX `__) use a single GPU learner and many CPU workers for experience collection. Experience collection can scale to hundreds of CPU workers due to the distributed prioritization of experience prior to storage in replay buffers. Tuned examples: `PongNoFrameskip-v4 `__, `Pendulum-v0 `__, `MountainCarContinuous-v0 `__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 `__. @@ -245,3 +245,18 @@ Tuned examples: `Humanoid-v1 `__ `[implementation] `__ Q-Mix is a specialized multi-agent algorithm. Code here is adapted from https://github.com/oxwhirl/pymarl_alpha to integrate with RLlib multi-agent APIs. To use Q-Mix, you must specify an agent `grouping `__ in the environment (see the `two-step game example `__). Currently, all agents in the group must be homogeneous. The algorithm can be scaled by increasing the number of workers or using Ape-X. + +Q-Mix is implemented in `PyTorch `__ and is currently *experimental*. + +Tuned examples: `Two-step game `__ + +**QMIX-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/qmix/qmix.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ diff --git a/doc/source/rllib-dev.rst b/doc/source/rllib-dev.rst new file mode 100644 index 0000000000000..688b59b7eda57 --- /dev/null +++ b/doc/source/rllib-dev.rst @@ -0,0 +1,70 @@ +RLlib Development +================= + +Development Install +------------------- + +You can develop RLlib locally without needing to compile Ray by using the `setup-rllib-dev.py `__ script. This sets up links between the ``rllib`` dir in your git repo and the one bundled with the ``ray`` package. When using this script, make sure that your git branch is in sync with the installed Ray binaries (i.e., you are up-to-date on `master `__ and have the latest `wheel `__ installed.) + +Features +-------- + +Feature development and upcoming priorities are tracked on the `RLlib project board `__ (note that this may not include all development efforts). For discussion of issues and new features, we use the `Ray dev list `__ and `GitHub issues page `__. + +Benchmarks +---------- + +A number of training run results are available in the `rl-experiments repo `__, and there is also a list of working hyperparameter configurations in `tuned_examples `__. Benchmark results are extremely valuable to the community, so if you happen to have results that may be of interest, consider making a pull request to either repo. + +Contributing Algorithms +----------------------- + +These are the guidelines for merging new algorithms into RLlib: + +* Contributed algorithms (`rllib/contrib `__): + - must subclass Agent and implement the ``_train()`` method + - must include a lightweight test (`example `__) to ensure the algorithm runs + - should include tuned hyperparameter examples and documentation + - should offer functionality not present in existing algorithms + +* Fully integrated algorithms (`rllib/agents `__) have the following additional requirements: + - must fully implement the Agent API + - must offer substantial new functionality not possible to add to other algorithms + - should support custom models and preprocessors + - should use RLlib abstractions and support distributed execution + +Both integrated and contributed algorithms ship with the ``ray`` PyPI package, and are tested as part of Ray's automated tests. The main difference between contributed and fully integrated algorithms is that the latter will be maintained by the Ray team to a much greater extent with respect to bugs and integration with RLlib features. + +How to add an algorithm to ``contrib`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +It takes just two changes to add an algorithm to `contrib `__. A minimal example can be found `here `__. First, subclass `Agent `__ and implement the ``_init`` and ``_train`` methods: + +.. literalinclude:: ../../python/ray/rllib/contrib/random_agent/random_agent.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + +Second, register the agent with a name in `contrib/registry.py `__. + +.. code-block:: python + + def _import_random_agent(): + from ray.rllib.contrib.random_agent.random_agent import RandomAgent + return RandomAgent + + def _import_random_agent_2(): + from ray.rllib.contrib.random_agent_2.random_agent_2 import RandomAgent2 + return RandomAgent2 + + CONTRIBUTED_ALGORITHMS = { + "contrib/RandomAgent": _import_random_agent, + "contrib/RandomAgent2": _import_random_agent_2, + # ... + } + +After registration, you can run and visualize agent progress using ``rllib train``: + +.. code-block:: bash + + rllib train --run=contrib/RandomAgent --env=CartPole-v0 + tensorboard --logdir=~/ray_results diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index ca36186e1a5f6..0409f3f0e6093 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -20,31 +20,47 @@ APEX-DQN **Yes** `+parametric`_ No **Yes** No APEX-DDPG No **Yes** **Yes** No ES **Yes** **Yes** No No ARS **Yes** **Yes** No No +QMIX **Yes** No **Yes** **Yes** ============= ======================= ================== =========== ================== .. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces -In the high-level agent APIs, environments are identified with string names. By default, the string will be interpreted as a gym `environment name `__, however you can also register custom environments by name: +You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name `__. Custom env classes passed directly to the agent must take a single ``env_config`` parameter in their constructor: .. code-block:: python - import ray - from ray.tune.registry import register_env + import gym, ray from ray.rllib.agents import ppo - def env_creator(env_config): - import gym - return gym.make("CartPole-v0") # or return your own custom env + class MyEnv(gym.Env): + def __init__(self, env_config): + self.action_space = + self.observation_space = + def reset(self): + return + def step(self, action): + return , , , - register_env("my_env", env_creator) ray.init() - trainer = ppo.PPOAgent(env="my_env", config={ - "env_config": {}, # config to pass to env creator + trainer = ppo.PPOAgent(env=MyEnv, config={ + "env_config": {}, # config to pass to env class }) while True: print(trainer.train()) +You can also register a custom env creator function with a string name. This function must take a single ``env_config`` parameter and return an env instance: + +.. code-block:: python + + from ray.tune.registry import register_env + + def env_creator(env_config): + return MyEnv(...) # return an env instance + + register_env("my_env", env_creator) + trainer = ppo.PPOAgent(env="my_env") + Configuring Environments ------------------------ @@ -95,6 +111,10 @@ RLlib will auto-vectorize Gym envs for batch evaluation if the ``num_envs_per_wo Multi-Agent ----------- +.. note:: + + Learn more about multi-agent reinforcement learning in RLlib by reading the `blog post `__. + A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car" and "traffic light" agents in the environment. The model for multi-agent in RLlib as follows: (1) as a user you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure: .. image:: multi-agent.svg @@ -183,6 +203,16 @@ Here is a simple `example training script 1``. +Grouping Agents +~~~~~~~~~~~~~~~ + +It is common to have groups of agents in multi-agent RL. RLlib treats agent groups like a single agent with a Tuple action and observation space. The group agent can then be assigned to a single policy for centralized execution, or to specialized multi-agent policies such as `Q-Mix `__ that implement centralized training but decentralized execution. You can use the ``MultiAgentEnv.with_agent_groups()`` method to define these groups: + +.. literalinclude:: ../../python/ray/rllib/env/multi_agent_env.py + :language: python + :start-after: __grouping_doc_begin__ + :end-before: __grouping_doc_end__ + Variable-Sharing Between Policies ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 09c49e2751bf0..9e7070b66c489 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -233,7 +233,7 @@ In this example we overrode existing methods of the existing DDPG policy graph, Variable-length / Parametric Action Spaces ------------------------------------------ -Custom models can be used to work with environments where (1) the set of valid actions varies per step, and/or (2) the number of valid actions is very large, as in `OpenAI Five `__ and `Horizon `__. The general idea is that the meaning of actions can be completely conditioned on the observation, that is, the ``a`` in ``Q(s, a)`` is just a token in ``[0, MAX_AVAIL_ACTIONS)`` that only has meaning in the context of ``s``. This works with algorithms in the `DQN and policy-gradient families `__ and can be implemented as follows: +Custom models can be used to work with environments where (1) the set of valid actions varies per step, and/or (2) the number of valid actions is very large, as in `OpenAI Five `__ and `Horizon `__. The general idea is that the meaning of actions can be completely conditioned on the observation, i.e., the ``a`` in ``Q(s, a)`` becomes just a token in ``[0, MAX_AVAIL_ACTIONS)`` that only has meaning in the context of ``s``. This works with algorithms in the `DQN and policy-gradient families `__ and can be implemented as follows: 1. The environment should return a mask and/or list of valid action embeddings as part of the observation for each step. To enable batching, the number of actions can be allowed to vary from 1 to some max number: diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index dc37d22943ba7..291ae0462eb43 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -10,11 +10,11 @@ be trained, checkpointed, or an action computed. .. image:: rllib-api.svg -You can train a simple DQN agent with the following command +You can train a simple DQN agent with the following command: .. code-block:: bash - python ray/python/ray/rllib/train.py --run DQN --env CartPole-v0 + rllib train --run DQN --env CartPole-v0 By default, the results will be logged to a subdirectory of ``~/ray_results``. This subdirectory will contain a file ``params.json`` which contains the @@ -26,10 +26,12 @@ training process with TensorBoard by running tensorboard --logdir=~/ray_results -The ``train.py`` script has a number of options you can show by running +The ``rllib train`` command (same as the ``train.py`` script in the repo) has a number of options you can show by running: .. code-block:: bash + rllib train --help + -or- python ray/python/ray/rllib/train.py --help The most important options are for choosing the environment @@ -42,16 +44,16 @@ Evaluating Trained Agents In order to save checkpoints from which to evaluate agents, set ``--checkpoint-freq`` (number of training iterations between checkpoints) -when running ``train.py``. +when running ``rllib train``. An example of evaluating a previously trained DQN agent is as follows: .. code-block:: bash - python ray/python/ray/rllib/rollout.py \ - ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \ - --run DQN --env CartPole-v0 --steps 10000 + rllib rollout \ + ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \ + --run DQN --env CartPole-v0 --steps 10000 The ``rollout.py`` helper script reconstructs a DQN agent from the checkpoint located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1`` @@ -70,13 +72,12 @@ In an example below, we train A2C by specifying 8 workers through the config fla .. code-block:: bash - python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \ - --run=A2C --config '{"num_workers": 8}' + rllib train --env=PongDeterministic-v4 --run=A2C --config '{"num_workers": 8}' Specifying Resources ~~~~~~~~~~~~~~~~~~~~ -You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five agents onto one GPU by setting ``num_gpus: 0.2``. Note that in Ray < 0.6.0 fractional GPU support requires setting the environment variable ``RAY_USE_XRAY=1``. +You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five agents onto one GPU by setting ``num_gpus: 0.2``. .. image:: rllib-config.svg @@ -98,11 +99,11 @@ Some good hyperparameters and settings are available in (some of them are tuned to run on GPUs). If you find better settings or tune an algorithm on a different domain, consider submitting a Pull Request! -You can run these with the ``train.py`` script as follows: +You can run these with the ``rllib train`` command as follows: .. code-block:: bash - python ray/python/ray/rllib/train.py -f /path/to/tuned/example.yaml + rllib train -f /path/to/tuned/example.yaml Python API ---------- @@ -224,39 +225,10 @@ Sometimes, it is necessary to coordinate between pieces of code that live in dif Ray actors provide high levels of performance, so in more complex cases they can be used implement communication patterns such as parameter servers and allreduce. -Debugging ---------- - -Gym Monitor -~~~~~~~~~~~ - -The ``"monitor": true`` config can be used to save Gym episode videos to the result dir. For example: - -.. code-block:: bash - - python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \ - --run=A2C --config '{"num_workers": 2, "monitor": true}' - - # videos will be saved in the ~/ray_results/ dir, for example - openaigym.video.0.31401.video000000.meta.json - openaigym.video.0.31401.video000000.mp4 - openaigym.video.0.31403.video000000.meta.json - openaigym.video.0.31403.video000000.mp4 - -Log Verbosity -~~~~~~~~~~~~~ - -You can control the agent log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example: - -.. code-block:: bash - - python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \ - --run=A2C --config '{"num_workers": 2, "log_level": "DEBUG"}' - Callbacks and Custom Metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -You can provide callback functions to be called at points during policy evaluation. These functions have access to an info dict containing state for the current `episode `__. Custom state can be stored for the `episode `__ in the ``info["episode"].user_data`` dict, and custom scalar metrics reported by saving values to the ``info["episode"].custom_metrics`` dict. These custom metrics will be averaged and reported as part of training results. The following example (full code `here `__) logs a custom metric from the environment: +You can provide callback functions to be called at points during policy evaluation. These functions have access to an info dict containing state for the current `episode `__. Custom state can be stored for the `episode `__ in the ``info["episode"].user_data`` dict, and custom scalar metrics reported by saving values to the ``info["episode"].custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. The following example (full code `here `__) logs a custom metric from the environment: .. code-block:: python @@ -273,10 +245,14 @@ You can provide callback functions to be called at points during policy evaluati def on_episode_end(info): episode = info["episode"] - mean_pole_angle = np.mean(episode.user_data["pole_angles"]) + pole_angle = np.mean(episode.user_data["pole_angles"]) print("episode {} ended with length {} and pole angles {}".format( - episode.episode_id, episode.length, mean_pole_angle)) - episode.custom_metrics["mean_pole_angle"] = mean_pole_angle + episode.episode_id, episode.length, pole_angle)) + episode.custom_metrics["pole_angle"] = pole_angle + + def on_train_result(info): + print("agent.train() result: {} -> {} episodes".format( + info["agent"].__name__, info["result"]["episodes_this_iter"])) ray.init() trials = tune.run_experiments({ @@ -288,6 +264,7 @@ You can provide callback functions to be called at points during policy evaluati "on_episode_start": tune.function(on_episode_start), "on_episode_step": tune.function(on_episode_step), "on_episode_end": tune.function(on_episode_end), + "on_train_result": tune.function(on_train_result), }, }, } @@ -297,6 +274,113 @@ Custom metrics can be accessed and visualized like any other training result: .. image:: custom_metric.png +Example: Curriculum Learning +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Let's look at two ways to use the above APIs to implement `curriculum learning `__. In curriculum learning, the agent task is adjusted over time to improve the learning process. Suppose that we have an environment class with a ``set_phase()`` method that we can call to adjust the task difficulty over time: + +Approach 1: Use the Agent API and update the environment between calls to ``train()``. This example shows the agent being run inside a Tune function: + +.. code-block:: python + + import ray + from ray import tune + from ray.rllib.agents.ppo import PPOAgent + + def train(config, reporter): + agent = PPOAgent(config=config, env=YourEnv) + while True: + result = agent.train() + reporter(**result) + if result["episode_reward_mean"] > 200: + phase = 2 + elif result["episode_reward_mean"] > 100: + phase = 1 + else: + phase = 0 + agent.optimizer.foreach_evaluator(lambda ev: ev.env.set_phase(phase)) + + ray.init() + tune.run_experiments({ + "curriculum": { + "run": train, + "config": { + "num_gpus": 0, + "num_workers": 2, + }, + "resources_per_trial": { + "cpu": 1, + "gpu": lambda spec: spec.config.num_gpus, + "extra_cpu": lambda spec: spec.config.num_workers, + }, + }, + }) + +Approach 2: Use the callbacks API to update the environment on new training results: + +.. code-block:: python + + import ray + from ray import tune + + def on_train_result(info): + result = info["result"] + if result["episode_reward_mean"] > 200: + phase = 2 + elif result["episode_reward_mean"] > 100: + phase = 1 + else: + phase = 0 + agent = info["agent"] + agent.optimizer.foreach_evaluator(lambda ev: ev.env.set_phase(phase)) + + ray.init() + tune.run_experiments({ + "curriculum": { + "run": "PPO", + "env": YourEnv, + "config": { + "callbacks": { + "on_train_result": tune.function(on_train_result), + }, + }, + }, + }) + +Debugging +--------- + +Gym Monitor +~~~~~~~~~~~ + +The ``"monitor": true`` config can be used to save Gym episode videos to the result dir. For example: + +.. code-block:: bash + + rllib train --env=PongDeterministic-v4 \ + --run=A2C --config '{"num_workers": 2, "monitor": true}' + + # videos will be saved in the ~/ray_results/ dir, for example + openaigym.video.0.31401.video000000.meta.json + openaigym.video.0.31401.video000000.mp4 + openaigym.video.0.31403.video000000.meta.json + openaigym.video.0.31403.video000000.mp4 + +Log Verbosity +~~~~~~~~~~~~~ + +You can control the agent log level via the ``"log_level"`` flag. Valid values are "INFO" (default), "DEBUG", "WARN", and "ERROR". This can be used to increase or decrease the verbosity of internal logging. For example: + +.. code-block:: bash + + rllib train --env=PongDeterministic-v4 \ + --run=A2C --config '{"num_workers": 2, "log_level": "DEBUG"}' + +Stack Traces +~~~~~~~~~~~~ + +You can use the ``ray stack`` command to dump the stack traces of all the Python workers on a single node. This can be useful for debugging unexpected hangs or performance issues. + REST API -------- diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index e96bd6fccbcb9..40f96b4b52adb 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -68,6 +68,10 @@ Algorithms - `Evolution Strategies `__ +* Multi-agent specific + + - `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) `__ + Models and Preprocessors ------------------------ * `RLlib Models and Preprocessors Overview `__ @@ -78,6 +82,14 @@ Models and Preprocessors * `Variable-length / Parametric Action Spaces `__ * `Model-Based Rollouts `__ +RLlib Development +----------------- + +* `Development Install `__ +* `Features `__ +* `Benchmarks `__ +* `Contributing Algorithms `__ + RLlib Concepts -------------- * `Policy Graphs `__ diff --git a/doc/source/tune-examples.rst b/doc/source/tune-examples.rst index e0af86bcb6956..65b419c3318e9 100644 --- a/doc/source/tune-examples.rst +++ b/doc/source/tune-examples.rst @@ -22,6 +22,8 @@ General Examples Example of using a Trainable class with PopulationBasedTraining scheduler. - `pbt_ppo_example `__: Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler. +- `logging_example `__: + Example of custom loggers and custom trial directory naming. Keras Examples diff --git a/doc/source/tune-package-ref.rst b/doc/source/tune-package-ref.rst index d2531254b7814..9168966e28108 100644 --- a/doc/source/tune-package-ref.rst +++ b/doc/source/tune-package-ref.rst @@ -29,10 +29,17 @@ ray.tune.suggest .. automodule:: ray.tune.suggest :members: - :exclude-members: function, grid_search, SuggestionAlgorithm + :exclude-members: function, sample_from, grid_search, SuggestionAlgorithm :show-inheritance: .. autoclass:: ray.tune.suggest.SuggestionAlgorithm :members: :private-members: :show-inheritance: + + +ray.tune.logger +--------------- + +.. autoclass:: ray.tune.logger.Logger + :members: diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index dc030bf6c4660..c4f846401b3cb 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -87,7 +87,7 @@ dictionary. Tune will convert the dict into an ``ray.tune.Experiment`` object. "alpha": tune.grid_search([0.2, 0.4, 0.6]), "beta": tune.grid_search([1, 2]), }, - "trial_resources": { "cpu": 1, "gpu": 0 }, + "resources_per_trial": { "cpu": 1, "gpu": 0 }, "num_samples": 10, "local_dir": "~/ray_results", "upload_dir": "s3://your_bucket/path", @@ -120,6 +120,35 @@ This function will report status on the command line until all Trials stop: An example of this can be found in `async_hyperband_example.py `__. +Custom Trial Names +~~~~~~~~~~~~~~~~~~ + +To specify custom trial names, you can pass use the ``trial_name_creator`` argument +in the Experiment object. This takes a function with the following signature, and +be sure to wrap it with `tune.function`: + +.. code-block:: python + + def trial_name_string(trial): + """ + Args: + trial (Trial): A generated trial object. + + Returns: + trial_name (str): String representation of Trial. + """ + return str(trial) + + exp = Experiment( + name="hyperband_test", + run=MyTrainableClass, + num_samples=1, + trial_name_creator=tune.function(trial_name_string) + ) + +An example can be found in `logging_example.py `__. + + Training Features ----------------- @@ -141,8 +170,8 @@ The following shows grid search over two nested parameters combined with random "my_experiment_name": { "run": my_trainable, "config": { - "alpha": lambda spec: np.random.uniform(100), - "beta": lambda spec: spec.config.alpha * np.random.normal(), + "alpha": tune.sample_from(lambda spec: np.random.uniform(100)), + "beta": tune.sample_from(lambda spec: spec.config.alpha * np.random.normal()), "nn_layers": [ tune.grid_search([16, 64, 256]), tune.grid_search([16, 64, 256]), @@ -153,7 +182,7 @@ The following shows grid search over two nested parameters combined with random .. note:: - Lambda functions will be evaluated during trial variant generation. If you need to pass a literal function in your config, use ``tune.function(...)`` to escape it. + Use ``tune.sample_from(...)`` to sample from a function during trial variant generation. If you need to pass a literal function in your config, use ``tune.function(...)`` to escape it. For more information on variant generation, see `basic_variant.py `__. @@ -169,8 +198,8 @@ By default, each random variable and grid search point is sampled once. To take "my_experiment_name": { "run": my_trainable, "config": { - "alpha": lambda spec: np.random.uniform(100), - "beta": lambda spec: spec.config.alpha * np.random.normal(), + "alpha": tune.sample_from(lambda spec: np.random.uniform(100)), + "beta": tune.sample_from(lambda spec: spec.config.alpha * np.random.normal()), "nn_layers": [ tune.grid_search([16, 64, 256]), tune.grid_search([16, 64, 256]), @@ -186,7 +215,7 @@ E.g. in the above, ``"num_samples": 10`` repeats the 3x3 grid search 10 times, f Using GPUs (Resource Allocation) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Tune will allocate the specified GPU and CPU ``trial_resources`` to each individual trial (defaulting to 1 CPU per trial). Under the hood, Tune runs each trial as a Ray actor, using Ray's resource handling to allocate resources and place actors. A trial will not be scheduled unless at least that amount of resources is available in the cluster, preventing the cluster from being overloaded. +Tune will allocate the specified GPU and CPU ``resources_per_trial`` to each individual trial (defaulting to 1 CPU per trial). Under the hood, Tune runs each trial as a Ray actor, using Ray's resource handling to allocate resources and place actors. A trial will not be scheduled unless at least that amount of resources is available in the cluster, preventing the cluster from being overloaded. Fractional values are also supported, (i.e., ``"gpu": 0.2``). You can find an example of this in the `Keras MNIST example `__. @@ -202,7 +231,7 @@ If your trainable function / class creates further Ray actors or tasks that also run_experiments({ "my_experiment_name": { "run": my_trainable, - "trial_resources": { + "resources_per_trial": { "cpu": 1, "gpu": 1, "extra_gpu": 4 @@ -317,7 +346,6 @@ The following fields will automatically show up on the console output, if provid Example_0: TERMINATED [pid=68248], 179 s, 2 iter, 60000 ts, 94 rew - Logging and Visualizing Results ------------------------------- @@ -355,6 +383,54 @@ Finally, to view the results with a `parallel coordinates visualization `__. + +You can also check out `logger.py `__ for implementation details. + +An example can be found in `logging_example.py `__. + +Custom Sync/Upload Commands +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If an upload directory is provided, Tune will automatically sync results to the given +directory with standard S3/gsutil commands. You can customize the upload command by +providing either a function or a string. + +If a string is provided, then it must include replacement fields ``{local_dir}`` and +``{remote_dir}``, like ``"aws s3 sync {local_dir} {remote_dir}"``. + +Alternatively, a function can be provided with the following signature (and must +be wrapped with ``tune.function``): + +.. code-block:: python + + def custom_sync_func(local_dir, remote_dir): + sync_cmd = "aws s3 sync {local_dir} {remote_dir}".format( + local_dir=local_dir, + remote_dir=remote_dir) + sync_process = subprocess.Popen(sync_cmd, shell=True) + sync_process.wait() + + exp = Experiment( + name="experiment_name", + run=MyTrainableClass, + sync_function=tune.function(custom_sync_func) + ) + Client API ---------- diff --git a/doc/source/using-ray-and-docker-on-a-cluster.md b/doc/source/using-ray-and-docker-on-a-cluster.md deleted file mode 100644 index 4e7b7a52d9bd6..0000000000000 --- a/doc/source/using-ray-and-docker-on-a-cluster.md +++ /dev/null @@ -1,236 +0,0 @@ -# Using Ray and Docker on a Cluster (Experimental) - -Packaging and deploying an application using Docker can provide certain advantages. It can make managing dependencies easier, help ensure that each cluster node receives a uniform configuration, and facilitate swapping hardware resources between applications. - - -## Create your Docker image - -First build a Ray Docker image by following the instructions for [Installation on Docker](install-on-docker.md). -This will allow you to create the `ray-project/deploy` image that serves as a basis for using Ray on a cluster with Docker. - -Docker images encapsulate the system state that will be used to run nodes in the cluster. -We recommend building on top of the Ray-provided Docker images to add your application code and dependencies. - -You can do this in one of two ways: by building from a customized Dockerfile or by saving an image after entering commands manually into a running container. -We describe both approaches below. - -### Creating a customized Dockerfile - -We recommend that you read the official Docker documentation for [Building your own image](https://docs.docker.com/engine/getstarted/step_four/) ahead of starting this section. -Your customized Dockerfile is a script of commands needed to set up your application, -possibly packaged in a folder with related resources. - -A simple template Dockerfile for a Ray application looks like this: - -``` -# Application Dockerfile template -FROM ray-project/deploy -RUN git clone -RUN -``` - -This file instructs Docker to load the image tagged `ray-project/deploy`, check out the git -repository at ``, and then run the script ``. - -Build the image by running something like: -``` -docker build -t . -``` -Replace `` with a tag of your choice. - - -### Creating a Docker image manually - -Launch the `ray-project/deploy` image interactively - -``` -docker run -t -i ray-project/deploy -``` - -Next, run whatever commands are needed to install your application. -When you are finished type `exit` to stop the container. - -Run -``` -docker ps -a -``` -to identify the id of the container you just exited. - -Next, commit the container -``` -docker commit -t -``` - -Replace `` with a name for your container and replace `` id with the hash id of the container used in configuration. - -## Publishing your Docker image to a repository - -When using Amazon EC2 it can be practical to publish images using the Repositories feature of Elastic Container Service. -Follow the steps below and see [documentation for creating a repository](http://docs.aws.amazon.com/AmazonECR/latest/userguide/repository-create.html) for additional context. - -First ensure that the AWS command-line interface is installed. - -``` -sudo apt-get install -y awscli -``` - -Next create a repository in Amazon's Elastic Container Registry. -This results in a shared resource for storing Docker images that will be accessible from all nodes. - - -``` -aws ecr create-repository --repository-name --region= -``` - -Replace `` with a string describing the application. -Replace `` with the AWS region string, e.g., `us-west-2`. -This should produce output like the following: - -``` -{ - "repository": { - "repositoryUri": "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-app", - "createdAt": 1487227244.0, - "repositoryArn": "arn:aws:ecr:us-west-2:123456789012:repository/my-app", - "registryId": "123456789012", - "repositoryName": "my-app" - } -} -``` - -Take note of the `repositoryUri` string, in this example `123456789012.dkr.ecr.us-west-2.amazonaws.com/my-app`. - - -Tag the Docker image with the repository URI. - -``` -docker tag -``` - -Replace the `` with the container name used previously and replace `` with URI returned by the command used to create the repository. - -Log into the repository: - -``` -eval $(aws ecr get-login --region ) -``` - -Replace `` with your selected AWS region. - -Push the image to the repository: -``` -docker push -``` -Replace `` with the URI of your repository. Now other hosts will be able to access your application Docker image. - - -## Starting a cluster - -We assume a cluster configuration like that described in instructions for [using Ray on a large cluster](using-ray-on-a-large-cluster.md). -In particular, we assume that there is a head node that has ssh access to all of the worker nodes, and that there is a file `workers.txt` listing the IP addresses of all worker nodes. - -### Install the Docker image on all nodes - -Create a script called `setup-docker.sh` on the head node. -``` -# setup-docker.sh -sudo apt-get install -y docker.io -sudo service docker start -sudo usermod -a -G docker ubuntu -exec sudo su -l ubuntu -eval $(aws ecr get-login --region ) -docker pull -``` - -Replace `` with the URI of the repository created in the previous section. -Replace `` with the AWS region in which you created that repository. -This script will install Docker, authenticate the session with the container registry, and download the container image from that registry. - -Run `setup-docker.sh` on the head node (if you used the head node to build the Docker image then you can skip this step): -``` -bash setup-docker.sh -``` - -Run `setup-docker.sh` on the worker nodes: -``` -parallel-ssh -h workers.txt -P -t 0 -I < setup-docker.sh -``` - -### Launch Ray cluster using Docker - -To start Ray on the head node run the following command: - -``` -eval $(aws ecr get-login --region ) -docker run \ - -d --shm-size= --net=host \ - \ - ray start --head \ - --object-manager-port=8076 \ - --redis-port=6379 \ - --num-workers= -``` - -Replace `` with the URI of the repository. -Replace `` with the region of the repository. -Replace `` with the number of workers, e.g., typically a number similar to the number of cores in the system. -Replace `` with the the amount of shared memory to make available within the Docker container, e.g., `8G`. - - -To start Ray on the worker nodes create a script `start-worker-docker.sh` with content like the following: -``` -eval $(aws ecr get-login --region ) -docker run -d --shm-size= --net=host \ - \ - ray start \ - --object-manager-port=8076 \ - --redis-address= \ - --num-workers= - -``` - -Replace `` with the string `:6379` where `` is the private network IP address of the head node. - -Execute the script on the worker nodes: -``` -parallel-ssh -h workers.txt -P -t 0 -I < setup-worker-docker.sh -``` - - -## Running jobs on a cluster - -On the head node, identify the id of the container that you launched as the Ray head. - -``` -docker ps -``` - -the container id appears in the first column of the output. - -Now launch an interactive shell within the container: - -``` -docker exec -t -i bash -``` - -Replace `` with the container id found in the previous step. - -Next, launch your application program. -The Python program should contain an initialization command that takes the Redis address as a parameter: - -``` -ray.init(redis_address="") -``` - - -## Shutting down a cluster - -Kill all running Docker images on the worker nodes: -``` -parallel-ssh -h workers.txt -P 'docker kill $(docker ps -q)' -``` - -Kill all running Docker images on the head node: -``` -docker kill $(docker ps -q) -``` diff --git a/doc/source/using-ray-on-a-cluster.rst b/doc/source/using-ray-on-a-cluster.rst index 611e47b79db23..2bc8b1cf6bf06 100644 --- a/doc/source/using-ray-on-a-cluster.rst +++ b/doc/source/using-ray-on-a-cluster.rst @@ -3,12 +3,12 @@ Manual Cluster Setup .. note:: - If you're using AWS or GCP you should use the automated `setup commands `__. + If you're using AWS or GCP you should use the automated `setup commands `_. The instructions in this document work well for small clusters. For larger -clusters, follow the instructions for `managing a cluster with parallel ssh`_. +clusters, consider using the pssh package: ``sudo apt-get install pssh`` or +the `setup commands for private clusters `_. -.. _`managing a cluster with parallel ssh`: http://ray.readthedocs.io/en/latest/using-ray-on-a-large-cluster.html Deploying Ray on a Cluster -------------------------- @@ -32,7 +32,7 @@ If the ``--redis-port`` argument is omitted, Ray will choose a port at random. The command will print out the address of the Redis server that was started (and some other address information). -Then on all of the other nodes, run the following. Make sure to replace +**Then on all of the other nodes**, run the following. Make sure to replace ```` with the value printed by the command on the head node (it should look something like ``123.45.67.89:6379``). diff --git a/doc/source/using-ray-on-a-large-cluster.rst b/doc/source/using-ray-on-a-large-cluster.rst deleted file mode 100644 index b87c8c05f5125..0000000000000 --- a/doc/source/using-ray-on-a-large-cluster.rst +++ /dev/null @@ -1,309 +0,0 @@ -Manual Cluster Setup on a Large Cluster -======================================= - -.. note:: - - If you're using AWS or GCP you should use the automated `setup commands `__. - -Deploying Ray on a cluster requires a bit of manual work. The instructions here -illustrate how to use parallel ssh commands to simplify the process of running -commands and scripts on many machines simultaneously. - -Booting up a cluster on EC2 ---------------------------- - -* Create an EC2 instance running Ray following the `installation instructions`_. - - * Add any packages that you may need for running your application. - * Install the pssh package: ``sudo apt-get install pssh``. -* `Create an AMI`_ with Ray installed and with whatever code and libraries you - want on the cluster. -* Use the EC2 console to launch additional instances using the AMI you created. -* Configure the instance security groups so that they machines can all - communicate with one another. - -.. _`installation instructions`: http://ray.readthedocs.io/en/latest/installation.html -.. _`Create an AMI`: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/creating-an-ami-ebs.html - -Deploying Ray on a Cluster --------------------------- - -This section assumes that you have a cluster of machines running and that these -nodes have network connectivity to one another. It also assumes that Ray is -installed on each machine. - -Additional assumptions: - -* All of the following commands are run from a machine designated as - the **head node**. -* The head node will run Redis and the global scheduler. -* The head node has ssh access to all other nodes. -* All nodes are accessible via ssh keys -* Ray is checked out on each node at the location ``$HOME/ray``. - -**Note:** The commands below will probably need to be customized for your -specific setup. - -Connect to the head node -~~~~~~~~~~~~~~~~~~~~~~~~ - -In order to initiate ssh commands from the cluster head node we suggest enabling -ssh agent forwarding. This will allow the session that you initiate with the -head node to connect to other nodes in the cluster to run scripts on them. You -can enable ssh forwarding by running the following command before connecting to -the head node (replacing ```` with the path to the private key that you -would use when logging in to the nodes in the cluster). - -.. code-block:: bash - - ssh-add - -Now log in to the head node with the following command, where -```` is the public IP address of the head node (just choose -one of the nodes to be the head node). - -.. code-block:: bash - - ssh -A ubuntu@ - -Build a list of node IP addresses -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -On the head node, populate a file ``workers.txt`` with one IP address on each -line. Do not include the head node IP address in this file. These IP addresses -should typically be private network IP addresses, but any IP addresses which the -head node can use to ssh to worker nodes will work here. This should look -something like the following. - -.. code-block:: bash - - 172.31.27.16 - 172.31.29.173 - 172.31.24.132 - 172.31.29.224 - -Confirm that you can ssh to all nodes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: bash - - for host in $(cat workers.txt); do - ssh -o "StrictHostKeyChecking no" $host uptime - done - -You may need to verify the host keys during this process. If so, run this step -again to verify that it worked. If you see a **permission denied** error, you -most likely forgot to run ``ssh-add `` before connecting to the head -node. - -Starting Ray -~~~~~~~~~~~~ - -**Start Ray on the head node** - -On the head node, run the following: - -.. code-block:: bash - - ray start --head --redis-port=6379 - - -**Start Ray on the worker nodes** - -Create a file ``start_worker.sh`` that contains something like the following: - -.. code-block:: bash - - # Make sure the SSH session has the correct version of Python on its path. - # You will probably have to change the line below. - export PATH=/home/ubuntu/anaconda3/bin/:$PATH - ray start --redis-address=:6379 - -This script, when run on the worker nodes, will start up Ray. You will need to -replace ```` with the IP address that worker nodes will use to -connect to the head node (most likely a **private IP address**). In this -example we also export the path to the Python installation since our remote -commands will not be executing in a login shell. - -**Warning:** You will probably need to manually export the correct path to -Python (you will need to change the first line of ``start_worker.sh`` to find -the version of Python that Ray was built against). This is necessary because the -``PATH`` environment variable used by ``parallel-ssh`` can differ from the -``PATH`` environment variable that gets set when you ``ssh`` to the machine. - -**Warning:** If the ``parallel-ssh`` command below appears to hang or otherwise -fails, ``head-node-ip`` may need to be a private IP address instead of a public -IP address (e.g., if you are using EC2). It's also possible that you forgot to -run ``ssh-add `` or that you forgot the ``-A`` flag when connecting to -the head node. - -Now use ``parallel-ssh`` to start up Ray on each worker node. - -.. code-block:: bash - - parallel-ssh -h workers.txt -P -I < start_worker.sh - -Note that on some distributions the ``parallel-ssh`` command may be called -``pssh``. - -**Verification** - -Now you have started all of the Ray processes on each node. These include: - -- Some worker processes on each machine. -- An object store on each machine. -- A local scheduler on each machine. -- Multiple Redis servers (on the head node). - -To confirm that the Ray cluster setup is working, start up Python on one of the -nodes in the cluster and enter the following commands to connect to the Ray -cluster. - -.. code-block:: python - - import ray - ray.init(redis_address="") - -Here ```` should have the form ``:6379``. - -Now you can define remote functions and execute tasks. For example, to verify -that the correct number of nodes have joined the cluster, you can run the -following. - -.. code-block:: python - - import time - - @ray.remote - def f(): - time.sleep(0.01) - return ray.services.get_node_ip_address() - - # Get a list of the IP addresses of the nodes that have joined the cluster. - set(ray.get([f.remote() for _ in range(1000)])) - - -Stopping Ray -~~~~~~~~~~~~ - -**Stop Ray on worker nodes** - -Create a file ``stop_worker.sh`` that contains something like the following: - -.. code-block:: bash - - # Make sure the SSH session has the correct version of Python on its path. - # You will probably have to change the line below. - export PATH=/home/ubuntu/anaconda3/bin/:$PATH - ray stop - -This script, when run on the worker nodes, will stop Ray. Note, you will need to -replace ``/home/ubuntu/anaconda3/bin/`` with the correct path to your Python -installation. - -Now use ``parallel-ssh`` to stop Ray on each worker node. - -.. code-block:: bash - - parallel-ssh -h workers.txt -P -I < stop_worker.sh - -**Stop Ray on the head node** - -.. code-block:: bash - - ray stop - -Upgrading Ray -~~~~~~~~~~~~~ - -Ray remains under active development so you may at times want to upgrade the -cluster to take advantage of improvements and fixes. - -**Create an upgrade script** - -On the head node, create a file called ``upgrade.sh`` that contains the commands -necessary to upgrade Ray. It should look something like the following: - -.. code-block:: bash - - # Make sure the SSH session has the correct version of Python on its path. - # You will probably have to change the line below. - export PATH=/home/ubuntu/anaconda3/bin/:$PATH - # Do pushd/popd to make sure we end up in the same directory. - pushd . - # Upgrade Ray. - cd ray - git checkout master - git pull - cd python - pip install -e . --verbose - popd - -This script executes a series of git commands to update the Ray source code, then builds -and installs Ray. - -**Stop Ray on the cluster** - -Follow the instructions for `Stopping Ray`_. - -**Run the upgrade script on the cluster** - -First run the upgrade script on the head node. This will upgrade the head node -and help confirm that the upgrade script is working properly. - -.. code-block:: bash - - bash upgrade.sh - -Next run the upgrade script on the worker nodes. - -.. code-block:: bash - - parallel-ssh -h workers.txt -P -t 0 -I < upgrade.sh - -Note here that we use the ``-t 0`` option to set the timeout to infinite. You -may also want to use the ``-p`` flag, which controls the degree of parallelism -used by parallel ssh. - -It is probably a good idea to ssh to one of the other nodes and verify that the -upgrade script ran as expected. - -Sync Application Files to other nodes -------------------------------------- - -If you are running an application that reads input files or uses python -libraries then you may find it useful to copy a directory on the head node to -the worker nodes. - -You can do this using the ``parallel-rsync`` command: - -.. code-block:: bash - - parallel-rsync -h workers.txt -r /home/ubuntu/ - -where ```` is the directory you want to synchronize. Note that the -destination argument for this command must represent an absolute path on the -worker node. - -Troubleshooting ---------------- - -Problems with parallel-ssh -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If any of the above commands fail, verify that the head node has SSH access to -the other nodes by running - -.. code-block:: bash - - for host in $(cat workers.txt); do - ssh $host uptime - done - -If you get a permission denied error, then make sure you have SSH'ed to the head -node with agent forwarding enabled. This is done as follows. - -.. code-block:: bash - - ssh-add - ssh -A ubuntu@ diff --git a/docker/base-deps/Dockerfile b/docker/base-deps/Dockerfile index 56cff8de0ac5e..b7fa7be6b9cf7 100644 --- a/docker/base-deps/Dockerfile +++ b/docker/base-deps/Dockerfile @@ -26,6 +26,6 @@ RUN apt-get update \ && /opt/conda/bin/conda clean -y --all \ && /opt/conda/bin/pip install \ flatbuffers \ - cython==0.27.3 + cython==0.29.0 ENV PATH "/opt/conda/bin:$PATH" diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index 9cdee4ff117eb..db9292dc645ae 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -5,7 +5,7 @@ FROM ray-project/deploy # This updates numpy to 1.14 and mutes errors from other libraries RUN conda install -y numpy RUN apt-get install -y zlib1g-dev -RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras pytest-timeout +RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras pytest-timeout smart_open RUN pip install -U h5py # Mutes FutureWarnings RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 20db30944e513..267b7a3544e51 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -1,5 +1,6 @@ package org.ray.api.options; +import java.util.HashMap; import java.util.Map; /** @@ -7,12 +8,24 @@ */ public class ActorCreationOptions extends BaseTaskOptions { + public static final int NO_RECONSTRUCTION = 0; + public static final int INFINITE_RECONSTRUCTIONS = (int) Math.pow(2, 30); + + public final int maxReconstructions; + public ActorCreationOptions() { super(); + this.maxReconstructions = NO_RECONSTRUCTION; } public ActorCreationOptions(Map resources) { super(resources); + this.maxReconstructions = NO_RECONSTRUCTION; } + + public ActorCreationOptions(Map resources, int maxReconstructions) { + super(resources); + this.maxReconstructions = maxReconstructions; + } } diff --git a/java/doc/installation.rst b/java/doc/installation.rst index 8daec29ace403..54be4094be48d 100644 --- a/java/doc/installation.rst +++ b/java/doc/installation.rst @@ -26,7 +26,7 @@ For Ubuntu users, run the following commands: # If you are on Ubuntu 14.04, you need the following. pip install cmake - pip install cython==0.27.3 + pip install cython==0.29.0 For macOS users, run the following commands: :: @@ -34,7 +34,7 @@ For macOS users, run the following commands: brew update brew install maven cmake pkg-config automake autoconf libtool openssl bison wget - pip install cython==0.27.3 + pip install cython==0.29.0 Build Ray ^^^^^^^^^ diff --git a/java/prepare.sh b/java/prepare.sh index 9554e500a8edd..dcb325c4b1db5 100755 --- a/java/prepare.sh +++ b/java/prepare.sh @@ -50,7 +50,7 @@ declare -a nativeBinaries=( declare -a nativeLibraries=( "./src/ray/gcs/redis_module/libray_redis_module.so" - "./src/ray/raylet/liblocal_scheduler_library_java.*" + "./src/ray/raylet/libraylet_library_java.*" "./src/plasma/libplasma_java.*" "./src/ray/raylet/*lib.a" ) diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 10dc172fd4d99..b3adaa11cc55a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -27,15 +27,19 @@ import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.ResourceUtil; import org.ray.runtime.util.UniqueIdUtil; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Core functionality to implement Ray APIs. */ public abstract class AbstractRayRuntime implements RayRuntime { + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); + private static final int GET_TIMEOUT_MS = 1000; private static final int FETCH_BATCH_SIZE = 1000; + private static final int LIMITED_RETRY_COUNTER = 10; protected RayConfig rayConfig; protected WorkerContext workerContext; @@ -75,10 +79,26 @@ public RayObject put(T obj) { public void put(UniqueId objectId, T obj) { UniqueId taskId = workerContext.getCurrentTask().taskId; - RayLog.core.debug("Putting object {}, for task {} ", objectId, taskId); + LOGGER.debug("Putting object {}, for task {} ", objectId, taskId); objectStoreProxy.put(objectId, obj, null); } + + /** + * Store a serialized object in the object store. + * + * @param obj The serialized Java object to be stored. + * @return A RayObject instance that represents the in-store object. + */ + public RayObject putSerialized(byte[] obj) { + UniqueId objectId = UniqueIdUtil.computePutId( + workerContext.getCurrentTask().taskId, workerContext.nextPutIndex()); + UniqueId taskId = workerContext.getCurrentTask().taskId; + LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId); + objectStoreProxy.putSerialized(objectId, obj, null); + return new RayObjectImpl<>(objectId); + } + @Override public T get(UniqueId objectId) throws RayException { List ret = get(ImmutableList.of(objectId)); @@ -118,7 +138,9 @@ public List get(List objectIds) { // Try reconstructing any objects we haven't gotten yet. Try to get them // until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat. + int retryCounter = 0; while (unreadys.size() > 0) { + retryCounter++; List unreadyList = new ArrayList<>(unreadys.keySet()); List> reconstructBatches = splitIntoBatches(unreadyList, FETCH_BATCH_SIZE); @@ -140,10 +162,20 @@ public List get(List objectIds) { unreadys.remove(id); } } + + if (retryCounter % LIMITED_RETRY_COUNTER == 0) { + LOGGER.warn("Attempted {} times to reconstruct objects {}, " + + "but haven't received response. If this message continues to print," + + " it may indicate that the task is hanging, or someting wrong " + + "happened in raylet backend.", + retryCounter, unreadys.keySet()); + } + } + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), taskId); } - RayLog.core - .debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get"); List finalRet = new ArrayList<>(); for (Pair value : ret) { @@ -152,8 +184,7 @@ public List get(List objectIds) { return finalRet; } catch (RayException e) { - RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) - + " get with Exception", e); + LOGGER.error("Failed to get objects for task {}.", taskId, e); throw e; } finally { // If there were objects that we weren't able to get locally, let the local @@ -270,6 +301,10 @@ private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, resources.put(ResourceUtil.CPU_LITERAL, 0.0); } + int maxActorReconstruction = 0; + if (taskOptions instanceof ActorCreationOptions) { + maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions; + } RayFunction rayFunction = functionManager.getFunction(current.driverId, func); return new TaskSpec( current.driverId, @@ -277,6 +312,7 @@ private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, current.taskId, -1, actorCreationId, + maxActorReconstruction, actor.getId(), actor.getHandleId(), actor.increaseTaskCounter(), diff --git a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java index fd303fa936bb9..5f0ae13eac25c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java @@ -67,7 +67,6 @@ public int increaseTaskCounter() { return taskCounter++; } - private UniqueId computeNextActorHandleId() { byte[] bytes = Sha1Digestor.digest(handleId.getBytes(), ++numForks); return new UniqueId(bytes); diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 2036f2319cf49..fd88bde353e30 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -61,7 +61,7 @@ public void start() throws Exception { // Load native libraries. try { resetLibaryPath(); - System.loadLibrary("local_scheduler_library_java"); + System.loadLibrary("raylet_library_java"); System.loadLibrary("plasma_java"); } catch (Exception e) { LOGGER.error("Failed to load native libraries.", e); diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index 5371560bfc90f..3531b7ed80d38 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -5,7 +5,6 @@ import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.logger.RayLog; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,9 +62,9 @@ public void execute(TaskSpec spec) { } else { runtime.localActors.put(returnId, result); } - RayLog.core.info("Finished executing task {}", spec.taskId); + LOGGER.info("Finished executing task {}", spec.taskId); } catch (Exception e) { - RayLog.core.error("Error executing task " + spec, e); + LOGGER.error("Error executing task " + spec, e); runtime.put(returnId, new RayException("Error executing task " + spec, e)); } finally { Thread.currentThread().setContextClassLoader(oldLoader); diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 7708466ed5afa..fdb507689616a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -79,6 +79,7 @@ private TaskSpec createDummyTask(WorkerMode workerMode, UniqueId driverId) { UniqueId.NIL, 0, UniqueId.NIL, + 0, UniqueId.NIL, UniqueId.NIL, 0, diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/StateStoreProxyImpl.java b/java/runtime/src/main/java/org/ray/runtime/gcs/StateStoreProxyImpl.java index 586d02fa9406c..fe74adc0f7418 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/StateStoreProxyImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/StateStoreProxyImpl.java @@ -11,13 +11,15 @@ import org.ray.api.id.UniqueId; import org.ray.runtime.generated.ClientTableData; import org.ray.runtime.util.NetworkUtil; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A class used to interface with the Ray control state. */ public class StateStoreProxyImpl implements StateStoreProxy { + private static final Logger LOGGER = LoggerFactory.getLogger(StateStoreProxyImpl.class); public KeyValueStoreLink rayKvStore; public ArrayList shardStoreList = new ArrayList<>(); @@ -87,11 +89,11 @@ public List getAddressInfo(final String nodeIpAddress, return doGetAddressInfo(nodeIpAddress, redisAddress); } catch (Exception e) { try { - RayLog.core.warn("Error occurred in StateStoreProxyImpl getAddressInfo, " - + (numRetries - count) + " retries remaining", e); + LOGGER.warn("Error occurred in StateStoreProxyImpl getAddressInfo, {} retries remaining", + (numRetries - count), e); TimeUnit.MILLISECONDS.sleep(1000); } catch (InterruptedException ie) { - RayLog.core.error("error at StateStoreProxyImpl getAddressInfo", e); + LOGGER.error("error at StateStoreProxyImpl getAddressInfo", e); throw new RuntimeException(e); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java index 01113096036fc..4e17e45a7d450 100644 --- a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java +++ b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java @@ -29,17 +29,15 @@ public final class TaskInfo extends Table { public String actorCreationDummyObjectId() { int o = __offset(14); return o != 0 ? __string(o + bb_pos) : null; } public ByteBuffer actorCreationDummyObjectIdAsByteBuffer() { return __vector_as_bytebuffer(14, 1); } public ByteBuffer actorCreationDummyObjectIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 14, 1); } - public String actorId() { int o = __offset(16); return o != 0 ? __string(o + bb_pos) : null; } - public ByteBuffer actorIdAsByteBuffer() { return __vector_as_bytebuffer(16, 1); } - public ByteBuffer actorIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 16, 1); } - public String actorHandleId() { int o = __offset(18); return o != 0 ? __string(o + bb_pos) : null; } - public ByteBuffer actorHandleIdAsByteBuffer() { return __vector_as_bytebuffer(18, 1); } - public ByteBuffer actorHandleIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 18, 1); } - public int actorCounter() { int o = __offset(20); return o != 0 ? bb.getInt(o + bb_pos) : 0; } - public boolean isActorCheckpointMethod() { int o = __offset(22); return o != 0 ? 0!=bb.get(o + bb_pos) : false; } - public String functionId() { int o = __offset(24); return o != 0 ? __string(o + bb_pos) : null; } - public ByteBuffer functionIdAsByteBuffer() { return __vector_as_bytebuffer(24, 1); } - public ByteBuffer functionIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 24, 1); } + public int maxActorReconstructions() { int o = __offset(16); return o != 0 ? bb.getInt(o + bb_pos) : 0; } + public String actorId() { int o = __offset(18); return o != 0 ? __string(o + bb_pos) : null; } + public ByteBuffer actorIdAsByteBuffer() { return __vector_as_bytebuffer(18, 1); } + public ByteBuffer actorIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 18, 1); } + public String actorHandleId() { int o = __offset(20); return o != 0 ? __string(o + bb_pos) : null; } + public ByteBuffer actorHandleIdAsByteBuffer() { return __vector_as_bytebuffer(20, 1); } + public ByteBuffer actorHandleIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 20, 1); } + public int actorCounter() { int o = __offset(22); return o != 0 ? bb.getInt(o + bb_pos) : 0; } + public boolean isActorCheckpointMethod() { int o = __offset(24); return o != 0 ? 0!=bb.get(o + bb_pos) : false; } public Arg args(int j) { return args(new Arg(), j); } public Arg args(Arg obj, int j) { int o = __offset(26); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } public int argsLength() { int o = __offset(26); return o != 0 ? __vector_len(o) : 0; } @@ -62,11 +60,11 @@ public static int createTaskInfo(FlatBufferBuilder builder, int parent_counter, int actor_creation_idOffset, int actor_creation_dummy_object_idOffset, + int max_actor_reconstructions, int actor_idOffset, int actor_handle_idOffset, int actor_counter, boolean is_actor_checkpoint_method, - int function_idOffset, int argsOffset, int returnsOffset, int required_resourcesOffset, @@ -80,10 +78,10 @@ public static int createTaskInfo(FlatBufferBuilder builder, TaskInfo.addRequiredResources(builder, required_resourcesOffset); TaskInfo.addReturns(builder, returnsOffset); TaskInfo.addArgs(builder, argsOffset); - TaskInfo.addFunctionId(builder, function_idOffset); TaskInfo.addActorCounter(builder, actor_counter); TaskInfo.addActorHandleId(builder, actor_handle_idOffset); TaskInfo.addActorId(builder, actor_idOffset); + TaskInfo.addMaxActorReconstructions(builder, max_actor_reconstructions); TaskInfo.addActorCreationDummyObjectId(builder, actor_creation_dummy_object_idOffset); TaskInfo.addActorCreationId(builder, actor_creation_idOffset); TaskInfo.addParentCounter(builder, parent_counter); @@ -101,11 +99,11 @@ public static int createTaskInfo(FlatBufferBuilder builder, public static void addParentCounter(FlatBufferBuilder builder, int parentCounter) { builder.addInt(3, parentCounter, 0); } public static void addActorCreationId(FlatBufferBuilder builder, int actorCreationIdOffset) { builder.addOffset(4, actorCreationIdOffset, 0); } public static void addActorCreationDummyObjectId(FlatBufferBuilder builder, int actorCreationDummyObjectIdOffset) { builder.addOffset(5, actorCreationDummyObjectIdOffset, 0); } - public static void addActorId(FlatBufferBuilder builder, int actorIdOffset) { builder.addOffset(6, actorIdOffset, 0); } - public static void addActorHandleId(FlatBufferBuilder builder, int actorHandleIdOffset) { builder.addOffset(7, actorHandleIdOffset, 0); } - public static void addActorCounter(FlatBufferBuilder builder, int actorCounter) { builder.addInt(8, actorCounter, 0); } - public static void addIsActorCheckpointMethod(FlatBufferBuilder builder, boolean isActorCheckpointMethod) { builder.addBoolean(9, isActorCheckpointMethod, false); } - public static void addFunctionId(FlatBufferBuilder builder, int functionIdOffset) { builder.addOffset(10, functionIdOffset, 0); } + public static void addMaxActorReconstructions(FlatBufferBuilder builder, int maxActorReconstructions) { builder.addInt(6, maxActorReconstructions, 0); } + public static void addActorId(FlatBufferBuilder builder, int actorIdOffset) { builder.addOffset(7, actorIdOffset, 0); } + public static void addActorHandleId(FlatBufferBuilder builder, int actorHandleIdOffset) { builder.addOffset(8, actorHandleIdOffset, 0); } + public static void addActorCounter(FlatBufferBuilder builder, int actorCounter) { builder.addInt(9, actorCounter, 0); } + public static void addIsActorCheckpointMethod(FlatBufferBuilder builder, boolean isActorCheckpointMethod) { builder.addBoolean(10, isActorCheckpointMethod, false); } public static void addArgs(FlatBufferBuilder builder, int argsOffset) { builder.addOffset(11, argsOffset, 0); } public static int createArgsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startArgsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } @@ -127,8 +125,12 @@ public static int endTaskInfo(FlatBufferBuilder builder) { return o; } - //this is manually added to avoid encoding/decoding cost as our object - //id is a byte array instead of a string + /** This is manually added to avoid encoding/decoding cost as our object + * id is a byte array instead of a string. + * This function is error-prone. If the fields before `returns` changed, + * the offset number should be changed accordingly. + * TODO(yuhguo): fix this error-prone funciton. + */ public ByteBuffer returnsAsByteBuffer(int j) { int o = __offset(28); if (o == 0) { diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java index 3dbe7b61459fd..0e3c70ed94d4f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java @@ -9,13 +9,15 @@ import org.ray.api.id.UniqueId; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.raylet.MockRayletClient; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A mock implementation of {@code org.ray.spi.ObjectStoreLink}, which use Map to store data. */ public class MockObjectStore implements ObjectStoreLink { + private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectStore.class); private final RayDevRuntime runtime; private final Map data = new ConcurrentHashMap<>(); private final Map metadata = new ConcurrentHashMap<>(); @@ -28,14 +30,15 @@ public MockObjectStore(RayDevRuntime runtime) { @Override public void put(byte[] objectId, byte[] value, byte[] metadataValue) { if (objectId == null || objectId.length == 0 || value == null) { - RayLog.core - .error(logPrefix() + "cannot put null: " + objectId + "," + Arrays.toString(value)); + LOGGER + .error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value)); System.exit(-1); } UniqueId uniqueId = new UniqueId(objectId); data.put(uniqueId, value); - metadata.put(uniqueId, metadataValue); - + if (metadataValue != null) { + metadata.put(uniqueId, metadataValue); + } if (scheduler != null) { scheduler.onObjectPut(uniqueId); } @@ -47,7 +50,7 @@ public List get(byte[][] objectIds, int timeoutMs, boolean isMetadata) { ArrayList rets = new ArrayList<>(objectIds.length); for (byte[] objId : objectIds) { UniqueId uniqueId = new UniqueId(objId); - RayLog.core.info(logPrefix() + " is notified for objectid " + uniqueId); + LOGGER.info("{} is notified for objectid {}",logPrefix(), uniqueId); rets.add(dataMap.get(uniqueId)); } return rets; diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 5f8221ff6f028..be33150c71382 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -75,6 +75,10 @@ public void put(UniqueId id, Object obj, Object metadata) { store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata)); } + public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) { + store.put(id.getBytes(), obj, metadata); + } + public enum GetStatus { SUCCESS, FAILED } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 91937ba14b1e2..f658d3b1697dd 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -10,6 +10,7 @@ import java.util.Map; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.generated.Arg; @@ -19,10 +20,13 @@ import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.UniqueIdUtil; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class RayletClientImpl implements RayletClient { + private static final Logger LOGGER = LoggerFactory.getLogger(RayletClientImpl.class); + private static final int TASK_SPEC_BUFFER_SIZE = 2 * 1024 * 1024; /** @@ -46,6 +50,11 @@ public RayletClientImpl(String schedulerSockName, UniqueId clientId, @Override public WaitResult wait(List> waitFor, int numReturns, int timeoutMs, UniqueId currentTaskId) { + Preconditions.checkNotNull(waitFor); + if (waitFor.isEmpty()) { + return new WaitResult<>(new ArrayList<>(), new ArrayList<>()); + } + List ids = new ArrayList<>(); for (RayObject element : waitFor) { ids.add(element.getId()); @@ -69,7 +78,7 @@ public WaitResult wait(List> waitFor, int numReturns, int @Override public void submitTask(TaskSpec spec) { - RayLog.core.debug("Submitting task: {}", spec); + LOGGER.debug("Submitting task: {}", spec); ByteBuffer info = convertTaskSpecToFlatbuffer(spec); byte[] cursorId = null; if (!spec.getExecutionDependencies().isEmpty()) { @@ -90,12 +99,15 @@ public TaskSpec getTask() { @Override public void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId) { - if (RayLog.core.isDebugEnabled()) { - RayLog.core.debug("Blocked on objects for task {}, object IDs are {}", + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Blocked on objects for task {}, object IDs are {}", UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); } - nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), + int ret = nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), fetchOnly, currentTaskId.getBytes()); + if (ret != 0) { + throw new RayException("Connection closed by Raylet"); + } } @Override @@ -123,6 +135,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { UniqueId parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); int parentCounter = info.parentCounter(); UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer()); + int maxActorReconstructions = info.maxActorReconstructions(); UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer()); UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer()); int actorCounter = info.actorCounter(); @@ -158,8 +171,9 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { FunctionDescriptor functionDescriptor = new FunctionDescriptor( info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2) ); - return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, actorId, - actorHandleId, actorCounter, args, returnIds, resources, functionDescriptor); + return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, + maxActorReconstructions, actorId, actorHandleId, actorCounter, args, returnIds, resources, + functionDescriptor); } private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { @@ -173,10 +187,10 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { final int parentCounter = task.parentCounter; final int actorCreateIdOffset = fbb.createString(task.actorCreationId.toByteBuffer()); final int actorCreateDummyIdOffset = fbb.createString(task.actorId.toByteBuffer()); + final int maxActorReconstructions = task.maxActorReconstructions; final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer()); final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer()); final int actorCounter = task.actorCounter; - final int functionIdOffset = fbb.createString(UniqueId.NIL.toByteBuffer()); // Serialize args int[] argsOffsets = new int[task.args.length]; for (int i = 0; i < argsOffsets.length; i++) { @@ -226,22 +240,32 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { int functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); int root = TaskInfo.createTaskInfo( - fbb, driverIdOffset, taskIdOffset, - parentTaskIdOffset, parentCounter, - actorCreateIdOffset, actorCreateDummyIdOffset, - actorIdOffset, actorHandleIdOffset, actorCounter, - false, functionIdOffset, - argsOffset, returnsOffset, requiredResourcesOffset, - requiredPlacementResourcesOffset, Language.JAVA, + fbb, + driverIdOffset, + taskIdOffset, + parentTaskIdOffset, + parentCounter, + actorCreateIdOffset, + actorCreateDummyIdOffset, + maxActorReconstructions, + actorIdOffset, + actorHandleIdOffset, + actorCounter, + false, + argsOffset, + returnsOffset, + requiredResourcesOffset, + requiredPlacementResourcesOffset, + Language.JAVA, functionDescriptorOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); if (buffer.remaining() > TASK_SPEC_BUFFER_SIZE) { - RayLog.core.error( - "Allocated buffer is not enough to transfer the task specification: " - + TASK_SPEC_BUFFER_SIZE + " vs " + buffer.remaining()); - assert (false); + LOGGER.error( + "Allocated buffer is not enough to transfer the task specification: {}vs {}", + TASK_SPEC_BUFFER_SIZE, buffer.remaining()); + throw new RuntimeException("Allocated buffer is not enough to transfer to task."); } return buffer; } @@ -274,7 +298,7 @@ private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBu private static native void nativeDestroy(long client); - private static native void nativeFetchOrReconstruct(long client, byte[][] objectIds, + private static native int nativeFetchOrReconstruct(long client, byte[][] objectIds, boolean fetchOnly, byte[] currentTaskId); private static native void nativeNotifyUnblocked(long client, byte[] currentTaskId); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 56940e33cbcfd..7b25882dd600b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -38,6 +38,8 @@ public class RunManager { private List processes; + private static final int KILL_PROCESS_WAIT_TIMEOUT_SECONDS = 1; + public RunManager(RayConfig rayConfig) { this.rayConfig = rayConfig; processes = new ArrayList<>(); @@ -45,8 +47,24 @@ public RunManager(RayConfig rayConfig) { } public void cleanup() { - for (Process p : processes) { + // Terminate the processes in the reversed order of creating them. + // Because raylet needs to exit before object store, otherwise it + // cannot exit gracefully. + + for (int i = processes.size() - 1; i >= 0; --i) { + Process p = processes.get(i); p.destroy(); + + try { + p.waitFor(KILL_PROCESS_WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOGGER.warn("Got InterruptedException while waiting for process {}" + + " to be terminated.", processes.get(i)); + } + + if (p.isAlive()) { + p.destroyForcibly(); + } } } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 58dc3d8030958..83714a6dec983 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -6,14 +6,17 @@ import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.id.UniqueId; +import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.util.Serializer; public class ArgumentsBuilder { - private static boolean checkSimpleValue(Object o) { - // TODO(raulchen): implement this. - return true; - } + /** + * If the the size of an argument's serialized data is smaller than this number, + * the argument will be passed by value. Otherwise it'll be passed by reference. + */ + private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024; + /** * Convert real function arguments to task spec arguments. @@ -30,10 +33,13 @@ public static FunctionArg[] wrap(Object[] args) { data = Serializer.encode(arg); } else if (arg instanceof RayObject) { id = ((RayObject) arg).getId(); - } else if (checkSimpleValue(arg)) { - data = Serializer.encode(arg); } else { - id = Ray.put(arg).getId(); + byte[] serialized = Serializer.encode(arg); + if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { + id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId(); + } else { + data = serialized; + } } if (id != null) { ret[i] = FunctionArg.passByReference(id); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 864be37544b87..8988c933fe3b7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -28,6 +28,8 @@ public class TaskSpec { // Id for createActor a target actor public final UniqueId actorCreationId; + public final int maxActorReconstructions; + // Actor ID of the task. This is the actor that this task is executed on // or NIL_ACTOR_ID if the task is just a normal task. public final UniqueId actorId; @@ -62,14 +64,15 @@ public boolean isActorCreationTask() { } public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter, - UniqueId actorCreationId, UniqueId actorId, UniqueId actorHandleId, int actorCounter, - FunctionArg[] args, UniqueId[] returnIds, + UniqueId actorCreationId, int maxActorReconstructions, UniqueId actorId, + UniqueId actorHandleId, int actorCounter, FunctionArg[] args, UniqueId[] returnIds, Map resources, FunctionDescriptor functionDescriptor) { this.driverId = driverId; this.taskId = taskId; this.parentTaskId = parentTaskId; this.parentCounter = parentCounter; this.actorCreationId = actorCreationId; + this.maxActorReconstructions = maxActorReconstructions; this.actorId = actorId; this.actorHandleId = actorHandleId; this.actorCounter = actorCounter; diff --git a/java/runtime/src/main/java/org/ray/runtime/util/NetworkUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/NetworkUtil.java index eeaec75b9dec6..40d0860a8914a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/NetworkUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/NetworkUtil.java @@ -7,10 +7,13 @@ import java.net.NetworkInterface; import java.net.ServerSocket; import java.util.Enumeration; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class NetworkUtil { + private static final Logger LOGGER = LoggerFactory.getLogger(NetworkUtil.class); + public static String getIpAddress(String interfaceName) { try { Enumeration interfaces = NetworkInterface.getNetworkInterfaces(); @@ -35,9 +38,9 @@ public static String getIpAddress(String interfaceName) { return addr.getHostAddress(); } } - RayLog.core.warn("You need to correctly specify [ray.java] net_interface in config."); + LOGGER.warn("You need to correctly specify [ray.java] net_interface in config."); } catch (Exception e) { - RayLog.core.error("Can't get ip address, use 127.0.0.1 as default.", e); + LOGGER.error("Can't get ip address, use 127.0.0.1 as default.", e); } return "127.0.0.1"; diff --git a/java/runtime/src/main/java/org/ray/runtime/util/Sha1Digestor.java b/java/runtime/src/main/java/org/ray/runtime/util/Sha1Digestor.java index 6454775430d0b..761e7192fbb19 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/Sha1Digestor.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/Sha1Digestor.java @@ -2,15 +2,17 @@ import java.nio.ByteBuffer; import java.security.MessageDigest; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class Sha1Digestor { + private static final Logger LOGGER = LoggerFactory.getLogger(Sha1Digestor.class); private static final ThreadLocal md = ThreadLocal.withInitial(() -> { try { return MessageDigest.getInstance("SHA1"); } catch (Exception e) { - RayLog.core.error("Cannot get SHA1 MessageDigest", e); + LOGGER.error("Cannot get SHA1 MessageDigest", e); throw new RuntimeException("Cannot get SHA1 digest", e); } }); diff --git a/java/runtime/src/main/java/org/ray/runtime/util/SystemUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/SystemUtil.java index 858cf3c37dc9b..3234cd055229a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/SystemUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/SystemUtil.java @@ -3,13 +3,16 @@ import java.lang.management.ManagementFactory; import java.lang.management.RuntimeMXBean; import java.util.concurrent.locks.ReentrantLock; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * some utilities for system process. */ public class SystemUtil { + private static final Logger LOGGER = LoggerFactory.getLogger(SystemUtil.class); + static final ReentrantLock pidlock = new ReentrantLock(); static Integer pid; @@ -34,7 +37,7 @@ public static boolean startWithJar(String clsName) { } catch (ClassNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); - RayLog.core.error("error at SystemUtil startWithJar", e); + LOGGER.error("error at SystemUtil startWithJar", e); return false; } } diff --git a/java/test/src/main/java/org/ray/api/benchmark/ActorPressTest.java b/java/test/src/main/java/org/ray/api/benchmark/ActorPressTest.java index 1918daa43771e..9945e17772a55 100644 --- a/java/test/src/main/java/org/ray/api/benchmark/ActorPressTest.java +++ b/java/test/src/main/java/org/ray/api/benchmark/ActorPressTest.java @@ -1,14 +1,11 @@ package org.ray.api.benchmark; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.annotation.RayRemote; -import org.ray.api.test.MyRunner; -@RunWith(MyRunner.class) public class ActorPressTest extends RayBenchmarkTest { @Test diff --git a/java/test/src/main/java/org/ray/api/benchmark/MaxPressureTest.java b/java/test/src/main/java/org/ray/api/benchmark/MaxPressureTest.java index 923ed4dd32375..bf5d9c5ac3a77 100644 --- a/java/test/src/main/java/org/ray/api/benchmark/MaxPressureTest.java +++ b/java/test/src/main/java/org/ray/api/benchmark/MaxPressureTest.java @@ -1,14 +1,11 @@ package org.ray.api.benchmark; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.annotation.RayRemote; -import org.ray.api.test.MyRunner; -@RunWith(MyRunner.class) public class MaxPressureTest extends RayBenchmarkTest { public static final int clientNum = 2; diff --git a/java/test/src/main/java/org/ray/api/benchmark/RateLimiterPressureTest.java b/java/test/src/main/java/org/ray/api/benchmark/RateLimiterPressureTest.java index 0594dfae732f9..4c44a0332145b 100644 --- a/java/test/src/main/java/org/ray/api/benchmark/RateLimiterPressureTest.java +++ b/java/test/src/main/java/org/ray/api/benchmark/RateLimiterPressureTest.java @@ -1,14 +1,11 @@ package org.ray.api.benchmark; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.annotation.RayRemote; -import org.ray.api.test.MyRunner; -@RunWith(MyRunner.class) public class RateLimiterPressureTest extends RayBenchmarkTest { public static final int clientNum = 2; diff --git a/java/test/src/main/java/org/ray/api/benchmark/RayBenchmarkTest.java b/java/test/src/main/java/org/ray/api/benchmark/RayBenchmarkTest.java index 5c47b3174574e..ab73dd21440d7 100644 --- a/java/test/src/main/java/org/ray/api/benchmark/RayBenchmarkTest.java +++ b/java/test/src/main/java/org/ray/api/benchmark/RayBenchmarkTest.java @@ -11,10 +11,14 @@ import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.annotation.RayRemote; +import org.ray.api.test.BaseTest; import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -public abstract class RayBenchmarkTest implements Serializable { +public abstract class RayBenchmarkTest extends BaseTest implements Serializable { + private static final Logger LOGGER = LoggerFactory.getLogger(RayBenchmarkTest.class); //not thread safe ,but we only have one thread here public static final DecimalFormat df = new DecimalFormat("00.00"); private static final long serialVersionUID = 416045641835782523L; @@ -61,12 +65,12 @@ private static List singleClient(PressureTestParameter pressureTestParamet long endTime = remoteResult.getFinishTime(); long costTime = endTime - temp.getStartTime(); counterList.add(costTime / 1000); - RayLog.core.warn(logPrefix + "_cost_time:" + costTime + "ns"); + LOGGER.warn("{}_cost_time:{}ns",logPrefix, costTime); Assert.assertTrue(rayBenchmarkTest.checkResult(remoteResult.getResult())); } return counterList; } catch (Exception e) { - RayLog.core.error("singleClient", e); + LOGGER.error("singleClient", e); return null; } @@ -83,7 +87,7 @@ public void singleLatencyTest(int times, RayActor rayActor) { long endTime = System.nanoTime(); long costTime = endTime - startTime; counterList.add(costTime / 1000); - RayLog.core.warn("SINGLE_LATENCY_cost_time: " + costTime + " us"); + LOGGER.warn("SINGLE_LATENCY_cost_time: {} us", costTime); Assert.assertTrue(checkResult(t)); } Collections.sort(counterList); @@ -103,15 +107,15 @@ private void printList(List list) { int ninety = (int) (len * 0.9); int fifty = (int) (len * 0.5); - RayLog.core.error("Final result of rt as below:"); - RayLog.core.error("max: " + list.get(len - 1) + "μs"); - RayLog.core.error("min: " + list.get(0) + "μs"); - RayLog.core.error("median: " + list.get(middle) + "μs"); - RayLog.core.error("99.99% data smaller than: " + list.get(almostHundred) + "μs"); - RayLog.core.error("99% data smaller than: " + list.get(ninetyNine) + "μs"); - RayLog.core.error("95% data smaller than: " + list.get(ninetyFive) + "μs"); - RayLog.core.error("90% data smaller than: " + list.get(ninety) + "μs"); - RayLog.core.error("50% data smaller than: " + list.get(fifty) + "μs"); + LOGGER.error("Final result of rt as below:"); + LOGGER.error("max: {}μs", list.get(len - 1)); + LOGGER.error("min: {}μs", list.get(0)); + LOGGER.error("median: {}μs", list.get(middle)); + LOGGER.error("99.99% data smaller than: {}μs", list.get(almostHundred)); + LOGGER.error("99% data smaller than: {}μs", list.get(ninetyNine)); + LOGGER.error("95% data smaller than: {}μs", list.get(ninetyFive)); + LOGGER.error("90% data smaller than: {}μs", list.get(ninety)); + LOGGER.error("50% data smaller than: {}μs", list.get(fifty)); } public void rateLimiterPressureTest(PressureTestParameter pressureTestParameter) { diff --git a/java/test/src/main/java/org/ray/api/benchmark/SingleLatencyTest.java b/java/test/src/main/java/org/ray/api/benchmark/SingleLatencyTest.java index de9c1e8385db0..5e82163bc1631 100644 --- a/java/test/src/main/java/org/ray/api/benchmark/SingleLatencyTest.java +++ b/java/test/src/main/java/org/ray/api/benchmark/SingleLatencyTest.java @@ -1,14 +1,11 @@ package org.ray.api.benchmark; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.annotation.RayRemote; -import org.ray.api.test.MyRunner; -@RunWith(MyRunner.class) public class SingleLatencyTest extends RayBenchmarkTest { public static final int totalNum = 10; diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java new file mode 100644 index 0000000000000..59bba919fa7a5 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -0,0 +1,67 @@ +package org.ray.api.test; + +import static org.ray.runtime.util.SystemUtil.pid; + +import java.io.IOException; +import java.util.HashMap; +import java.util.concurrent.TimeUnit; +import org.junit.Assert; +import org.junit.Test; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.annotation.RayRemote; +import org.ray.api.options.ActorCreationOptions; + +public class ActorReconstructionTest extends BaseTest { + + @RayRemote() + public static class Counter { + + private int value = 0; + + public int increase(int delta) { + value += delta; + return value; + } + + public int getPid() { + return pid(); + } + } + + @Test + public void testActorReconstruction() throws InterruptedException, IOException { + ActorCreationOptions options = new ActorCreationOptions(new HashMap<>(), 1); + RayActor actor = Ray.createActor(Counter::new, options); + // Call increase 3 times. + for (int i = 0; i < 3; i++) { + Ray.call(Counter::increase, actor, 1).get(); + } + + // Kill the actor process. + int pid = Ray.call(Counter::getPid, actor).get(); + Runtime.getRuntime().exec("kill -9 " + pid); + // Wait for the actor to be killed. + TimeUnit.SECONDS.sleep(1); + + // Try calling increase on this actor again and check the value is now 4. + int value = Ray.call(Counter::increase, actor, 1).get(); + Assert.assertEquals(value, 4); + + // Kill the actor process again. + pid = Ray.call(Counter::getPid, actor).get(); + Runtime.getRuntime().exec("kill -9 " + pid); + TimeUnit.SECONDS.sleep(1); + + // Try calling increase on this actor again and this should fail. + try { + Ray.call(Counter::increase, actor, 1).get(); + Assert.fail("The above task didn't fail."); + } catch (StringIndexOutOfBoundsException e) { + // Raylet backend will put invalid data in task's result to indicate the task has failed. + // Thus, Java deserialization will fail and throw `StringIndexOutOfBoundsException`. + // TODO(hchen): we should use object's metadata to indicate task failure, + // instead of throwing this exception. + } + } +} diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 1dffb16b89765..c8fb6ce6e973a 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -2,7 +2,6 @@ import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; @@ -10,8 +9,7 @@ import org.ray.api.function.RayFunc2; import org.ray.api.id.UniqueId; -@RunWith(MyRunner.class) -public class ActorTest { +public class ActorTest extends BaseTest { @RayRemote public static class Counter { diff --git a/java/test/src/main/java/org/ray/api/test/BaseTest.java b/java/test/src/main/java/org/ray/api/test/BaseTest.java new file mode 100644 index 0000000000000..f7bcf01b42213 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/BaseTest.java @@ -0,0 +1,32 @@ +package org.ray.api.test; + +import java.io.File; +import org.junit.After; +import org.junit.Before; +import org.ray.api.Ray; + +public class BaseTest { + + @Before + public void setUp() { + System.setProperty("ray.home", "../.."); + System.setProperty("ray.resources", "CPU:4,RES-A:4"); + Ray.init(); + } + + @After + public void tearDown() { + // TODO(qwang): This is double check to check that the socket file is removed actually. + // We could not enable this until `systemInfo` enabled. + //File rayletSocketFIle = new File(Ray.systemInfo().rayletSocketName()); + Ray.shutdown(); + + //remove raylet socket file + //rayletSocketFIle.delete(); + + // unset system properties + System.clearProperty("ray.home"); + System.clearProperty("ray.resources"); + } + +} diff --git a/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java b/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java index 3dddce1f5c85c..9f31363e86ce1 100644 --- a/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java +++ b/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java @@ -2,7 +2,6 @@ import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.annotation.RayRemote; @@ -10,8 +9,7 @@ /** * Hello world. */ -@RunWith(MyRunner.class) -public class HelloWorldTest { +public class HelloWorldTest extends BaseTest { @RayRemote private static String hello() { diff --git a/java/test/src/main/java/org/ray/api/test/MyRunner.java b/java/test/src/main/java/org/ray/api/test/MyRunner.java deleted file mode 100644 index 3c7b755f5bbb5..0000000000000 --- a/java/test/src/main/java/org/ray/api/test/MyRunner.java +++ /dev/null @@ -1,19 +0,0 @@ -package org.ray.api.test; - -import org.junit.runner.notification.RunNotifier; -import org.junit.runners.BlockJUnit4ClassRunner; -import org.junit.runners.model.InitializationError; - -public class MyRunner extends BlockJUnit4ClassRunner { - - public MyRunner(Class klass) throws InitializationError { - super(klass); - } - - @Override - public void run(RunNotifier notifier) { - notifier.addListener(new TestListener()); - notifier.fireTestRunStarted(getDescription()); - super.run(notifier); - } -} diff --git a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java index 961c02bd2e73c..d38c46992d76e 100644 --- a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java @@ -13,8 +13,7 @@ /** * Test putting and getting objects. */ -@RunWith(MyRunner.class) -public class ObjectStoreTest { +public class ObjectStoreTest extends BaseTest { @Test public void testPutAndGet() { diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java index be24f0299e564..795e0efdb5fa0 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java @@ -13,8 +13,7 @@ import org.ray.api.id.UniqueId; -@RunWith(MyRunner.class) -public class PlasmaFreeTest { +public class PlasmaFreeTest extends BaseTest { @RayRemote private static String hello() { diff --git a/java/test/src/main/java/org/ray/api/test/RayCallTest.java b/java/test/src/main/java/org/ray/api/test/RayCallTest.java index ae1fc4483bf5b..08e90e589d0ca 100644 --- a/java/test/src/main/java/org/ray/api/test/RayCallTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayCallTest.java @@ -2,6 +2,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; + +import java.io.Serializable; import java.util.List; import java.util.Map; import org.junit.Assert; @@ -13,8 +15,7 @@ /** * Test Ray.call API */ -@RunWith(MyRunner.class) -public class RayCallTest { +public class RayCallTest extends BaseTest { @RayRemote private static int testInt(int val) { @@ -66,6 +67,15 @@ private static Map testMap(Map val) { return val; } + public static class LargeObject implements Serializable { + private byte[] data = new byte[1024 * 1024]; + } + + @RayRemote + private static LargeObject testLargeObject(LargeObject largeObject) { + return largeObject; + } + /** * Test calling and returning different types. */ @@ -83,6 +93,8 @@ public void testType() { Assert.assertEquals(list, Ray.call(RayCallTest::testList, list).get()); Map map = ImmutableMap.of("1", 1, "2", 2); Assert.assertEquals(map, Ray.call(RayCallTest::testMap, map).get()); + LargeObject largeObject = new LargeObject(); + Assert.assertNotNull(Ray.call(RayCallTest::testLargeObject, largeObject).get()); } @RayRemote @@ -130,4 +142,5 @@ public void testNumberOfParameters() { Assert.assertEquals(5, (int) Ray.call(RayCallTest::testFiveParams, 1, 1, 1, 1, 1).get()); Assert.assertEquals(6, (int) Ray.call(RayCallTest::testSixParams, 1, 1, 1, 1, 1, 1).get()); } + } diff --git a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java index 71e3d0dfff8e7..8260b39d48adb 100644 --- a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java @@ -10,22 +10,28 @@ public class RayConfigTest { @Test public void testCreateRayConfig() { - System.setProperty("ray.home", "/path/to/ray"); - System.setProperty("ray.driver.resource-path", "path/to/ray/driver/resource/path"); - RayConfig rayConfig = RayConfig.create(); - - Assert.assertEquals("/path/to/ray", rayConfig.rayHome); - Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode); - Assert.assertEquals(RunMode.CLUSTER, rayConfig.runMode); - - System.setProperty("ray.home", ""); - rayConfig = RayConfig.create(); - - Assert.assertEquals(System.getProperty("user.dir"), rayConfig.rayHome); - Assert.assertEquals(System.getProperty("user.dir") + - "/build/src/ray/thirdparty/redis/src/redis-server", rayConfig.redisServerExecutablePath); - - Assert.assertEquals("path/to/ray/driver/resource/path", rayConfig.driverResourcePath); + try { + System.setProperty("ray.home", "/path/to/ray"); + System.setProperty("ray.driver.resource-path", "path/to/ray/driver/resource/path"); + RayConfig rayConfig = RayConfig.create(); + + Assert.assertEquals("/path/to/ray", rayConfig.rayHome); + Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode); + Assert.assertEquals(RunMode.CLUSTER, rayConfig.runMode); + + System.setProperty("ray.home", ""); + rayConfig = RayConfig.create(); + + Assert.assertEquals(System.getProperty("user.dir"), rayConfig.rayHome); + Assert.assertEquals(System.getProperty("user.dir") + + "/build/src/ray/thirdparty/redis/src/redis-server", rayConfig.redisServerExecutablePath); + + Assert.assertEquals("path/to/ray/driver/resource/path", rayConfig.driverResourcePath); + } finally { + //unset the system property + System.clearProperty("ray.home"); + System.clearProperty("ray.driver.resource-path"); + } } } diff --git a/java/test/src/main/java/org/ray/api/test/RayMethodsTest.java b/java/test/src/main/java/org/ray/api/test/RayMethodsTest.java index b49f78e1729db..a612ef7c26355 100644 --- a/java/test/src/main/java/org/ray/api/test/RayMethodsTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayMethodsTest.java @@ -5,7 +5,6 @@ import java.util.stream.Collectors; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.WaitResult; @@ -14,8 +13,7 @@ /** * Integration test for Ray.* */ -@RunWith(MyRunner.class) -public class RayMethodsTest { +public class RayMethodsTest extends BaseTest { @Test public void test() { diff --git a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java index e185a5f19a894..36abda2f32dd4 100644 --- a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java +++ b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java @@ -2,10 +2,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import jdk.nashorn.internal.ir.annotations.Immutable; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; @@ -17,8 +15,7 @@ /** * Resources Management Test. */ -@RunWith(MyRunner.class) -public class ResourcesManagementTest { +public class ResourcesManagementTest extends BaseTest { @RayRemote public static Integer echo(Integer number) { diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index 4fab74aed1991..a85a8ca2bc25b 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -5,14 +5,12 @@ import java.util.List; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.id.UniqueId; -@RunWith(MyRunner.class) -public class StressTest { +public class StressTest extends BaseTest { public static int echo(int x) { return x; diff --git a/java/test/src/main/java/org/ray/api/test/TestListener.java b/java/test/src/main/java/org/ray/api/test/TestListener.java deleted file mode 100644 index efc419b34720e..0000000000000 --- a/java/test/src/main/java/org/ray/api/test/TestListener.java +++ /dev/null @@ -1,21 +0,0 @@ -package org.ray.api.test; - -import org.junit.runner.Description; -import org.junit.runner.Result; -import org.junit.runner.notification.RunListener; -import org.ray.api.Ray; - -public class TestListener extends RunListener { - - @Override - public void testRunStarted(Description description) { - System.setProperty("ray.home", "../.."); - System.setProperty("ray.resources", "CPU:4,RES-A:4"); - Ray.init(); - } - - @Override - public void testRunFinished(Result result) { - Ray.shutdown(); - } -} diff --git a/java/test/src/main/java/org/ray/api/test/WaitTest.java b/java/test/src/main/java/org/ray/api/test/WaitTest.java index ac18276bf3633..49b5f8365b0e8 100644 --- a/java/test/src/main/java/org/ray/api/test/WaitTest.java +++ b/java/test/src/main/java/org/ray/api/test/WaitTest.java @@ -1,17 +1,16 @@ package org.ray.api.test; import com.google.common.collect.ImmutableList; +import java.util.ArrayList; import java.util.List; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.annotation.RayRemote; -@RunWith(MyRunner.class) -public class WaitTest { +public class WaitTest extends BaseTest { @RayRemote private static String hi() { @@ -58,4 +57,18 @@ public void testWaitInWorker() { RayObject res = Ray.call(WaitTest::waitInWorker); res.get(); } + + @Test + public void testWaitForEmpty() { + WaitResult result = Ray.wait(new ArrayList<>()); + Assert.assertTrue(result.getReady().isEmpty()); + Assert.assertTrue(result.getUnready().isEmpty()); + + try { + Ray.wait(null); + Assert.fail(); + } catch (NullPointerException e) { + Assert.assertTrue(true); + } + } } diff --git a/python/build-wheel-macos.sh b/python/build-wheel-macos.sh index 30e8b19363769..1b29cd77415ba 100755 --- a/python/build-wheel-macos.sh +++ b/python/build-wheel-macos.sh @@ -70,7 +70,7 @@ for ((i=0; i<${#PY_VERSIONS[@]}; ++i)); do $PIP_CMD install -q setuptools_scm==2.1.0 # Fix the numpy version because this will be the oldest numpy version we can # support. - $PIP_CMD install -q numpy==$NUMPY_VERSION cython==0.27.3 + $PIP_CMD install -q numpy==$NUMPY_VERSION cython==0.29.0 # Install wheel to avoid the error "invalid command 'bdist_wheel'". $PIP_CMD install -q wheel # Add the correct Python to the path and build the wheel. This is only diff --git a/python/build-wheel-manylinux1.sh b/python/build-wheel-manylinux1.sh index db31ff55a4e6e..500fe5c491fb7 100755 --- a/python/build-wheel-manylinux1.sh +++ b/python/build-wheel-manylinux1.sh @@ -21,7 +21,7 @@ for PYTHON in cp27-cp27mu cp34-cp34m cp35-cp35m cp36-cp36m cp37-cp37m; do pushd python # Fix the numpy version because this will be the oldest numpy version we can # support. - /opt/python/${PYTHON}/bin/pip install -q numpy==1.10.4 cython==0.27.3 + /opt/python/${PYTHON}/bin/pip install -q numpy==1.10.4 cython==0.29.0 INCLUDE_UI=1 PATH=/opt/python/${PYTHON}/bin:$PATH /opt/python/${PYTHON}/bin/python setup.py bdist_wheel # In the future, run auditwheel here. mv dist/*.whl ../.whl/ diff --git a/python/ray/__init__.py b/python/ray/__init__.py index ed024a107aa50..5b47ccc75c65d 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -47,7 +47,7 @@ raise modin_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "modin") -sys.path.insert(0, modin_path) +sys.path.append(modin_path) from ray.raylet import ObjectID, _config # noqa: E402 from ray.profiling import profile # noqa: E402 @@ -65,7 +65,7 @@ from ray.actor import method # noqa: E402 # Ray version string. -__version__ = "0.6.0" +__version__ = "0.6.1" __all__ = [ "error_info", "init", "connect", "disconnect", "get", "put", "wait", diff --git a/python/ray/actor.py b/python/ray/actor.py index 926f15b293644..a29bb765a2cb2 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -10,7 +10,7 @@ import traceback import ray.cloudpickle as pickle -from ray.function_manager import FunctionActorManager +from ray.function_manager import FunctionDescriptor import ray.raylet import ray.ray_constants as ray_constants import ray.signature as signature @@ -44,8 +44,7 @@ def compute_actor_handle_id(actor_handle_id, num_forks): return ray.ObjectID(handle_id) -def compute_actor_handle_id_non_forked(actor_id, actor_handle_id, - current_task_id): +def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id): """Deterministically compute an actor handle ID in the non-forked case. This code path is used whenever an actor handle is pickled and unpickled @@ -59,16 +58,13 @@ def compute_actor_handle_id_non_forked(actor_id, actor_handle_id, to the same actor handle IDs. Args: - actor_id: The actor ID. actor_handle_id: The original actor handle ID. - num_forks: The number of times the original actor handle has been - forked so far. + current_task_id: The ID of the task that is unpickling the handle. Returns: An ID for the new actor handle. """ handle_id_hash = hashlib.sha1() - handle_id_hash.update(actor_id.id()) handle_id_hash.update(actor_handle_id.id()) handle_id_hash.update(current_task_id.id()) handle_id = handle_id_hash.digest() @@ -76,18 +72,6 @@ def compute_actor_handle_id_non_forked(actor_id, actor_handle_id, return ray.ObjectID(handle_id) -def compute_actor_creation_function_id(class_id): - """Compute the function ID for an actor creation task. - - Args: - class_id: The ID of the actor class. - - Returns: - The function ID of the actor creation event. - """ - return ray.ObjectID(class_id) - - def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, frontier): """Set the most recent checkpoint associated with a given actor ID. @@ -228,7 +212,7 @@ def remote(self, *args, **kwargs): return self._remote(args, kwargs) def _submit(self, args, kwargs, num_return_vals=None): - logger.warn( + logger.warning( "WARNING: _submit() is being deprecated. Please use _remote().") return self._remote( args=args, kwargs=kwargs, num_return_vals=num_return_vals) @@ -271,12 +255,14 @@ class ActorClass(object): each actor method. """ - def __init__(self, modified_class, class_id, checkpoint_interval, num_cpus, - num_gpus, resources, actor_method_cpus): + def __init__(self, modified_class, class_id, checkpoint_interval, + max_reconstructions, num_cpus, num_gpus, resources, + actor_method_cpus): self._modified_class = modified_class self._class_id = class_id self._class_name = modified_class.__name__ self._checkpoint_interval = checkpoint_interval + self._max_reconstructions = max_reconstructions self._num_cpus = num_cpus self._num_gpus = num_gpus self._resources = resources @@ -285,6 +271,23 @@ def __init__(self, modified_class, class_id, checkpoint_interval, num_cpus, self._actor_methods = inspect.getmembers( self._modified_class, ray.utils.is_function_or_method) + self._actor_method_names = [ + method_name for method_name, _ in self._actor_methods + ] + + constructor_name = "__init__" + if constructor_name not in self._actor_method_names: + # Add __init__ if it does not exist. + # Actor creation will be executed with __init__ together. + + # Assign an __init__ function will avoid many checks later on. + def __init__(self): + pass + + self._modified_class.__init__ = __init__ + self._actor_method_names.append(constructor_name) + self._actor_methods.append((constructor_name, __init__)) + # Extract the signatures of each of the methods. This will be used # to catch some errors if the methods are called with inappropriate # arguments. @@ -298,7 +301,6 @@ def __init__(self, modified_class, class_id, checkpoint_interval, num_cpus, signature.check_signature_supported(method, warn=True) self._method_signatures[method_name] = signature.extract_signature( method, ignore_first=not ray.utils.is_class_method(method)) - # Set the default number of return values for this method. if hasattr(method, "__ray_num_return_vals__"): self._actor_method_num_return_vals[method_name] = ( @@ -307,10 +309,6 @@ def __init__(self, modified_class, class_id, checkpoint_interval, num_cpus, self._actor_method_num_return_vals[method_name] = ( DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS) - self._actor_method_names = [ - method_name for method_name, _ in self._actor_methods - ] - def __call__(self, *args, **kwargs): raise Exception("Actors methods cannot be instantiated directly. " "Instead of running '{}()', try '{}.remote()'.".format( @@ -336,7 +334,7 @@ def _submit(self, num_cpus=None, num_gpus=None, resources=None): - logger.warn( + logger.warning( "WARNING: _submit() is being deprecated. Please use _remote().") return self._remote( args=args, @@ -384,14 +382,14 @@ def _remote(self, # Instead, instantiate the actor locally and add it to the worker's # dictionary if worker.mode == ray.LOCAL_MODE: - worker.actors[actor_id] = self._modified_class.__new__( - self._modified_class) + worker.actors[actor_id] = self._modified_class( + *copy.deepcopy(args), **copy.deepcopy(kwargs)) else: # Export the actor. if not self._exported: worker.function_actor_manager.export_actor_class( - self._class_id, self._modified_class, - self._actor_method_names, self._checkpoint_interval) + self._modified_class, self._actor_method_names, + self._checkpoint_interval) self._exported = True resources = ray.utils.resources_from_resource_arguments( @@ -407,33 +405,34 @@ def _remote(self, actor_placement_resources = resources.copy() actor_placement_resources["CPU"] += 1 - creation_args = [self._class_id] - function_id = compute_actor_creation_function_id(self._class_id) + if args is None: + args = [] + if kwargs is None: + kwargs = {} + function_name = "__init__" + function_signature = self._method_signatures[function_name] + creation_args = signature.extend_args(function_signature, args, + kwargs) + function_descriptor = FunctionDescriptor( + self._modified_class.__module__, function_name, + self._modified_class.__name__) [actor_cursor] = worker.submit_task( - function_id, + function_descriptor, creation_args, actor_creation_id=actor_id, + max_actor_reconstructions=self._max_reconstructions, num_return_vals=1, resources=resources, placement_resources=actor_placement_resources) - # We initialize the actor counter at 1 to account for the actor - # creation task. - actor_counter = 1 actor_handle = ActorHandle( - actor_id, self._class_name, actor_cursor, actor_counter, - self._actor_method_names, self._method_signatures, + actor_id, self._modified_class.__module__, self._class_name, + actor_cursor, self._actor_method_names, self._method_signatures, self._actor_method_num_return_vals, actor_cursor, self._actor_method_cpus, worker.task_driver_id) - - # Call __init__ as a remote function. - if "__init__" in actor_handle._ray_actor_method_names: - actor_handle.__init__.remote(*args, **kwargs) - else: - if len(args) != 0 or len(kwargs) != 0: - raise Exception("Arguments cannot be passed to the actor " - "constructor because this actor class has no " - "__init__ method.") + # We increment the actor counter by 1 to account for the actor creation + # task. + actor_handle._ray_actor_counter += 1 return actor_handle @@ -455,6 +454,7 @@ class ActorHandle(object): Attributes: _ray_actor_id: The ID of the corresponding actor. + _ray_module_name: The module name of this actor. _ray_actor_handle_id: The ID of this handle. If this is the "original" handle for an actor (as opposed to one created by passing another handle into a task), then this ID must be NIL_ID. If this @@ -485,37 +485,32 @@ class ActorHandle(object): _ray_actor_driver_id: The driver ID of the job that created the actor (it is possible that this ActorHandle exists on a driver with a different driver ID). - _ray_previous_actor_handle_id: If this actor handle is not an original - handle, (e.g., it was created by forking or pickling), then - this is the ID of the handle that this handle was created from. - Otherwise, this is None. """ def __init__(self, actor_id, + module_name, class_name, actor_cursor, - actor_counter, actor_method_names, method_signatures, method_num_return_vals, actor_creation_dummy_object_id, actor_method_cpus, actor_driver_id, - actor_handle_id=None, - previous_actor_handle_id=None): + actor_handle_id=None): + self._ray_actor_id = actor_id + self._ray_module_name = module_name # False if this actor handle was created by forking or pickling. True # if it was created by the _serialization_helper function. - self._ray_original_handle = previous_actor_handle_id is None - - self._ray_actor_id = actor_id + self._ray_original_handle = actor_handle_id is None if self._ray_original_handle: self._ray_actor_handle_id = ray.ObjectID( ray.worker.NIL_ACTOR_HANDLE_ID) else: self._ray_actor_handle_id = actor_handle_id self._ray_actor_cursor = actor_cursor - self._ray_actor_counter = actor_counter + self._ray_actor_counter = 0 self._ray_actor_method_names = actor_method_names self._ray_method_signatures = method_signatures self._ray_method_num_return_vals = method_num_return_vals @@ -525,8 +520,6 @@ def __init__(self, actor_creation_dummy_object_id) self._ray_actor_method_cpus = actor_method_cpus self._ray_actor_driver_id = actor_driver_id - self._ray_previous_actor_handle_id = previous_actor_handle_id - self._ray_previously_generated_actor_handle_id = None def _actor_method_call(self, method_name, @@ -580,32 +573,13 @@ def _actor_method_call(self, is_actor_checkpoint_method = (method_name == "__ray_checkpoint__") - # Right now, if the actor handle has been pickled, we create a - # temporary actor handle id for invocations. - # TODO(pcm): This still leads to a lot of actor handles being - # created, there should be a better way to handle pickled - # actor handles. - if self._ray_actor_handle_id is None: - actor_handle_id = compute_actor_handle_id_non_forked( - self._ray_actor_id, self._ray_previous_actor_handle_id, - worker.current_task_id) - # Each new task creates a new actor handle id, so we need to - # reset the actor counter to 0 - if (actor_handle_id != - self._ray_previously_generated_actor_handle_id): - self._ray_actor_counter = 0 - self._ray_previously_generated_actor_handle_id = ( - actor_handle_id) - else: - actor_handle_id = self._ray_actor_handle_id - - function_id = FunctionActorManager.compute_actor_method_function_id( - self._ray_class_name, method_name) + function_descriptor = FunctionDescriptor( + self._ray_module_name, method_name, self._ray_class_name) object_ids = worker.submit_task( - function_id, + function_descriptor, args, actor_id=self._ray_actor_id, - actor_handle_id=actor_handle_id, + actor_handle_id=self._ray_actor_handle_id, actor_counter=self._ray_actor_counter, is_actor_checkpoint_method=is_actor_checkpoint_method, actor_creation_dummy_object_id=( @@ -668,6 +642,16 @@ def __del__(self): # there are ANY handles in scope in the process that created the actor, # not just the first one. worker = ray.worker.get_global_worker() + if (worker.mode == ray.worker.SCRIPT_MODE + and self._ray_actor_driver_id.id() != worker.worker_id): + # If the worker is a driver and driver id has changed because + # Ray was shut down re-initialized, the actor is already cleaned up + # and we don't need to send `__ray_terminate__` again. + logger.warning( + "Actor is garbage collected in the wrong driver." + + " Actor id = %s, class name = %s.", self._ray_actor_id, + self._ray_class_name) + return if worker.connected and self._ray_original_handle: # TODO(rkn): Should we be passing in the actor cursor as a # dependency here? @@ -691,23 +675,28 @@ def _serialization_helper(self, ray_forking): Returns: A dictionary of the information needed to reconstruct the object. """ + if ray_forking: + actor_handle_id = compute_actor_handle_id( + self._ray_actor_handle_id, self._ray_actor_forks) + else: + actor_handle_id = self._ray_actor_handle_id + state = { "actor_id": self._ray_actor_id.id(), + "actor_handle_id": actor_handle_id.id(), + "module_name": self._ray_module_name, "class_name": self._ray_class_name, - "actor_forks": self._ray_actor_forks, "actor_cursor": self._ray_actor_cursor.id() if self._ray_actor_cursor is not None else None, - "actor_counter": 0, # Reset the actor counter. "actor_method_names": self._ray_actor_method_names, "method_signatures": self._ray_method_signatures, "method_num_return_vals": self._ray_method_num_return_vals, + # Actors in local mode don't have dummy objects. "actor_creation_dummy_object_id": self. _ray_actor_creation_dummy_object_id.id() if self._ray_actor_creation_dummy_object_id is not None else None, "actor_method_cpus": self._ray_actor_method_cpus, "actor_driver_id": self._ray_actor_driver_id.id(), - "previous_actor_handle_id": self._ray_actor_handle_id.id() - if self._ray_actor_handle_id else None, "ray_forking": ray_forking } @@ -728,11 +717,21 @@ def _deserialization_helper(self, state, ray_forking): worker.check_connected() if state["ray_forking"]: - actor_handle_id = compute_actor_handle_id( - ray.ObjectID(state["previous_actor_handle_id"]), - state["actor_forks"]) + actor_handle_id = ray.ObjectID(state["actor_handle_id"]) else: - actor_handle_id = None + # Right now, if the actor handle has been pickled, we create a + # temporary actor handle id for invocations. + # TODO(pcm): This still leads to a lot of actor handles being + # created, there should be a better way to handle pickled + # actor handles. + # TODO(swang): Accessing the worker's current task ID is not + # thread-safe. + # TODO(swang): Unpickling the same actor handle twice in the same + # task will break the application, and unpickling it twice in the + # same actor is likely a performance bug. We should consider + # logging a warning in these cases. + actor_handle_id = compute_actor_handle_id_non_forked( + ray.ObjectID(state["actor_handle_id"]), worker.current_task_id) # This is the driver ID of the driver that owns the actor, not # necessarily the driver that owns this actor handle. @@ -740,10 +739,10 @@ def _deserialization_helper(self, state, ray_forking): self.__init__( ray.ObjectID(state["actor_id"]), + state["module_name"], state["class_name"], ray.ObjectID(state["actor_cursor"]) if state["actor_cursor"] is not None else None, - state["actor_counter"], state["actor_method_names"], state["method_signatures"], state["method_num_return_vals"], @@ -751,9 +750,7 @@ def _deserialization_helper(self, state, ray_forking): if state["actor_creation_dummy_object_id"] is not None else None, state["actor_method_cpus"], actor_driver_id, - actor_handle_id=actor_handle_id, - previous_actor_handle_id=ray.ObjectID( - state["previous_actor_handle_id"])) + actor_handle_id=actor_handle_id) def __getstate__(self): """This code path is used by pickling but not by Ray forking.""" @@ -765,12 +762,19 @@ def __setstate__(self, state): def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus, - checkpoint_interval): + checkpoint_interval, max_reconstructions): if checkpoint_interval is None: checkpoint_interval = -1 + if max_reconstructions is None: + max_reconstructions = 0 if checkpoint_interval == 0: raise Exception("checkpoint_interval must be greater than 0.") + if not (ray_constants.NO_RECONSTRUCTION <= max_reconstructions <= + ray_constants.INFINITE_RECONSTRUCTION): + raise Exception("max_reconstructions must be in range [%d, %d]." % + (ray_constants.NO_RECONSTRUCTION, + ray_constants.INFINITE_RECONSTRUCTION)) # Modify the class to have an additional method that will be used for # terminating the worker. @@ -781,7 +785,7 @@ def __ray_terminate__(self): # Disconnect the worker from the local scheduler. The point of # this is so that when the worker kills itself below, the local # scheduler won't push an error message to the driver. - worker.local_scheduler_client.disconnect() + worker.raylet_client.disconnect() sys.exit(0) assert False, "This process should have terminated." @@ -822,8 +826,7 @@ def __ray_checkpoint__(self): # the local scheduler will not be included, and may not be runnable # on checkpoint resumption. actor_id = ray.ObjectID(worker.actor_id) - frontier = worker.local_scheduler_client.get_actor_frontier( - actor_id) + frontier = worker.raylet_client.get_actor_frontier(actor_id) # Save the checkpoint in Redis. TODO(rkn): Checkpoints # should not be stored in Redis. Fix this. set_actor_checkpoint(worker, worker.actor_id, checkpoint_index, @@ -853,7 +856,7 @@ def __ray_checkpoint_restore__(self): # Set the number of tasks executed so far. worker.actor_task_counter = checkpoint_index # Set the actor frontier in the local scheduler. - worker.local_scheduler_client.set_actor_frontier(frontier) + worker.raylet_client.set_actor_frontier(frontier) checkpoint_resumed = True return checkpoint_resumed @@ -863,8 +866,9 @@ def __ray_checkpoint_restore__(self): class_id = _random_string() - return ActorClass(Class, class_id, checkpoint_interval, num_cpus, num_gpus, - resources, actor_method_cpus) + return ActorClass(Class, class_id, checkpoint_interval, + max_reconstructions, num_cpus, num_gpus, resources, + actor_method_cpus) ray.worker.global_worker.make_actor = make_actor diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index 9c4a452ee2687..a806e3b62c155 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -362,12 +362,14 @@ def update(self): raise e def _update(self): + now = time.time() + # Throttle autoscaling updates to this interval to avoid exceeding # rate limits on API calls. - if time.time() - self.last_update_time < self.update_interval_s: + if now - self.last_update_time < self.update_interval_s: return - self.last_update_time = time.time() + self.last_update_time = now num_pending = self.num_launches_pending.value nodes = self.workers() logger.info(self.info_string(nodes)) @@ -377,7 +379,7 @@ def _update(self): # Terminate any idle or out of date nodes last_used = self.load_metrics.last_used_time_by_ip - horizon = time.time() - (60 * self.config["idle_timeout_minutes"]) + horizon = now - (60 * self.config["idle_timeout_minutes"]) num_terminated = 0 for node_id in nodes: node_ip = self.provider.internal_ip(node_id) @@ -441,7 +443,7 @@ def _update(self): # Attempt to recover unhealthy nodes for node_id in nodes: - self.recover_if_needed(node_id) + self.recover_if_needed(node_id, now) def reload_config(self, errors_fatal=False): try: @@ -488,14 +490,14 @@ def files_up_to_date(self, node_id): return False return True - def recover_if_needed(self, node_id): + def recover_if_needed(self, node_id, now): if not self.can_update(node_id): return key = self.provider.internal_ip(node_id) if key not in self.load_metrics.last_heartbeat_time_by_ip: - self.load_metrics.last_heartbeat_time_by_ip[key] = time.time() + self.load_metrics.last_heartbeat_time_by_ip[key] = now last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[key] - delta = time.time() - last_heartbeat_time + delta = now - last_heartbeat_time if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S: return logger.warning("StandardAutoscaler: No heartbeat from node " diff --git a/python/ray/autoscaler/aws/config.py b/python/ray/autoscaler/aws/config.py index 62e0b25ee2e2d..c0493df0d658e 100644 --- a/python/ray/autoscaler/aws/config.py +++ b/python/ray/autoscaler/aws/config.py @@ -273,8 +273,11 @@ def _get_role(role_name, config): try: role.load() return role - except botocore.errorfactory.NoSuchEntityException: - return None + except botocore.exceptions.ClientError as exc: + if exc.response.get("Error", {}).get("Code") == "NoSuchEntity": + return None + else: + raise exc def _get_instance_profile(profile_name, config): @@ -283,8 +286,11 @@ def _get_instance_profile(profile_name, config): try: profile.load() return profile - except botocore.errorfactory.NoSuchEntityException: - return None + except botocore.exceptions.ClientError as exc: + if exc.response.get("Error", {}).get("Code") == "NoSuchEntity": + return None + else: + raise exc def _get_key(key_name, config): diff --git a/python/ray/autoscaler/aws/development-example.yaml b/python/ray/autoscaler/aws/development-example.yaml index 273d1d7d94593..b9791799ff048 100644 --- a/python/ray/autoscaler/aws/development-example.yaml +++ b/python/ray/autoscaler/aws/development-example.yaml @@ -94,7 +94,7 @@ setup_commands: - echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc # Build Ray. - git clone https://github.com/ray-project/ray || true - - pip install boto3==1.4.8 cython==0.27.3 + - pip install boto3==1.4.8 cython==0.29.0 - cd ray/python; pip install -e . --verbose # Custom commands that will be run on the head node after common setup. diff --git a/python/ray/autoscaler/aws/example-full.yaml b/python/ray/autoscaler/aws/example-full.yaml index d74d45823c211..afe767a93a776 100644 --- a/python/ray/autoscaler/aws/example-full.yaml +++ b/python/ray/autoscaler/aws/example-full.yaml @@ -89,9 +89,9 @@ setup_commands: # has your Ray repo pre-cloned. Then, you can replace the pip installs # below with a git checkout (and possibly a recompile). - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp36-cp36m-manylinux1_x86_64.whl # Consider uncommenting these if you also want to run apt-get commands during setup # - sudo pkill -9 apt-get || true # - sudo pkill -9 dpkg || true diff --git a/python/ray/autoscaler/gcp/example-full.yaml b/python/ray/autoscaler/gcp/example-full.yaml index 6afbb464fa6a0..b841781e9a592 100644 --- a/python/ray/autoscaler/gcp/example-full.yaml +++ b/python/ray/autoscaler/gcp/example-full.yaml @@ -123,10 +123,10 @@ setup_commands: - >- pip install google-api-python-client==1.6.7 - cython==0.27.3 - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp36-cp36m-manylinux1_x86_64.whl + cython==0.29.0 + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp36-cp36m-manylinux1_x86_64.whl - >- cd ~ && git clone https://github.com/ray-project/ray || true diff --git a/python/ray/autoscaler/local/development-example.yaml b/python/ray/autoscaler/local/development-example.yaml new file mode 100644 index 0000000000000..11f7c960f1f22 --- /dev/null +++ b/python/ray/autoscaler/local/development-example.yaml @@ -0,0 +1,32 @@ +cluster_name: default +min_workers: 0 +max_workers: 0 +docker: + image: "" + container_name: "" +target_utilization_fraction: 0.8 +idle_timeout_minutes: 5 +provider: + type: local + head_ip: YOUR_HEAD_NODE_HOSTNAME + worker_ips: [] +auth: + ssh_user: YOUR_USERNAME + ssh_private_key: ~/.ssh/id_rsa +head_node: {} +worker_nodes: {} +file_mounts: + "/tmp/ray_sha": "/YOUR/LOCAL/RAY/REPO/.git/refs/heads/YOUR_BRANCH" +setup_commands: [] +head_setup_commands: [] +worker_setup_commands: [] +setup_commands: + - source activate ray && test -e ray || git clone https://github.com/YOUR_GITHUB/ray.git + - source activate ray && cd ray && git fetch && git reset --hard `cat /tmp/ray_sha` +# - source activate ray && cd ray/python && pip install -e . +head_start_ray_commands: + - source activate ray && ray stop + - source activate ray && ulimit -c unlimited && ray start --head --redis-port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml +worker_start_ray_commands: + - source activate ray && ray stop + - source activate ray && ray start --redis-address=$RAY_HEAD_IP:6379 diff --git a/python/ray/autoscaler/local/example-full.yaml b/python/ray/autoscaler/local/example-full.yaml index 11f7c960f1f22..88d20dadd616a 100644 --- a/python/ray/autoscaler/local/example-full.yaml +++ b/python/ray/autoscaler/local/example-full.yaml @@ -15,15 +15,12 @@ auth: ssh_private_key: ~/.ssh/id_rsa head_node: {} worker_nodes: {} -file_mounts: - "/tmp/ray_sha": "/YOUR/LOCAL/RAY/REPO/.git/refs/heads/YOUR_BRANCH" +file_mounts: {} setup_commands: [] head_setup_commands: [] worker_setup_commands: [] setup_commands: - - source activate ray && test -e ray || git clone https://github.com/YOUR_GITHUB/ray.git - - source activate ray && cd ray && git fetch && git reset --hard `cat /tmp/ray_sha` -# - source activate ray && cd ray/python && pip install -e . + - source activate ray && pip install -U ray head_start_ray_commands: - source activate ray && ray stop - source activate ray && ulimit -c unlimited && ray start --head --redis-port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml diff --git a/python/ray/experimental/async_api.py b/python/ray/experimental/async_api.py new file mode 100644 index 0000000000000..8df8596e29aa5 --- /dev/null +++ b/python/ray/experimental/async_api.py @@ -0,0 +1,62 @@ +# Note: asyncio is only compatible with Python 3 + +import asyncio +import ray +from ray.experimental.async_plasma import PlasmaProtocol, PlasmaEventHandler + +handler = None +transport = None +protocol = None + + +async def _async_init(): + global handler, transport, protocol + if handler is None: + worker = ray.worker.global_worker + loop = asyncio.get_event_loop() + worker.plasma_client.subscribe() + rsock = worker.plasma_client.get_notification_socket() + handler = PlasmaEventHandler(loop, worker) + transport, protocol = await loop.create_connection( + lambda: PlasmaProtocol(worker.plasma_client, handler), sock=rsock) + + +def init(): + """ + Initialize synchronously. + """ + loop = asyncio.get_event_loop() + if loop.is_running(): + raise Exception("You must initialize the Ray async API by calling " + "async_api.init() or async_api.as_future(obj) before " + "the event loop starts.") + else: + asyncio.get_event_loop().run_until_complete(_async_init()) + + +def as_future(object_id): + """Turn an object_id into a Future object. + + Args: + object_id: A Ray object_id. + + Returns: + PlasmaObjectFuture: A future object that waits the object_id. + """ + if handler is None: + init() + return handler.as_future(object_id) + + +def shutdown(): + """Manually shutdown the async API. + + Cancels all related tasks and all the socket transportation. + """ + global handler, transport, protocol + if handler is not None: + handler.close() + transport.close() + handler = None + transport = None + protocol = None diff --git a/python/ray/experimental/async_plasma.py b/python/ray/experimental/async_plasma.py new file mode 100644 index 0000000000000..2c0f806f2467b --- /dev/null +++ b/python/ray/experimental/async_plasma.py @@ -0,0 +1,237 @@ +import asyncio +import ctypes +import sys + +import pyarrow.plasma as plasma + +import ray +from ray.services import logger + +INT64_SIZE = ctypes.sizeof(ctypes.c_int64) + + +def _release_waiter(waiter, *_): + if not waiter.done(): + waiter.set_result(None) + + +class PlasmaProtocol(asyncio.Protocol): + """Protocol control for the asyncio connection.""" + + def __init__(self, plasma_client, plasma_event_handler): + self.plasma_client = plasma_client + self.plasma_event_handler = plasma_event_handler + self.transport = None + self._buffer = b"" + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + self._buffer += data + messages = [] + i = 0 + while i + INT64_SIZE <= len(self._buffer): + msg_len = int.from_bytes(self._buffer[i:i + INT64_SIZE], + sys.byteorder) + if i + INT64_SIZE + msg_len > len(self._buffer): + break + i += INT64_SIZE + segment = self._buffer[i:i + msg_len] + i += msg_len + messages.append(self.plasma_client.decode_notification(segment)) + + self._buffer = self._buffer[i:] + self.plasma_event_handler.process_notifications(messages) + + def connection_lost(self, exc): + # The socket has been closed + logger.debug("PlasmaProtocol - connection lost.") + + def eof_received(self): + logger.debug("PlasmaProtocol - EOF received.") + self.transport.close() + + +class PlasmaObjectFuture(asyncio.Future): + """This class manages the lifecycle of a Future contains an object_id. + + Note: + This Future is an item in an linked list. + + Attributes: + object_id: The object_id this Future contains. + """ + + def __init__(self, loop, object_id): + super().__init__(loop=loop) + self.object_id = object_id + self.prev = None + self.next = None + + @property + def ray_object_id(self): + return ray.ObjectID(self.object_id.binary()) + + def __repr__(self): + return super().__repr__() + "{object_id=%s}" % self.object_id + + +class PlasmaObjectLinkedList(asyncio.Future): + """This class is a doubly-linked list. + It holds a ObjectID and maintains futures assigned to the ObjectID. + + Args: + loop: an event loop. + plain_object_id (plasma.ObjectID): + The plasma ObjectID this class holds. + """ + + def __init__(self, loop, plain_object_id): + super().__init__(loop=loop) + assert isinstance(plain_object_id, plasma.ObjectID) + self.object_id = plain_object_id + self.head = None + self.tail = None + + def append(self, future): + """Append an object to the linked list. + + Args: + future (PlasmaObjectFuture): A PlasmaObjectFuture instance. + """ + future.prev = self.tail + if self.tail is None: + assert self.head is None + self.head = future + else: + self.tail.next = future + self.tail = future + # Once done, it will be removed from the list. + future.add_done_callback(self.remove) + + def remove(self, future): + """Remove an object from the linked list. + + Args: + future (PlasmaObjectFuture): A PlasmaObjectFuture instance. + """ + if self._loop.get_debug(): + logger.debug("Removing %s from the linked list.", future) + if future.prev is None: + assert future is self.head + self.head = future.next + if self.head is None: + self.tail = None + if not self.cancelled(): + self.set_result(None) + else: + self.head.prev = None + elif future.next is None: + assert future is self.tail + self.tail = future.prev + if self.tail is None: + self.head = None + if not self.cancelled(): + self.set_result(None) + else: + self.tail.prev = None + + def cancel(self, *args, **kwargs): + """Manually cancel all tasks assigned to this event loop.""" + # Because remove all futures will trigger `set_result`, + # we cancel itself first. + super().cancel() + for future in self.traverse(): + # All cancelled futures should have callbacks to removed itself + # from this linked list. However, these callbacks are scheduled in + # an event loop, so we could still find them in our list. + if not future.cancelled(): + future.cancel() + + def set_result(self, result): + """Complete all tasks. """ + for future in self.traverse(): + # All cancelled futures should have callbacks to removed itself + # from this linked list. However, these callbacks are scheduled in + # an event loop, so we could still find them in our list. + future.set_result(result) + if not self.done(): + super().set_result(result) + + def traverse(self): + """Traverse this linked list. + + Yields: + PlasmaObjectFuture: PlasmaObjectFuture instances. + """ + current = self.head + while current is not None: + yield current + current = current.next + + +class PlasmaEventHandler: + """This class is an event handler for Plasma.""" + + def __init__(self, loop, worker): + super().__init__() + self._loop = loop + self._worker = worker + self._waiting_dict = {} + + def process_notifications(self, messages): + """Process notifications.""" + for object_id, object_size, metadata_size in messages: + if object_size > 0 and object_id in self._waiting_dict: + linked_list = self._waiting_dict[object_id] + self._complete_future(linked_list) + + def close(self): + """Clean up this handler.""" + for linked_list in self._waiting_dict.values(): + linked_list.cancel() + # All cancelled linked lists should have callbacks to removed itself + # from the waiting dict. However, these callbacks are scheduled in + # an event loop, so we don't check them now. + + def _unregister_callback(self, fut): + del self._waiting_dict[fut.object_id] + + def _complete_future(self, fut): + obj = self._worker.retrieve_and_deserialize([fut.object_id], 0)[0] + fut.set_result(obj) + + def as_future(self, object_id, check_ready=True): + """Turn an object_id into a Future object. + + Args: + object_id: A Ray's object_id. + check_ready (bool): If true, check if the object_id is ready. + + Returns: + PlasmaObjectFuture: A future object that waits the object_id. + """ + if not isinstance(object_id, ray.ObjectID): + raise TypeError("Input should be an ObjectID.") + + plain_object_id = plasma.ObjectID(object_id.id()) + fut = PlasmaObjectFuture(loop=self._loop, object_id=plain_object_id) + + if check_ready: + ready, _ = ray.wait([object_id], timeout=0) + if ready: + if self._loop.get_debug(): + logger.debug("%s has been ready.", plain_object_id) + self._complete_future(fut) + return fut + + if plain_object_id not in self._waiting_dict: + linked_list = PlasmaObjectLinkedList(self._loop, plain_object_id) + linked_list.add_done_callback(self._unregister_callback) + self._waiting_dict[plain_object_id] = linked_list + self._waiting_dict[plain_object_id].append(fut) + if self._loop.get_debug(): + logger.debug("%s added to the waiting list.", fut) + + return fut diff --git a/python/ray/experimental/sgd/mnist_example.py b/python/ray/experimental/sgd/mnist_example.py index 8c2fff213c94b..331fbae66e5e8 100755 --- a/python/ray/experimental/sgd/mnist_example.py +++ b/python/ray/experimental/sgd/mnist_example.py @@ -57,6 +57,7 @@ def __init__(self): # Set seed and build layers tf.set_random_seed(0) + self.x = tf.placeholder(tf.float32, [None, 784], name="x") self.y_ = tf.placeholder(tf.float32, [None, 10], name="y_") y_conv, self.keep_prob = deepnn(self.x) @@ -74,6 +75,15 @@ def __init__(self): tf.argmax(y_conv, 1), tf.argmax(self.y_, 1)) self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + def get_loss(self): + return self.loss + + def get_optimizer(self): + return self.optimizer + + def get_variables(self): + return self.variables + def get_feed_dict(self): batch = self.mnist.train.next_batch(50) return { @@ -82,13 +92,14 @@ def get_feed_dict(self): self.keep_prob: 0.5, } - def test_accuracy(self): - return self.accuracy.eval( + def get_metrics(self): + accuracy = self.accuracy.eval( feed_dict={ self.x: self.mnist.test.images, self.y_: self.mnist.test.labels, self.keep_prob: 1.0, }) + return {"accuracy": accuracy} def train_mnist(config, reporter): @@ -101,14 +112,15 @@ def train_mnist(config, reporter): strategy=args.strategy) # Important: synchronize the initial weights of all model replicas - w0 = sgd.for_model(lambda m: m.variables.get_flat()) - sgd.foreach_model(lambda m: m.variables.set_flat(w0)) + w0 = sgd.for_model(lambda m: m.get_variables().get_flat()) + sgd.foreach_model(lambda m: m.get_variables().set_flat(w0)) for i in range(args.num_iters): if i % 10 == 0: start = time.time() loss = sgd.step(fetch_stats=True)["loss"] - acc = sgd.foreach_model(lambda model: model.test_accuracy()) + metrics = sgd.foreach_model(lambda model: model.get_metrics()) + acc = [m["accuracy"] for m in metrics] print("Iter", i, "loss", loss, "accuracy", acc) print("Time per iteration", time.time() - start) assert len(set(acc)) == 1, ("Models out of sync", acc) diff --git a/python/ray/experimental/sgd/model.py b/python/ray/experimental/sgd/model.py index ac8e0eedf23ea..2f12816570088 100644 --- a/python/ray/experimental/sgd/model.py +++ b/python/ray/experimental/sgd/model.py @@ -7,16 +7,38 @@ class Model(object): """Your class must implement this interface to be used with Ray SGD. This supports any form of input pipeline: it is up to you to define it - using TensorFlow. The only requirements are that the loss and optimizer - attributes must be defined. - + using TensorFlow. For an example implementation, see tfbench/test_model.py - - Attributes: - loss (tf.Tensor): Loss function to minimize. - optimizer (tf.train.Optimizer): Optimizer to use to minimize the loss. """ + def get_loss(self): + """Return loss of the model + + Returns: + loss + """ + raise NotImplementedError( + "get_loss of %s is not implemented" % self.__class__.__name__) + + # TODO support complex way of updating gradient, + # e.g. using different optimizers + def get_optimizer(self): + """Return optimizer for the model + + Returns: + optimizer + """ + raise NotImplementedError( + "get_optimizer of %s is not implemented" % self.__class__.__name__) + + def get_metrics(self): + """Return metrics of the model + + Returns: + metrics(dict): e.g. {"accuracy": accuracy(numpy data)} + """ + return {} + def get_feed_dict(self): """Extra values to pass in when computing gradients for the loss. diff --git a/python/ray/experimental/sgd/param_server.py b/python/ray/experimental/sgd/param_server.py index 517d419c36440..a6972772297f2 100644 --- a/python/ray/experimental/sgd/param_server.py +++ b/python/ray/experimental/sgd/param_server.py @@ -67,6 +67,7 @@ def get(self, object_id): client = ray.worker.global_worker.plasma_client assert self.acc_counter == self.num_sgd_workers, self.acc_counter oid = ray.pyarrow.plasma.ObjectID(object_id) + self.accumulated /= self.acc_counter client.put(self.accumulated.flatten(), object_id=oid) self.accumulated = np.zeros_like(self.accumulated) self.acc_counter = 0 diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index a663960683f79..9ce087a6a1977 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -69,7 +69,7 @@ def __init__(self, all_reduce_alg="simple"): if num_workers == 1 and strategy == "ps": - logger.warn( + logger.warning( "The parameter server strategy does not make sense for single " "worker operation, falling back to simple mode.") strategy = "simple" @@ -141,6 +141,7 @@ def foreach_model(self, fn): Returns: List of results from applying the function. """ + results = ray.get([w.foreach_model.remote(fn) for w in self.workers]) out = [] for r in results: diff --git a/python/ray/experimental/sgd/sgd_worker.py b/python/ray/experimental/sgd/sgd_worker.py index 0d4b45c7c8bc4..27c14f9dd52ca 100644 --- a/python/ray/experimental/sgd/sgd_worker.py +++ b/python/ray/experimental/sgd/sgd_worker.py @@ -9,9 +9,10 @@ import tensorflow as tf import ray -from ray.experimental.sgd.util import fetch, run_timeline, warmup -from ray.experimental.sgd.modified_allreduce import sum_gradients_all_reduce, \ - unpack_small_tensors +from ray.experimental.sgd.util import (ensure_plasma_tensorflow_op, fetch, + run_timeline, warmup) +from ray.experimental.sgd.modified_allreduce import (sum_gradients_all_reduce, + unpack_small_tensors) logger = logging.getLogger(__name__) @@ -55,9 +56,11 @@ def __init__(self, with tf.variable_scope("device_%d" % device_idx): model = model_creator(worker_index, device_idx) self.models.append(model) + optimizer = model.get_optimizer() + loss = model.get_loss() grads = [ - t for t in model.optimizer.compute_gradients( - model.loss) if t[0] is not None + t for t in optimizer.compute_gradients(loss) + if t[0] is not None ] grad_ops.append(grads) @@ -110,10 +113,7 @@ def __init__(self, if plasma_op: store_socket = ( ray.worker.global_worker.plasma_client.store_socket_name) - manager_socket = ( - ray.worker.global_worker.plasma_client.manager_socket_name) - if not plasma.tf_plasma_op: - plasma.build_plasma_tensorflow_op() + ensure_plasma_tensorflow_op() # For fetching grads -> plasma self.plasma_in_grads = [] @@ -123,12 +123,11 @@ def __init__(self, ] for j in range(num_grads): grad = self.per_device_grads[0][j] - with tf.device(self.models[0].loss.device): + with tf.device(self.models[0].get_loss().device): plasma_grad = plasma.tf_plasma_op.tensor_to_plasma( [grad], self.plasma_in_grads_oids[j], - plasma_store_socket_name=store_socket, - plasma_manager_socket_name=manager_socket) + plasma_store_socket_name=store_socket) self.plasma_in_grads.append(plasma_grad) # For applying grads <- plasma @@ -145,8 +144,7 @@ def __init__(self, grad_ph = plasma.tf_plasma_op.plasma_to_tensor( self.plasma_out_grads_oids[j], dtype=tf.float32, - plasma_store_socket_name=store_socket, - plasma_manager_socket_name=manager_socket) + plasma_store_socket_name=store_socket) grad_ph = tf.reshape(grad_ph, self.packed_grads_and_vars[0][j][0].shape) logger.debug("Packed tensor {}".format(grad_ph)) @@ -174,10 +172,9 @@ def __init__(self, apply_ops = [] to_apply = unpacked_gv[0] for ix, m in enumerate(self.models): - apply_ops.append( - m.optimizer.apply_gradients( - [(g, v) - for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix])])) + apply_ops.append(m.get_optimizer().apply_gradients([ + (g, v) for ((g, _), (_, v)) in zip(to_apply, unpacked_gv[ix]) + ])) self.apply_op = tf.group(*apply_ops) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) @@ -209,7 +206,7 @@ def compute_gradients(self): # averaged across all devices by allreduce. fetches = self.sess.run( [ - self.models[0].loss, self.per_device_grads[0], + self.models[0].get_loss(), self.per_device_grads[0], self.nccl_control_out ], feed_dict=feed_dict) @@ -229,7 +226,7 @@ def apply_gradients(self, avg_grads): def compute_apply(self): fetches = run_timeline( self.sess, - [self.models[0].loss, self.apply_op, self.nccl_control_out], + [self.models[0].get_loss(), self.apply_op, self.nccl_control_out], feed_dict=self._grad_feed_dict(), name="compute_apply") return fetches[0] @@ -247,7 +244,7 @@ def ps_compute_apply(self, fetch(agg_grad_shard_oids) fetches = run_timeline( self.sess, [ - self.models[0].loss, self.plasma_in_grads, self.apply_op, + self.models[0].get_loss(), self.plasma_in_grads, self.apply_op, self.nccl_control_out ], feed_dict=feed_dict, diff --git a/python/ray/experimental/sgd/tfbench/test_model.py b/python/ray/experimental/sgd/tfbench/test_model.py index d866668f810d5..dd74b6cd23225 100644 --- a/python/ray/experimental/sgd/tfbench/test_model.py +++ b/python/ray/experimental/sgd/tfbench/test_model.py @@ -14,6 +14,7 @@ class MockDataset(): class TFBenchModel(Model): def __init__(self, batch=64, use_cpus=False): + image_shape = [batch, 224, 224, 3] labels_shape = [batch] @@ -45,5 +46,11 @@ def __init__(self, batch=64, use_cpus=False): self.loss = tf.reduce_mean(loss, name='xentropy-loss') self.optimizer = tf.train.GradientDescentOptimizer(1e-6) + def get_loss(self): + return self.loss + + def get_optimizer(self): + return self.optimizer + def get_feed_dict(self): return {} diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index c8df01cb35b25..57549d2dfe281 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -2,10 +2,13 @@ from __future__ import division from __future__ import print_function +import filelock import json import logging import numpy as np import os +import pyarrow +import pyarrow.plasma as plasma import time import tensorflow as tf @@ -33,10 +36,10 @@ def warmup(): def fetch(oids): - local_sched_client = ray.worker.global_worker.local_scheduler_client + raylet_client = ray.worker.global_worker.raylet_client for o in oids: ray_obj_id = ray.ObjectID(o) - local_sched_client.fetch_or_reconstruct([ray_obj_id], True) + raylet_client.fetch_or_reconstruct([ray_obj_id], True) def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""): @@ -120,6 +123,16 @@ def chrome_trace_format(self, filename): logger.info("Wrote chrome timeline to", filename) +def ensure_plasma_tensorflow_op(): + base_path = os.path.join(pyarrow.__path__[0], "tensorflow") + lock_path = os.path.join(base_path, "compile_op.lock") + with filelock.FileLock(lock_path): + if not os.path.exists(os.path.join(base_path, "plasma_op.so")): + plasma.build_plasma_tensorflow_op() + else: + plasma.load_plasma_tensorflow_op() + + if __name__ == "__main__": a = Timeline(1) b = Timeline(2) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index d97cc274f76d6..2048444fbff44 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -9,6 +9,7 @@ import time import ray +from ray.function_manager import FunctionDescriptor import ray.gcs_utils import ray.ray_constants as ray_constants from ray.utils import (decode, binary_to_object_id, binary_to_hex, @@ -234,6 +235,9 @@ def _task_table(self, task_id): execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() task_spec = ray.raylet.task_from_string(task_spec) + function_descriptor_list = task_spec.function_descriptor_list() + function_descriptor = FunctionDescriptor.from_bytes_list( + function_descriptor_list) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), @@ -245,10 +249,14 @@ def _task_table(self, task_id): "ActorCreationDummyObjectID": binary_to_hex( task_spec.actor_creation_dummy_object_id().id()), "ActorCounter": task_spec.actor_counter(), - "FunctionID": binary_to_hex(task_spec.function_id().id()), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), - "RequiredResources": task_spec.required_resources() + "RequiredResources": task_spec.required_resources(), + "FunctionID": binary_to_hex(function_descriptor.function_id.id()), + "FunctionHash": binary_to_hex(function_descriptor.function_hash), + "ModuleName": function_descriptor.module_name, + "ClassName": function_descriptor.class_name, + "FunctionName": function_descriptor.function_name, } return { @@ -358,12 +366,14 @@ def client_table(self): node_info[client_id] = { "ClientID": client_id, "IsInsertion": client.IsInsertion(), - "NodeManagerAddress": decode(client.NodeManagerAddress()), + "NodeManagerAddress": decode( + client.NodeManagerAddress(), allow_none=True), "NodeManagerPort": client.NodeManagerPort(), "ObjectManagerPort": client.ObjectManagerPort(), "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName()), - "RayletSocketName": decode(client.RayletSocketName()), + client.ObjectStoreSocketName(), allow_none=True), + "RayletSocketName": decode( + client.RayletSocketName(), allow_none=True), "Resources": resources } return list(node_info.values()) @@ -425,7 +435,8 @@ def _profile_table(self, component_id): component_type = decode(profile_table_message.ComponentType()) component_id = binary_to_hex(profile_table_message.ComponentId()) - node_ip_address = decode(profile_table_message.NodeIpAddress()) + node_ip_address = decode( + profile_table_message.NodeIpAddress(), allow_none=True) for j in range(profile_table_message.ProfileEventsLength()): profile_event_message = profile_table_message.ProfileEvents(j) diff --git a/python/ray/experimental/test/async_test.py b/python/ray/experimental/test/async_test.py new file mode 100644 index 0000000000000..bdf45f77e8282 --- /dev/null +++ b/python/ray/experimental/test/async_test.py @@ -0,0 +1,150 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import asyncio +import time + +import pytest + +import ray +from ray.experimental import async_api + + +@pytest.fixture +def init(): + ray.init(num_cpus=4) + async_api.init() + asyncio.get_event_loop().set_debug(False) + yield + async_api.shutdown() + ray.shutdown() + + +def gen_tasks(time_scale=0.1): + @ray.remote + def f(n): + time.sleep(n * time_scale) + return n + + tasks = [f.remote(i) for i in range(5)] + return tasks + + +def test_simple(init): + @ray.remote + def f(): + time.sleep(1) + return {"key1": ["value"]} + + future = async_api.as_future(f.remote()) + result = asyncio.get_event_loop().run_until_complete(future) + assert result["key1"] == ["value"] + + +def test_gather(init): + loop = asyncio.get_event_loop() + tasks = gen_tasks() + futures = [async_api.as_future(obj_id) for obj_id in tasks] + results = loop.run_until_complete(asyncio.gather(*futures)) + assert all(a == b for a, b in zip(results, ray.get(tasks))) + + +def test_gather_benchmark(init): + @ray.remote + def f(n): + time.sleep(0.001 * n) + return 42 + + async def test_async(): + sum_time = 0. + for _ in range(50): + tasks = [f.remote(n) for n in range(20)] + start = time.time() + futures = [async_api.as_future(obj_id) for obj_id in tasks] + await asyncio.gather(*futures) + sum_time += time.time() - start + return sum_time + + def baseline(): + sum_time = 0. + for _ in range(50): + tasks = [f.remote(n) for n in range(20)] + start = time.time() + ray.get(tasks) + sum_time += time.time() - start + return sum_time + + # warm up + baseline() + # async get + sum_time_1 = asyncio.get_event_loop().run_until_complete(test_async()) + # get + sum_time_2 = baseline() + + # Ensure the new implementation is not too slow. + assert sum_time_2 * 1.2 > sum_time_1 + + +def test_wait(init): + loop = asyncio.get_event_loop() + tasks = gen_tasks() + futures = [async_api.as_future(obj_id) for obj_id in tasks] + results, _ = loop.run_until_complete(asyncio.wait(futures)) + assert set(results) == set(futures) + + +def test_wait_timeout(init): + loop = asyncio.get_event_loop() + tasks = gen_tasks(10) + futures = [async_api.as_future(obj_id) for obj_id in tasks] + fut = asyncio.wait(futures, timeout=5) + results, _ = loop.run_until_complete(fut) + assert list(results)[0] == futures[0] + + +def test_gather_mixup(init): + loop = asyncio.get_event_loop() + + @ray.remote + def f(n): + time.sleep(n * 0.1) + return n + + async def g(n): + await asyncio.sleep(n * 0.1) + return n + + tasks = [ + async_api.as_future(f.remote(1)), + g(2), + async_api.as_future(f.remote(3)), + g(4) + ] + results = loop.run_until_complete(asyncio.gather(*tasks)) + assert results == [1, 2, 3, 4] + + +def test_wait_mixup(init): + loop = asyncio.get_event_loop() + + @ray.remote + def f(n): + time.sleep(n) + return n + + def g(n): + async def _g(_n): + await asyncio.sleep(_n) + return _n + + return asyncio.ensure_future(_g(n)) + + tasks = [ + async_api.as_future(f.remote(0.1)), + g(7), + async_api.as_future(f.remote(5)), + g(2) + ] + ready, _ = loop.run_until_complete(asyncio.wait(tasks, timeout=4)) + assert set(ready) == {tasks[0], tasks[-1]} diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index 72ec53651df76..659f18decd940 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -5,6 +5,7 @@ import hashlib import inspect import json +import logging import sys import time import traceback @@ -18,6 +19,7 @@ from ray import ray_constants from ray import cloudpickle as pickle from ray.utils import ( + binary_to_hex, is_cython, is_function_or_method, is_class_method, @@ -31,6 +33,228 @@ ["function", "function_name", "max_calls"]) """FunctionExecutionInfo: A named tuple storing remote function information.""" +logger = logging.getLogger(__name__) + + +class FunctionDescriptor(object): + """A class used to describe a python function. + + Attributes: + module_name: the module name that the function belongs to. + class_name: the class name that the function belongs to if exists. + It could be empty is the function is not a class method. + function_name: the function name of the function. + function_hash: the hash code of the function source code if the + function code is available. + function_id: the function id calculated from this descriptor. + is_for_driver_task: whether this descriptor is for driver task. + """ + + def __init__(self, + module_name, + function_name, + class_name="", + function_source_hash=b""): + self._module_name = module_name + self._class_name = class_name + self._function_name = function_name + self._function_source_hash = function_source_hash + self._function_id = self._get_function_id() + + def __repr__(self): + return ("FunctionDescriptor:" + self._module_name + "." + + self._class_name + "." + self._function_name + "." + + binary_to_hex(self._function_source_hash)) + + @classmethod + def from_bytes_list(cls, function_descriptor_list): + """Create a FunctionDescriptor instance from list of bytes. + + This function is used to create the function descriptor from + backend data. + + Args: + cls: Current class which is required argument for classmethod. + function_descriptor_list: list of bytes to represent the + function descriptor. + + Returns: + The FunctionDescriptor instance created from the bytes list. + """ + assert isinstance(function_descriptor_list, list) + if len(function_descriptor_list) == 0: + # This is a function descriptor of driver task. + return FunctionDescriptor.for_driver_task() + elif (len(function_descriptor_list) == 3 + or len(function_descriptor_list) == 4): + module_name = function_descriptor_list[0].decode() + class_name = function_descriptor_list[1].decode() + function_name = function_descriptor_list[2].decode() + if len(function_descriptor_list) == 4: + return cls(module_name, function_name, class_name, + function_descriptor_list[3]) + else: + return cls(module_name, function_name, class_name) + else: + raise Exception( + "Invalid input for FunctionDescriptor.from_bytes_list") + + @classmethod + def from_function(cls, function): + """Create a FunctionDescriptor from a function instance. + + This function is used to create the function descriptor from + a python function. If a function is a class function, it should + not be used by this function. + + Args: + cls: Current class which is required argument for classmethod. + function: the python function used to create the function + descriptor. + + Returns: + The FunctionDescriptor instance created according to the function. + """ + module_name = function.__module__ + function_name = function.__name__ + class_name = "" + + function_source_hasher = hashlib.sha1() + try: + # If we are running a script or are in IPython, include the source + # code in the hash. + source = inspect.getsource(function).encode("ascii") + function_source_hasher.update(source) + function_source_hash = function_source_hasher.digest() + except (IOError, OSError, TypeError): + # Source code may not be available: + # e.g. Cython or Python interpreter. + function_source_hash = b"" + + return cls(module_name, function_name, class_name, + function_source_hash) + + @classmethod + def from_class(cls, target_class): + """Create a FunctionDescriptor from a class. + + Args: + cls: Current class which is required argument for classmethod. + target_class: the python class used to create the function + descriptor. + + Returns: + The FunctionDescriptor instance created according to the class. + """ + module_name = target_class.__module__ + class_name = target_class.__name__ + return cls(module_name, "__init__", class_name) + + @classmethod + def for_driver_task(cls): + """Create a FunctionDescriptor instance for a driver task.""" + return cls("", "", "", b"") + + @property + def is_for_driver_task(self): + """See whether this function descriptor is for a driver or not. + + Returns: + True if this function descriptor is for driver tasks. + """ + return all( + len(x) == 0 + for x in [self.module_name, self.class_name, self.function_name]) + + @property + def module_name(self): + """Get the module name of current function descriptor. + + Returns: + The module name of the function descriptor. + """ + return self._module_name + + @property + def class_name(self): + """Get the class name of current function descriptor. + + Returns: + The class name of the function descriptor. It could be + empty if the function is not a class method. + """ + return self._class_name + + @property + def function_name(self): + """Get the function name of current function descriptor. + + Returns: + The function name of the function descriptor. + """ + return self._function_name + + @property + def function_hash(self): + """Get the hash code of the function source code. + + Returns: + The bytes with length of ray_constants.ID_SIZE if the source + code is available. Otherwise, the bytes length will be 0. + """ + return self._function_source_hash + + @property + def function_id(self): + """Get the function id calculated from this descriptor. + + Returns: + The value of ray.ObjectID that represents the function id. + """ + return ray.ObjectID(self._function_id) + + def _get_function_id(self): + """Calculate the function id of current function descriptor. + + This function id is calculated from all the fields of function + descriptor. + + Returns: + bytes with length of ray_constants.ID_SIZE. + """ + if self.is_for_driver_task: + return ray_constants.NIL_FUNCTION_ID.id() + function_id_hash = hashlib.sha1() + # Include the function module and name in the hash. + function_id_hash.update(self.module_name.encode("ascii")) + function_id_hash.update(self.function_name.encode("ascii")) + function_id_hash.update(self.class_name.encode("ascii")) + function_id_hash.update(self._function_source_hash) + # Compute the function ID. + function_id = function_id_hash.digest() + assert len(function_id) == ray_constants.ID_SIZE + return function_id + + def get_function_descriptor_list(self): + """Return a list of bytes representing the function descriptor. + + This function is used to pass this function descriptor to backend. + + Returns: + A list of bytes. + """ + descriptor_list = [] + if self.is_for_driver_task: + # Driver task returns an empty list. + return descriptor_list + else: + descriptor_list.append(self.module_name.encode("ascii")) + descriptor_list.append(self.class_name.encode("ascii")) + descriptor_list.append(self.function_name.encode("ascii")) + if len(self._function_source_hash) != 0: + descriptor_list.append(self._function_source_hash) + return descriptor_list + class FunctionActorManager(object): """A class used to export/load remote functions and actors. @@ -45,6 +269,8 @@ class FunctionActorManager(object): and execution_info. _num_task_executions: The map from driver_id to function execution times. + imported_actor_classes: The set of actor classes keys (format: + ActorClass:function_id) that are already in GCS. """ def __init__(self, worker): @@ -58,11 +284,17 @@ def __init__(self, worker): # workers that execute remote functions. self._function_execution_info = defaultdict(lambda: {}) self._num_task_executions = defaultdict(lambda: {}) + # A set of all of the actor class keys that have been imported by the + # import thread. It is safe to convert this worker into an actor of + # these types. + self.imported_actor_classes = set() - def increase_task_counter(self, driver_id, function_id): + def increase_task_counter(self, driver_id, function_descriptor): + function_id = function_descriptor.function_id.id() self._num_task_executions[driver_id][function_id] += 1 - def get_task_counter(self, driver_id, function_id): + def get_task_counter(self, driver_id, function_descriptor): + function_id = function_descriptor.function_id.id() return self._num_task_executions[driver_id][function_id] def export_cached(self): @@ -124,13 +356,13 @@ def _do_export(self, remote_function): check_oversized_pickle(pickled_function, remote_function._function_name, "remote function", self._worker) - key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" + - remote_function._function_id) + remote_function._function_descriptor.function_id.id()) self._worker.redis_client.hmset( key, { "driver_id": self._worker.task_driver_id.id(), - "function_id": remote_function._function_id, + "function_id": remote_function._function_descriptor. + function_id.id(), "name": remote_function._function_name, "module": function.__module__, "function": pickled_function, @@ -193,24 +425,35 @@ def f(): self._worker.redis_client.rpush( b"FunctionTable:" + function_id.id(), self._worker.worker_id) - def get_execution_info(self, driver_id, function_id): + def get_execution_info(self, driver_id, function_descriptor): """Get the FunctionExecutionInfo of a remote function. Args: driver_id: ID of the driver that the function belongs to. - function_id: ID of the function to get. + function_descriptor: The FunctionDescriptor of the function to get. Returns: A FunctionExecutionInfo object. """ - # Wait until the function to be executed has actually been registered - # on this worker. We will push warnings to the user if we spend too - # long in this loop. - with profiling.profile("wait_for_function", worker=self._worker): - self._wait_for_function(function_id, driver_id) - return self._function_execution_info[driver_id][function_id.id()] + function_id = function_descriptor.function_id.id() - def _wait_for_function(self, function_id, driver_id, timeout=10): + # Wait until the function to be executed has actually been + # registered on this worker. We will push warnings to the user if + # we spend too long in this loop. + # The driver function may not be found in sys.path. Try to load + # the function from GCS. + with profiling.profile("wait_for_function", worker=self._worker): + self._wait_for_function(function_descriptor, driver_id) + try: + info = self._function_execution_info[driver_id][function_id] + except KeyError as e: + message = ("Error occurs in get_execution_info: " + "driver_id: %s, function_descriptor: %s. Message: %s" % + (binary_to_hex(driver_id), function_descriptor, e)) + raise KeyError(message) + return info + + def _wait_for_function(self, function_descriptor, driver_id, timeout=10): """Wait until the function to be executed is present on this worker. This method will simply loop until the import thread has imported the @@ -221,7 +464,8 @@ def _wait_for_function(self, function_id, driver_id, timeout=10): been defined. Args: - function_id (str): The ID of the function that we want to execute. + function_descriptor : The FunctionDescriptor of the function that + we want to execute. driver_id (str): The ID of the driver to push the error message to if this times out. """ @@ -231,7 +475,7 @@ def _wait_for_function(self, function_id, driver_id, timeout=10): while True: with self._worker.lock: if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID - and (function_id.id() in + and (function_descriptor.function_id.id() in self._function_execution_info[driver_id])): break elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and ( @@ -251,24 +495,6 @@ def _wait_for_function(self, function_id, driver_id, timeout=10): warning_sent = True time.sleep(0.001) - @classmethod - def compute_actor_method_function_id(cls, class_name, attr): - """Get the function ID corresponding to an actor method. - - Args: - class_name (str): The class name of the actor. - attr (str): The attribute name of the method. - - Returns: - Function ID corresponding to the method. - """ - function_id_hash = hashlib.sha1() - function_id_hash.update(class_name.encode("ascii")) - function_id_hash.update(attr.encode("ascii")) - function_id = function_id_hash.digest() - assert len(function_id) == ray_constants.ID_SIZE - return ray.ObjectID(function_id) - def _publish_actor_class_to_key(self, key, actor_class_info): """Push an actor class definition to Redis. @@ -287,9 +513,11 @@ def _publish_actor_class_to_key(self, key, actor_class_info): self._worker.redis_client.hmset(key, actor_class_info) self._worker.redis_client.rpush("Exports", key) - def export_actor_class(self, class_id, Class, actor_method_names, + def export_actor_class(self, Class, actor_method_names, checkpoint_interval): - key = b"ActorClass:" + class_id + function_descriptor = FunctionDescriptor.from_class(Class) + key = (b"ActorClass:" + self._worker.task_driver_id.id() + b":" + + function_descriptor.function_id.id()) actor_class_info = { "class_name": Class.__name__, "module": Class.__module__, @@ -318,6 +546,18 @@ def export_actor_class(self, class_id, Class, actor_method_names, # within tasks. I tried to disable this, but it may be necessary # because of https://github.com/ray-project/ray/issues/1146. + def load_actor(self, driver_id, function_descriptor): + key = (b"ActorClass:" + driver_id + b":" + + function_descriptor.function_id.id()) + # Wait for the actor class key to have been imported by the + # import thread. TODO(rkn): It shouldn't be possible to end + # up in an infinite loop here, but we should push an error to + # the driver if too much time is spent here. + while key not in self.imported_actor_classes: + time.sleep(0.001) + with self._worker.lock: + self.fetch_and_register_actor(key) + def fetch_and_register_actor(self, actor_class_key): """Import an actor. @@ -330,11 +570,10 @@ def fetch_and_register_actor(self, actor_class_key): worker: The worker to use. """ actor_id_str = self._worker.actor_id - (driver_id, class_id, class_name, module, pickled_class, - checkpoint_interval, + (driver_id, class_name, module, pickled_class, checkpoint_interval, actor_method_names) = self._worker.redis_client.hmget( actor_class_key, [ - "driver_id", "class_id", "class_name", "module", "class", + "driver_id", "class_name", "module", "class", "checkpoint_interval", "actor_method_names" ]) @@ -368,9 +607,9 @@ def temporary_actor_method(*xs): # Register the actor method executors. for actor_method_name in actor_method_names: - function_id = ( - FunctionActorManager.compute_actor_method_function_id( - class_name, actor_method_name).id()) + function_descriptor = FunctionDescriptor(module, actor_method_name, + class_name) + function_id = function_descriptor.function_id.id() temporary_executor = self._make_actor_method_executor( actor_method_name, temporary_actor_method, @@ -409,9 +648,9 @@ def temporary_actor_method(*xs): actor_methods = inspect.getmembers( unpickled_class, predicate=is_function_or_method) for actor_method_name, actor_method in actor_methods: - function_id = ( - FunctionActorManager.compute_actor_method_function_id( - class_name, actor_method_name).id()) + function_descriptor = FunctionDescriptor( + module, actor_method_name, class_name) + function_id = function_descriptor.function_id.id() executor = self._make_actor_method_executor( actor_method_name, actor_method, actor_imported=True) self._function_execution_info[driver_id][function_id] = ( @@ -452,7 +691,9 @@ def actor_method_executor(dummy_return_id, actor, *args): # If this is the first task to execute on the actor, try to resume # from a checkpoint. - if actor_imported and self._worker.actor_task_counter == 1: + # Current __init__ will be called by default. So the real function + # call will start from 2. + if actor_imported and self._worker.actor_task_counter == 2: checkpoint_resumed = ray.actor.restore_and_log_checkpoint( self._worker, actor) if checkpoint_resumed: diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 347f7ab9f8064..c477f4bcc9e83 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -3,26 +3,26 @@ from __future__ import print_function import flatbuffers - import ray.core.generated.ErrorTableData -from ray.core.generated.GcsTableEntry import GcsTableEntry from ray.core.generated.ClientTableData import ClientTableData +from ray.core.generated.DriverTableData import DriverTableData from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.ProfileTableData import ProfileTableData -from ray.core.generated.HeartbeatTableData import HeartbeatTableData +from ray.core.generated.GcsTableEntry import GcsTableEntry from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData -from ray.core.generated.DriverTableData import DriverTableData +from ray.core.generated.HeartbeatTableData import HeartbeatTableData +from ray.core.generated.Language import Language from ray.core.generated.ObjectTableData import ObjectTableData -from ray.core.generated.ray.protocol.Task import Task - +from ray.core.generated.ProfileTableData import ProfileTableData from ray.core.generated.TablePrefix import TablePrefix from ray.core.generated.TablePubsub import TablePubsub +from ray.core.generated.ray.protocol.Task import Task + __all__ = [ "GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData", "HeartbeatBatchTableData", "DriverTableData", "ProfileTableData", - "ObjectTableData", "Task", "TablePrefix", "TablePubsub", + "ObjectTableData", "Task", "TablePrefix", "TablePubsub", "Language", "construct_error_message" ] diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 70dba322370bb..08031c7b603b2 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -98,7 +98,7 @@ def _process_key(self, key): # Keep track of the fact that this actor class has been # exported so that we know it is safe to turn this worker # into an actor of that class. - self.worker.imported_actor_classes.add(key) + self.worker.function_actor_manager.imported_actor_classes.add(key) # TODO(rkn): We may need to bring back the case of # fetching actor classes here. else: diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index 7772974319aea..c4acbbbf60cf6 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -42,4 +42,4 @@ def free(object_ids, local_only=False, worker=None): if len(object_ids) == 0: return - worker.local_scheduler_client.free(object_ids, local_only) + worker.raylet_client.free_objects(object_ids, local_only) diff --git a/python/ray/memory_monitor.py b/python/ray/memory_monitor.py index 00cf86816dbf8..a52f98d7077df 100644 --- a/python/ray/memory_monitor.py +++ b/python/ray/memory_monitor.py @@ -37,7 +37,8 @@ def get_message(used_gb, total_gb, threshold): round(psutil.virtual_memory().shared / 1e9, 2)) + "currently being used by the Ray object store. You can set " "the object store size with the `object_store_memory` " - "parameter when starting Ray.") + "parameter when starting Ray, and the max Redis size with " + "`redis_max_memory`.") class MemoryMonitor(object): diff --git a/python/ray/profiling.py b/python/ray/profiling.py index 42b02f8926be8..d57d827cdd037 100644 --- a/python/ray/profiling.py +++ b/python/ray/profiling.py @@ -119,7 +119,7 @@ def flush_profile_data(self): else: component_type = "driver" - self.worker.local_scheduler_client.push_profile_events( + self.worker.raylet_client.push_profile_events( component_type, ray.ObjectID(self.worker.worker_id), self.worker.node_ip_address, events) @@ -128,6 +128,19 @@ def add_event(self, event): self.events.append(event) +class NoopProfiler(object): + """A no-op profile used when collect_profile_data=False.""" + + def start_flush_thread(self): + pass + + def flush_profile_data(self): + pass + + def add_event(self, event): + pass + + class RayLogSpanRaylet(object): """An object used to enable logging a span of events with a with statement. diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index a1d5e1a765438..fc89d48ed63ca 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -16,6 +16,7 @@ def env_integer(key, default): ID_SIZE = 20 NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff") +NIL_FUNCTION_ID = NIL_JOB_ID # If a remote function or actor (or some other export) has serialized size # greater than this quantity, print an warning. @@ -76,3 +77,8 @@ def env_integer(key, default): LOGGER_LEVEL_CHOICES = ['debug', 'info', 'warning', 'error', 'critical'] LOGGER_LEVEL_HELP = ("The logging level threshold, choices=['debug', 'info'," " 'warning', 'error', 'critical'], default='info'") + +# A constant indicating that an actor doesn't need reconstructions. +NO_RECONSTRUCTION = 0 +# A constant indicating that an actor should be reconstructed infinite times. +INFINITE_RECONSTRUCTION = 2**30 diff --git a/python/ray/raylet/__init__.py b/python/ray/raylet/__init__.py index 8757f59741567..69545f5c69366 100644 --- a/python/ray/raylet/__init__.py +++ b/python/ray/raylet/__init__.py @@ -2,12 +2,12 @@ from __future__ import division from __future__ import print_function -from ray.core.src.ray.raylet.liblocal_scheduler_library_python import ( - Task, LocalSchedulerClient, ObjectID, check_simple_value, compute_task_id, +from ray.core.src.ray.raylet.libraylet_library_python import ( + Task, RayletClient, ObjectID, check_simple_value, compute_task_id, task_from_string, task_to_string, _config, common_error) __all__ = [ - "Task", "LocalSchedulerClient", "ObjectID", "check_simple_value", + "Task", "RayletClient", "ObjectID", "check_simple_value", "compute_task_id", "task_from_string", "task_to_string", "start_local_scheduler", "_config", "common_error" ] diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index fb2a29e45c512..144fbbc441abc 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -3,11 +3,9 @@ from __future__ import print_function import copy -import hashlib -import inspect import logging -import ray.ray_constants as ray_constants +from ray.function_manager import FunctionDescriptor import ray.signature # Default parameters for remote functions. @@ -18,33 +16,6 @@ logger = logging.getLogger(__name__) -def compute_function_id(function): - """Compute an function ID for a function. - - Args: - func: The actual function. - - Returns: - Raw bytes of the function id - """ - function_id_hash = hashlib.sha1() - # Include the function module and name in the hash. - function_id_hash.update(function.__module__.encode("ascii")) - function_id_hash.update(function.__name__.encode("ascii")) - try: - # If we are running a script or are in IPython, include the source code - # in the hash. - source = inspect.getsource(function).encode("ascii") - function_id_hash.update(source) - except (IOError, OSError, TypeError): - # Source code may not be available: e.g. Cython or Python interpreter. - pass - # Compute the function ID. - function_id = function_id_hash.digest() - assert len(function_id) == ray_constants.ID_SIZE - return function_id - - class RemoteFunction(object): """A remote function. @@ -52,7 +23,7 @@ class RemoteFunction(object): Attributes: _function: The original function. - _function_id: The ID of the function. + _function_descriptor: The function descriptor. _function_name: The module and function name. _num_cpus: The default number of CPUs to use for invocations of this remote function. @@ -70,10 +41,7 @@ class RemoteFunction(object): def __init__(self, function, num_cpus, num_gpus, resources, num_return_vals, max_calls): self._function = function - # TODO(rkn): We store the function ID as a string, so that - # RemoteFunction objects can be pickled. We should undo this when - # we allow ObjectIDs to be pickled. - self._function_id = compute_function_id(function) + self._function_descriptor = FunctionDescriptor.from_function(function) self._function_name = ( self._function.__module__ + '.' + self._function.__name__) self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS @@ -109,7 +77,7 @@ def _submit(self, num_cpus=None, num_gpus=None, resources=None): - logger.warn( + logger.warning( "WARNING: _submit() is being deprecated. Please use _remote().") return self._remote( args=args, @@ -147,7 +115,7 @@ def _remote(self, result = self._function(*copy.deepcopy(args)) return result object_ids = worker.submit_task( - ray.ObjectID(self._function_id), + self._function_descriptor, args, num_return_vals=num_return_vals, resources=resources) diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index fd6ba3407eaec..b499f107123f6 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -31,12 +31,11 @@ def _setup_logger(): def _register_all(): - for key in [ - "PPO", "ES", "DQN", "APEX", "A3C", "PG", "DDPG", "APEX_DDPG", - "IMPALA", "ARS", "A2C", "__fake", "__sigmoid_fake_data", - "__parameter_tuning" - ]: - from ray.rllib.agents.agent import get_agent_class + from ray.rllib.agents.registry import ALGORITHMS + from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS + for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys( + )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]: + from ray.rllib.agents.registry import get_agent_class register_trainable(key, get_agent_class(key)) diff --git a/python/ray/rllib/agents/a3c/a2c.py b/python/ray/rllib/agents/a3c/a2c.py index f4e7f394afb3d..c344592b90863 100644 --- a/python/ray/rllib/agents/a3c/a2c.py +++ b/python/ray/rllib/agents/a3c/a2c.py @@ -4,6 +4,7 @@ from ray.rllib.agents.a3c.a3c import A3CAgent, DEFAULT_CONFIG as A3C_CONFIG from ray.rllib.optimizers import SyncSamplesOptimizer +from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts A2C_DEFAULT_CONFIG = merge_dicts( @@ -22,6 +23,7 @@ class A2CAgent(A3CAgent): _agent_name = "A2C" _default_config = A2C_DEFAULT_CONFIG + @override(A3CAgent) def _make_optimizer(self): return SyncSamplesOptimizer(self.local_evaluator, self.remote_evaluators, diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index ebfec99e3b241..43daa0b3ef781 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -7,6 +7,7 @@ from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.optimizers import AsyncGradientsOptimizer +from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -44,6 +45,7 @@ class A3CAgent(Agent): _default_config = DEFAULT_CONFIG _policy_graph = A3CPolicyGraph + @override(Agent) def _init(self): if self.config["use_pytorch"]: from ray.rllib.agents.a3c.a3c_torch_policy_graph import \ @@ -58,11 +60,7 @@ def _init(self): self.env_creator, policy_cls, self.config["num_workers"]) self.optimizer = self._make_optimizer() - def _make_optimizer(self): - return AsyncGradientsOptimizer(self.local_evaluator, - self.remote_evaluators, - self.config["optimizer"]) - + @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled start = time.time() @@ -73,3 +71,8 @@ def _train(self): result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) return result + + def _make_optimizer(self): + return AsyncGradientsOptimizer(self.local_evaluator, + self.remote_evaluators, + self.config["optimizer"]) diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py index 8aa60645aaebd..50258f58ac3aa 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py @@ -10,10 +10,12 @@ import ray from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override class A3CLoss(object): @@ -118,30 +120,11 @@ def __init__(self, observation_space, action_space, config): self.sess.run(tf.global_variables_initializer()) - def extra_compute_action_fetches(self): - return {"vf_preds": self.vf} - - def value(self, ob, *args): - feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} - assert len(args) == len(self.model.state_in), \ - (args, self.model.state_in) - for k, v in zip(self.model.state_in, args): - feed_dict[k] = v - vf = self.sess.run(self.vf, feed_dict) - return vf[0] - - def gradients(self, optimizer): - grads = tf.gradients(self.loss.total_loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - - def extra_compute_grad_fetches(self): - return self.stats_fetches - + @override(PolicyGraph) def get_initial_state(self): return self.model.state_init + @override(PolicyGraph) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -153,6 +136,30 @@ def postprocess_trajectory(self, next_state = [] for i in range(len(self.model.state_in)): next_state.append([sample_batch["state_out_{}".format(i)][-1]]) - last_r = self.value(sample_batch["new_obs"][-1], *next_state) + last_r = self._value(sample_batch["new_obs"][-1], *next_state) return compute_advantages(sample_batch, last_r, self.config["gamma"], self.config["lambda"]) + + @override(TFPolicyGraph) + def gradients(self, optimizer): + grads = tf.gradients(self.loss.total_loss, self.var_list) + self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) + clipped_grads = list(zip(self.grads, self.var_list)) + return clipped_grads + + @override(TFPolicyGraph) + def extra_compute_grad_fetches(self): + return self.stats_fetches + + @override(TFPolicyGraph) + def extra_compute_action_fetches(self): + return {"vf_preds": self.vf} + + def _value(self, ob, *args): + feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} + assert len(args) == len(self.model.state_in), \ + (args, self.model.state_in) + for k, v in zip(self.model.state_in, args): + feed_dict[k] = v + vf = self.sess.run(self.vf, feed_dict) + return vf[0] diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py index 3eecc3bb1921d..c24340d8d10a0 100644 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py @@ -10,7 +10,9 @@ from ray.rllib.models.pytorch.misc import var_to_np from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph +from ray.rllib.utils.annotations import override class A3CLoss(nn.Module): @@ -56,12 +58,15 @@ def __init__(self, obs_space, action_space, config): loss, loss_inputs=["obs", "actions", "advantages", "value_targets"]) + @override(TorchPolicyGraph) def extra_action_out(self, model_out): return {"vf_preds": var_to_np(model_out[1])} + @override(TorchPolicyGraph) def optimizer(self): return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) + @override(PolicyGraph) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index f0d9510756b93..6f151b8e089cd 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -2,20 +2,25 @@ from __future__ import division from __future__ import print_function +from datetime import datetime import copy -import os import logging +import os import pickle +import six import tempfile -from datetime import datetime import tensorflow as tf +from types import FunctionType import ray +from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.utils.annotations import override from ray.rllib.utils import FilterManager, deep_update, merge_dicts -from ray.tune.registry import ENV_CREATOR, _global_registry +from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.trainable import Trainable from ray.tune.trial import Resources from ray.tune.logger import UnifiedLogger @@ -40,6 +45,7 @@ "on_episode_step": None, # arg: {"env": .., "episode": ...} "on_episode_end": None, # arg: {"env": .., "episode": ...} "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} + "on_train_result": None, # arg: {"agent": ..., "result": ...} }, # === Policy === @@ -61,6 +67,8 @@ # Whether to clip rewards prior to experience postprocessing. Setting to # None means clip for Atari only. "clip_rewards": None, + # Whether to np.clip() actions to the action space low/high range spec. + "clip_actions": True, # Whether to use rllib or deepmind preprocessors by default "preprocessor_pref": "deepmind", @@ -117,11 +125,45 @@ "intra_op_parallelism_threads": 8, "inter_op_parallelism_threads": 8, }, - # Whether to LZ4 compress observations + # Whether to LZ4 compress individual observations "compress_observations": False, # Drop metric batches from unresponsive workers after this many seconds "collect_metrics_timeout": 180, + # === Offline Data Input / Output (Experimental) === + # Specify how to generate experiences: + # - "sampler": generate experiences via online simulation (default) + # - a local directory or file glob expression (e.g., "/tmp/*.json") + # - a list of individual file paths/URIs (e.g., ["/tmp/1.json", + # "s3://bucket/2.json"]) + # - a dict with string keys and sampling probabilities as values (e.g., + # {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}). + # - a function that returns a rllib.offline.InputReader + "input": "sampler", + # Specify how to evaluate the current policy. This only makes sense to set + # when the input is not already generating simulation data: + # - None: don't evaluate the policy. The episode reward and other + # metrics will be NaN if using offline data. + # - "simulation": run the environment in the background, but use + # this data for evaluation only and not for learning. + # - "counterfactual": use counterfactual policy evaluation to estimate + # performance (this option is not implemented yet). + "input_evaluation": None, + # Specify where experiences should be saved: + # - None: don't save any experiences + # - "logdir" to save to the agent log dir + # - a path/URI to save to a custom output directory (e.g., "s3://bucket/") + # - a function that returns a rllib.offline.OutputWriter + "output": None, + # What sample batch columns to LZ4 compress in the output data. + "output_compress_columns": ["obs", "new_obs"], + # Max output file size before rolling over to a new file. + "output_max_file_size": 64 * 1024 * 1024, + # Whether to run postprocess_trajectory() on the trajectory fragments from + # offline inputs. Whether this makes sense is algorithm-specific. + # TODO(ekl) implement this and multi-agent batch handling + # "postprocess_inputs": False, + # === Multiagent === "multiagent": { # Map from policy ids to tuples of (policy_graph_cls, obs_space, @@ -162,100 +204,6 @@ class Agent(Trainable): "tf_session_args", "env_config", "model", "optimizer", "multiagent" ] - @classmethod - def default_resource_request(cls, config): - cf = dict(cls._default_config, **config) - Agent._validate_config(cf) - # TODO(ekl): add custom resources here once tune supports them - return Resources( - cpu=cf["num_cpus_for_driver"], - gpu=cf["num_gpus"], - extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], - extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - - def make_local_evaluator(self, env_creator, policy_graph): - """Convenience method to return configured local evaluator.""" - - return self._make_evaluator( - PolicyEvaluator, - env_creator, - policy_graph, - 0, - # important: allow local tf to use more CPUs for optimization - merge_dicts(self.config, { - "tf_session_args": self. - config["local_evaluator_tf_session_args"] - })) - - def make_remote_evaluators(self, env_creator, policy_graph, count): - """Convenience method to return a number of remote evaluators.""" - - remote_args = { - "num_cpus": self.config["num_cpus_per_worker"], - "num_gpus": self.config["num_gpus_per_worker"], - "resources": self.config["custom_resources_per_worker"], - } - - cls = PolicyEvaluator.as_remote(**remote_args).remote - return [ - self._make_evaluator(cls, env_creator, policy_graph, i + 1, - self.config) for i in range(count) - ] - - def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, - config): - def session_creator(): - logger.debug("Creating TF session {}".format( - config["tf_session_args"])) - return tf.Session( - config=tf.ConfigProto(**config["tf_session_args"])) - - return cls( - env_creator, - self.config["multiagent"]["policy_graphs"] or policy_graph, - policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], - policies_to_train=self.config["multiagent"]["policies_to_train"], - tf_session_creator=(session_creator - if config["tf_session_args"] else None), - batch_steps=config["sample_batch_size"], - batch_mode=config["batch_mode"], - episode_horizon=config["horizon"], - preprocessor_pref=config["preprocessor_pref"], - sample_async=config["sample_async"], - compress_observations=config["compress_observations"], - num_envs=config["num_envs_per_worker"], - observation_filter=config["observation_filter"], - clip_rewards=config["clip_rewards"], - env_config=config["env_config"], - model_config=config["model"], - policy_config=config, - worker_index=worker_index, - monitor_path=self.logdir if config["monitor"] else None, - log_level=config["log_level"], - callbacks=config["callbacks"]) - - @classmethod - def resource_help(cls, config): - return ("\n\nYou can adjust the resource requests of RLlib agents by " - "setting `num_workers` and other configs. See the " - "DEFAULT_CONFIG defined by each agent for more info.\n\n" - "The config of this agent is: {}".format(config)) - - @staticmethod - def _validate_config(config): - if "gpu" in config: - raise ValueError( - "The `gpu` config is deprecated, please use `num_gpus=0|1` " - "instead.") - if "gpu_fraction" in config: - raise ValueError( - "The `gpu_fraction` config is deprecated, please use " - "`num_gpus=` instead.") - if "use_gpu_for_workers" in config: - raise ValueError( - "The `use_gpu_for_workers` config is deprecated, please use " - "`num_gpus_per_worker=1` instead.") - def __init__(self, config=None, env=None, logger_creator=None): """Initialize an RLLib agent. @@ -268,13 +216,12 @@ def __init__(self, config=None, env=None, logger_creator=None): """ config = config or {} - Agent._validate_config(config) # Vars to synchronize to evaluators on each train call self.global_vars = {"timestep": 0} # Agents allow env ids to be passed directly to the constructor. - self._env_id = env or config.get("env") + self._env_id = _register_if_needed(env or config.get("env")) # Create a default logger creator if no logger_creator is specified if logger_creator is None: @@ -296,6 +243,19 @@ def default_logger_creator(config): Trainable.__init__(self, config, logger_creator) + @classmethod + @override(Trainable) + def default_resource_request(cls, config): + cf = dict(cls._default_config, **config) + Agent._validate_config(cf) + # TODO(ekl): add custom resources here once tune supports them + return Resources( + cpu=cf["num_cpus_for_driver"], + gpu=cf["num_gpus"], + extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], + extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) + + @override(Trainable) def train(self): """Overrides super.train to synchronize global vars.""" @@ -316,8 +276,15 @@ def train(self): logger.debug("synchronized filters: {}".format( self.local_evaluator.filters)) - return Trainable.train(self) + result = Trainable.train(self) + if self.config["callbacks"].get("on_train_result"): + self.config["callbacks"]["on_train_result"]({ + "agent": self, + "result": result, + }) + return result + @override(Trainable) def _setup(self, config): env = self._env_id if env: @@ -336,6 +303,7 @@ def _setup(self, config): self._allow_unknown_configs, self._allow_unknown_subkeys) self.config = merged_config + Agent._validate_config(self.config) if self.config.get("log_level"): logging.getLogger("ray.rllib").setLevel(self.config["log_level"]) @@ -343,26 +311,29 @@ def _setup(self, config): with tf.Graph().as_default(): self._init() - def _init(self): - """Subclasses should override this for custom initialization.""" - - raise NotImplementedError - - @property - def iteration(self): - """Current training iter, auto-incremented with each train() call.""" - - return self._iteration + @override(Trainable) + def _stop(self): + # workaround for https://github.com/ray-project/ray/issues/1516 + if hasattr(self, "remote_evaluators"): + for ev in self.remote_evaluators: + ev.__ray_terminate__.remote() + if hasattr(self, "optimizer"): + self.optimizer.stop() - @property - def _agent_name(self): - """Subclasses should override this to declare their name.""" + @override(Trainable) + def _save(self, checkpoint_dir): + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(self.iteration)) + pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) + return checkpoint_path - raise NotImplementedError + @override(Trainable) + def _restore(self, checkpoint_path): + extra_data = pickle.load(open(checkpoint_path, "rb")) + self.__setstate__(extra_data) - @property - def _default_config(self): - """Subclasses should override this to declare their default config.""" + def _init(self): + """Subclasses should override this for custom initialization.""" raise NotImplementedError @@ -381,8 +352,10 @@ def compute_action(self, observation, state=None, policy_id="default"): if state is None: state = [] + preprocessed = self.local_evaluator.preprocessors[policy_id].transform( + observation) filtered_obs = self.local_evaluator.filters[policy_id]( - observation, update=False) + preprocessed, update=False) if state: return self.local_evaluator.for_policy( lambda p: p.compute_single_action(filtered_obs, state), @@ -391,6 +364,24 @@ def compute_action(self, observation, state=None, policy_id="default"): lambda p: p.compute_single_action(filtered_obs, state)[0], policy_id=policy_id) + @property + def iteration(self): + """Current training iter, auto-incremented with each train() call.""" + + return self._iteration + + @property + def _agent_name(self): + """Subclasses should override this to declare their name.""" + + raise NotImplementedError + + @property + def _default_config(self): + """Subclasses should override this to declare their default config.""" + + raise NotImplementedError + def get_weights(self, policies=None): """Return a dictionary of policy ids to weights. @@ -408,13 +399,158 @@ def set_weights(self, weights): """ self.local_evaluator.set_weights(weights) - def _stop(self): - # workaround for https://github.com/ray-project/ray/issues/1516 - if hasattr(self, "remote_evaluators"): - for ev in self.remote_evaluators: - ev.__ray_terminate__.remote() - if hasattr(self, "optimizer"): - self.optimizer.stop() + def make_local_evaluator(self, env_creator, policy_graph): + """Convenience method to return configured local evaluator.""" + + return self._make_evaluator( + PolicyEvaluator, + env_creator, + policy_graph, + 0, + # important: allow local tf to use more CPUs for optimization + merge_dicts(self.config, { + "tf_session_args": self. + config["local_evaluator_tf_session_args"] + })) + + def make_remote_evaluators(self, env_creator, policy_graph, count): + """Convenience method to return a number of remote evaluators.""" + + remote_args = { + "num_cpus": self.config["num_cpus_per_worker"], + "num_gpus": self.config["num_gpus_per_worker"], + "resources": self.config["custom_resources_per_worker"], + } + + cls = PolicyEvaluator.as_remote(**remote_args).remote + return [ + self._make_evaluator(cls, env_creator, policy_graph, i + 1, + self.config) for i in range(count) + ] + + def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): + """Export policy model with given policy_id to local directory. + + Arguments: + export_dir (string): Writable local directory. + policy_id (string): Optional policy id to export. + + Example: + >>> agent = MyAgent() + >>> for _ in range(10): + >>> agent.train() + >>> agent.export_policy_model("/tmp/export_dir") + """ + self.local_evaluator.export_policy_model(export_dir, policy_id) + + def export_policy_checkpoint(self, + export_dir, + filename_prefix="model", + policy_id=DEFAULT_POLICY_ID): + """Export tensorflow policy model checkpoint to local directory. + + Arguments: + export_dir (string): Writable local directory. + filename_prefix (string): file name prefix of checkpoint files. + policy_id (string): Optional policy id to export. + + Example: + >>> agent = MyAgent() + >>> for _ in range(10): + >>> agent.train() + >>> agent.export_policy_checkpoint("/tmp/export_dir") + """ + self.local_evaluator.export_policy_checkpoint( + export_dir, filename_prefix, policy_id) + + @classmethod + def resource_help(cls, config): + return ("\n\nYou can adjust the resource requests of RLlib agents by " + "setting `num_workers` and other configs. See the " + "DEFAULT_CONFIG defined by each agent for more info.\n\n" + "The config of this agent is: {}".format(config)) + + @staticmethod + def _validate_config(config): + if "gpu" in config: + raise ValueError( + "The `gpu` config is deprecated, please use `num_gpus=0|1` " + "instead.") + if "gpu_fraction" in config: + raise ValueError( + "The `gpu_fraction` config is deprecated, please use " + "`num_gpus=` instead.") + if "use_gpu_for_workers" in config: + raise ValueError( + "The `use_gpu_for_workers` config is deprecated, please use " + "`num_gpus_per_worker=1` instead.") + if (config["input"] == "sampler" + and config["input_evaluation"] is not None): + raise ValueError( + "`input_evaluation` should not be set when input=sampler") + + def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, + config): + def session_creator(): + logger.debug("Creating TF session {}".format( + config["tf_session_args"])) + return tf.Session( + config=tf.ConfigProto(**config["tf_session_args"])) + + if isinstance(config["input"], FunctionType): + input_creator = config["input"] + elif config["input"] == "sampler": + input_creator = (lambda ioctx: ioctx.default_sampler_input()) + elif isinstance(config["input"], dict): + input_creator = (lambda ioctx: MixedInput(ioctx, config["input"])) + else: + input_creator = (lambda ioctx: JsonReader(ioctx, config["input"])) + + if isinstance(config["output"], FunctionType): + output_creator = config["output"] + elif config["output"] is None: + output_creator = (lambda ioctx: NoopOutput()) + elif config["output"] == "logdir": + output_creator = (lambda ioctx: JsonWriter( + ioctx, + ioctx.log_dir, + max_file_size=config["output_max_file_size"], + compress_columns=config["output_compress_columns"])) + else: + output_creator = (lambda ioctx: JsonWriter( + ioctx, + config["output"], + max_file_size=config["output_max_file_size"], + compress_columns=config["output_compress_columns"])) + + return cls( + env_creator, + self.config["multiagent"]["policy_graphs"] or policy_graph, + policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], + policies_to_train=self.config["multiagent"]["policies_to_train"], + tf_session_creator=(session_creator + if config["tf_session_args"] else None), + batch_steps=config["sample_batch_size"], + batch_mode=config["batch_mode"], + episode_horizon=config["horizon"], + preprocessor_pref=config["preprocessor_pref"], + sample_async=config["sample_async"], + compress_observations=config["compress_observations"], + num_envs=config["num_envs_per_worker"], + observation_filter=config["observation_filter"], + clip_rewards=config["clip_rewards"], + clip_actions=config["clip_actions"], + env_config=config["env_config"], + model_config=config["model"], + policy_config=config, + worker_index=worker_index, + monitor_path=self.logdir if config["monitor"] else None, + log_dir=self.logdir, + log_level=config["log_level"], + callbacks=config["callbacks"], + input_creator=input_creator, + input_evaluation_method=config["input_evaluation"], + output_creator=output_creator) def __getstate__(self): state = {} @@ -433,64 +569,11 @@ def __setstate__(self, state): if "optimizer" in state: self.optimizer.restore(state["optimizer"]) - def _save(self, checkpoint_dir): - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) - return checkpoint_path - - def _restore(self, checkpoint_path): - extra_data = pickle.load(open(checkpoint_path, "rb")) - self.__setstate__(extra_data) - -def get_agent_class(alg): - """Returns the class of a known agent given its name.""" - - if alg == "DDPG": - from ray.rllib.agents import ddpg - return ddpg.DDPGAgent - elif alg == "APEX_DDPG": - from ray.rllib.agents import ddpg - return ddpg.ApexDDPGAgent - elif alg == "PPO": - from ray.rllib.agents import ppo - return ppo.PPOAgent - elif alg == "ES": - from ray.rllib.agents import es - return es.ESAgent - elif alg == "ARS": - from ray.rllib.agents import ars - return ars.ARSAgent - elif alg == "DQN": - from ray.rllib.agents import dqn - return dqn.DQNAgent - elif alg == "APEX": - from ray.rllib.agents import dqn - return dqn.ApexAgent - elif alg == "A3C": - from ray.rllib.agents import a3c - return a3c.A3CAgent - elif alg == "A2C": - from ray.rllib.agents import a3c - return a3c.A2CAgent - elif alg == "PG": - from ray.rllib.agents import pg - return pg.PGAgent - elif alg == "IMPALA": - from ray.rllib.agents import impala - return impala.ImpalaAgent - elif alg == "script": - from ray.tune import script_runner - return script_runner.ScriptRunner - elif alg == "__fake": - from ray.rllib.agents.mock import _MockAgent - return _MockAgent - elif alg == "__sigmoid_fake_data": - from ray.rllib.agents.mock import _SigmoidFakeData - return _SigmoidFakeData - elif alg == "__parameter_tuning": - from ray.rllib.agents.mock import _ParameterTuningAgent - return _ParameterTuningAgent - else: - raise Exception(("Unknown algorithm {}.").format(alg)) +def _register_if_needed(env_object): + if isinstance(env_object, six.string_types): + return env_object + elif isinstance(env_object, type): + name = env_object.__name__ + register_env(name, lambda config: env_object(config)) + return name diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index 67e9ba2429e0f..aafcee7f4e362 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -17,6 +17,7 @@ from ray.rllib.agents.ars import optimizers from ray.rllib.agents.ars import policies from ray.rllib.agents.ars import utils +from ray.rllib.utils.annotations import override from ray.rllib.utils import FilterManager logger = logging.getLogger(__name__) @@ -161,6 +162,7 @@ class ARSAgent(Agent): _agent_name = "ARS" _default_config = DEFAULT_CONFIG + @override(Agent) def _init(self): env = self.env_creator(self.config["env_config"]) from ray.rllib import models @@ -193,28 +195,7 @@ def _init(self): self.reward_list = [] self.tstart = time.time() - def _collect_results(self, theta_id, min_episodes): - num_episodes, num_timesteps = 0, 0 - results = [] - while num_episodes < min_episodes: - logger.info( - "Collected {} episodes {} timesteps so far this iter".format( - num_episodes, num_timesteps)) - rollout_ids = [ - worker.do_rollouts.remote(theta_id) for worker in self.workers - ] - # Get the results of the rollouts. - for result in ray.get(rollout_ids): - results.append(result) - # Update the number of episodes and the number of timesteps - # keeping in mind that result.noisy_lengths is a list of lists, - # where the inner lists have length 2. - num_episodes += sum(len(pair) for pair in result.noisy_lengths) - num_timesteps += sum( - sum(pair) for pair in result.noisy_lengths) - - return results, num_episodes, num_timesteps - + @override(Agent) def _train(self): config = self.config @@ -310,11 +291,38 @@ def _train(self): return result + @override(Agent) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for w in self.workers: w.__ray_terminate__.remote() + @override(Agent) + def compute_action(self, observation): + return self.policy.compute(observation, update=True)[0] + + def _collect_results(self, theta_id, min_episodes): + num_episodes, num_timesteps = 0, 0 + results = [] + while num_episodes < min_episodes: + logger.debug( + "Collected {} episodes {} timesteps so far this iter".format( + num_episodes, num_timesteps)) + rollout_ids = [ + worker.do_rollouts.remote(theta_id) for worker in self.workers + ] + # Get the results of the rollouts. + for result in ray.get(rollout_ids): + results.append(result) + # Update the number of episodes and the number of timesteps + # keeping in mind that result.noisy_lengths is a list of lists, + # where the inner lists have length 2. + num_episodes += sum(len(pair) for pair in result.noisy_lengths) + num_timesteps += sum( + sum(pair) for pair in result.noisy_lengths) + + return results, num_episodes, num_timesteps + def __getstate__(self): return { "weights": self.policy.get_weights(), @@ -329,6 +337,3 @@ def __setstate__(self, state): FilterManager.synchronize({ "default": self.policy.get_filter() }, self.workers) - - def compute_action(self, observation): - return self.policy.compute(observation, update=True)[0] diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index c1699364a1398..6b3465013da36 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -3,6 +3,7 @@ from __future__ import print_function from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG as DDPG_CONFIG +from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts APEX_DDPG_DEFAULT_CONFIG = merge_dicts( @@ -42,6 +43,7 @@ class ApexDDPGAgent(DDPGAgent): _agent_name = "APEX_DDPG" _default_config = APEX_DDPG_DEFAULT_CONFIG + @override(DDPGAgent) def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index 564d8e12b9131..04aba0e3ee956 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -5,14 +5,9 @@ from ray.rllib.agents.agent import with_common_config from ray.rllib.agents.dqn.dqn import DQNAgent from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph +from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule -OPTIMIZER_SHARED_CONFIGS = [ - "buffer_size", "prioritized_replay", "prioritized_replay_alpha", - "prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size", - "train_batch_size", "learning_starts" -] - # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ @@ -131,6 +126,7 @@ class DDPGAgent(DQNAgent): _default_config = DEFAULT_CONFIG _policy_graph = DDPGPolicyGraph + @override(DQNAgent) def _make_exploration_schedule(self, worker_index): # Override DQN's schedule to take into account `noise_scale` if self.config["per_worker_exploration"]: diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py index eb5f14c2d1c99..b8b625734793d 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py @@ -11,7 +11,9 @@ from ray.rllib.agents.dqn.dqn_policy_graph import _huber_loss, \ _minimize_and_clip, _scope_vars, _postprocess_dqn from ray.rllib.models import ModelCatalog +from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph A_SCOPE = "a_func" @@ -366,51 +368,11 @@ def __init__(self, observation_space, action_space, config): # Hard initial update self.update_target(tau=1.0) - def _build_q_network(self, obs, obs_space, actions): - q_net = QNetwork( - ModelCatalog.get_model({ - "obs": obs, - "is_training": self._get_is_training_placeholder(), - }, obs_space, 1, self.config["model"]), actions, - self.config["critic_hiddens"], - self.config["critic_hidden_activation"]) - return q_net.value, q_net.model - - def _build_p_network(self, obs, obs_space): - return PNetwork( - ModelCatalog.get_model({ - "obs": obs, - "is_training": self._get_is_training_placeholder(), - }, obs_space, 1, self.config["model"]), self.dim_actions, - self.config["actor_hiddens"], - self.config["actor_hidden_activation"]).action_scores - - def _build_action_network(self, p_values, stochastic, eps, - is_target=False): - return ActionNetwork( - p_values, self.low_action, self.high_action, stochastic, eps, - self.config["exploration_theta"], self.config["exploration_sigma"], - self.config["smooth_target_policy"], self.config["act_noise"], - is_target, self.config["target_noise"], - self.config["noise_clip"]).actions - - def _build_actor_critic_loss(self, - q_t, - q_tp1, - q_tp0, - twin_q_t=None, - twin_q_tp1=None): - return ActorCriticLoss( - q_t, q_tp1, q_tp0, self.importance_weights, self.rew_t, - self.done_mask, twin_q_t, twin_q_tp1, - self.config["actor_loss_coeff"], self.config["critic_loss_coeff"], - self.config["gamma"], self.config["n_step"], - self.config["use_huber"], self.config["huber_threshold"], - self.config["twin_q"]) - + @override(TFPolicyGraph) def optimizer(self): return tf.train.AdamOptimizer(learning_rate=self.config["lr"]) + @override(TFPolicyGraph) def gradients(self, optimizer): if self.config["grad_norm_clipping"] is not None: actor_grads_and_vars = _minimize_and_clip( @@ -438,23 +400,85 @@ def gradients(self, optimizer): grads_and_vars = actor_grads_and_vars + critic_grads_and_vars return grads_and_vars + @override(TFPolicyGraph) def extra_compute_action_feed_dict(self): return { self.stochastic: True, self.eps: self.cur_epsilon, } + @override(TFPolicyGraph) def extra_compute_grad_fetches(self): return { "td_error": self.loss.td_error, } + @override(PolicyGraph) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): return _postprocess_dqn(self, sample_batch) + @override(TFPolicyGraph) + def get_weights(self): + return self.variables.get_weights() + + @override(TFPolicyGraph) + def set_weights(self, weights): + self.variables.set_weights(weights) + + @override(PolicyGraph) + def get_state(self): + return [TFPolicyGraph.get_state(self), self.cur_epsilon] + + @override(PolicyGraph) + def set_state(self, state): + TFPolicyGraph.set_state(self, state[0]) + self.set_epsilon(state[1]) + + def _build_q_network(self, obs, obs_space, actions): + q_net = QNetwork( + ModelCatalog.get_model({ + "obs": obs, + "is_training": self._get_is_training_placeholder(), + }, obs_space, 1, self.config["model"]), actions, + self.config["critic_hiddens"], + self.config["critic_hidden_activation"]) + return q_net.value, q_net.model + + def _build_p_network(self, obs, obs_space): + return PNetwork( + ModelCatalog.get_model({ + "obs": obs, + "is_training": self._get_is_training_placeholder(), + }, obs_space, 1, self.config["model"]), self.dim_actions, + self.config["actor_hiddens"], + self.config["actor_hidden_activation"]).action_scores + + def _build_action_network(self, p_values, stochastic, eps, + is_target=False): + return ActionNetwork( + p_values, self.low_action, self.high_action, stochastic, eps, + self.config["exploration_theta"], self.config["exploration_sigma"], + self.config["smooth_target_policy"], self.config["act_noise"], + is_target, self.config["target_noise"], + self.config["noise_clip"]).actions + + def _build_actor_critic_loss(self, + q_t, + q_tp1, + q_tp0, + twin_q_t=None, + twin_q_tp1=None): + return ActorCriticLoss( + q_t, q_tp1, q_tp0, self.importance_weights, self.rew_t, + self.done_mask, twin_q_t, twin_q_tp1, + self.config["actor_loss_coeff"], self.config["critic_loss_coeff"], + self.config["gamma"], self.config["n_step"], + self.config["use_huber"], self.config["huber_threshold"], + self.config["twin_q"]) + def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): td_err = self.sess.run( @@ -480,16 +504,3 @@ def update_target(self, tau=None): def set_epsilon(self, epsilon): self.cur_epsilon = epsilon - - def get_weights(self): - return self.variables.get_weights() - - def set_weights(self, weights): - self.variables.set_weights(weights) - - def get_state(self): - return [TFPolicyGraph.get_state(self), self.cur_epsilon] - - def set_state(self, state): - TFPolicyGraph.set_state(self, state[0]) - self.set_epsilon(state[1]) diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index a6738d661e478..c9b15e0eca792 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -4,6 +4,7 @@ from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG from ray.rllib.utils import merge_dicts +from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -45,6 +46,7 @@ class ApexAgent(DQNAgent): _agent_name = "APEX" _default_config = APEX_DEFAULT_CONFIG + @override(DQNAgent) def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index cdace08653518..e6d0263c86972 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -7,7 +7,7 @@ from ray.rllib import optimizers from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph -from ray.rllib.evaluation.metrics import collect_metrics +from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule OPTIMIZER_SHARED_CONFIGS = [ @@ -117,11 +117,13 @@ class DQNAgent(Agent): _agent_name = "DQN" _default_config = DEFAULT_CONFIG _policy_graph = DQNPolicyGraph + _optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS + @override(Agent) def _init(self): # Update effective batch size to include n-step adjusted_batch_size = max(self.config["sample_batch_size"], - self.config["n_step"]) + self.config.get("n_step", 1)) self.config["sample_batch_size"] = adjusted_batch_size self.exploration0 = self._make_exploration_schedule(-1) @@ -130,7 +132,7 @@ def _init(self): for i in range(self.config["num_workers"]) ] - for k in OPTIMIZER_SHARED_CONFIGS: + for k in self._optimizer_shared_configs: if self._agent_name != "DQN" and k in [ "schedule_max_timesteps", "beta_annealing_fraction", "final_prioritized_replay_beta" @@ -160,43 +162,12 @@ def create_remote_evaluators(): # Create the remote evaluators *after* the replay actors if self.remote_evaluators is None: self.remote_evaluators = create_remote_evaluators() - self.optimizer.set_evaluators(self.remote_evaluators) + self.optimizer._set_evaluators(self.remote_evaluators) self.last_target_update_ts = 0 self.num_target_updates = 0 - def _make_exploration_schedule(self, worker_index): - # Use either a different `eps` per worker, or a linear schedule. - if self.config["per_worker_exploration"]: - assert self.config["num_workers"] > 1, \ - "This requires multiple workers" - if worker_index >= 0: - exponent = ( - 1 + - worker_index / float(self.config["num_workers"] - 1) * 7) - return ConstantSchedule(0.4**exponent) - else: - # local ev should have zero exploration so that eval rollouts - # run properly - return ConstantSchedule(0.0) - return LinearSchedule( - schedule_timesteps=int(self.config["exploration_fraction"] * - self.config["schedule_max_timesteps"]), - initial_p=1.0, - final_p=self.config["exploration_final_eps"]) - - @property - def global_timestep(self): - return self.optimizer.num_steps_sampled - - def update_target_if_needed(self): - if self.global_timestep - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.local_evaluator.foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.global_timestep - self.num_target_updates += 1 - + @override(Agent) def _train(self): start_timestep = self.global_timestep @@ -220,14 +191,12 @@ def _train(self): if self.config["per_worker_exploration"]: # Only collect metrics from the third of workers with lowest eps - result = collect_metrics( - self.local_evaluator, - self.remote_evaluators[-len(self.remote_evaluators) // 3:], - timeout_seconds=self.config["collect_metrics_timeout"]) + result = self.optimizer.collect_metrics( + timeout_seconds=self.config["collect_metrics_timeout"], + selected_evaluators=self.remote_evaluators[ + -len(self.remote_evaluators) // 3:]) else: - result = collect_metrics( - self.local_evaluator, - self.remote_evaluators, + result = self.optimizer.collect_metrics( timeout_seconds=self.config["collect_metrics_timeout"]) result.update( @@ -239,6 +208,38 @@ def _train(self): }, **self.optimizer.stats())) return result + def update_target_if_needed(self): + if self.global_timestep - self.last_target_update_ts > \ + self.config["target_network_update_freq"]: + self.local_evaluator.foreach_trainable_policy( + lambda p, _: p.update_target()) + self.last_target_update_ts = self.global_timestep + self.num_target_updates += 1 + + @property + def global_timestep(self): + return self.optimizer.num_steps_sampled + + def _make_exploration_schedule(self, worker_index): + # Use either a different `eps` per worker, or a linear schedule. + if self.config["per_worker_exploration"]: + assert self.config["num_workers"] > 1, \ + "This requires multiple workers" + if worker_index >= 0: + exponent = ( + 1 + + worker_index / float(self.config["num_workers"] - 1) * 7) + return ConstantSchedule(0.4**exponent) + else: + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) + return LinearSchedule( + schedule_timesteps=int(self.config["exploration_fraction"] * + self.config["schedule_max_timesteps"]), + initial_p=1.0, + final_p=self.config["exploration_final_eps"]) + def __getstate__(self): state = Agent.__getstate__(self) state.update({ diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index c883ef25067dc..625e577fff164 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -10,7 +10,9 @@ import ray from ray.rllib.models import ModelCatalog from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph Q_SCOPE = "q_func" @@ -253,6 +255,10 @@ def __init__(self, self.td_error = tf.nn.softmax_cross_entropy_with_logits( labels=m, logits=q_logits_t_selected) self.loss = tf.reduce_mean(self.td_error * importance_weights) + self.stats = { + # TODO: better Q stats for dist dqn + "mean_td_error": tf.reduce_mean(self.td_error), + } else: q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best @@ -264,6 +270,12 @@ def __init__(self, q_t_selected - tf.stop_gradient(q_t_selected_target)) self.loss = tf.reduce_mean( importance_weights * _huber_loss(self.td_error)) + self.stats = { + "mean_q": tf.reduce_mean(q_t_selected), + "min_q": tf.reduce_min(q_t_selected), + "max_q": tf.reduce_max(q_t_selected), + "mean_td_error": tf.reduce_mean(self.td_error), + } class DQNPolicyGraph(TFPolicyGraph): @@ -380,34 +392,13 @@ def __init__(self, observation_space, action_space, config): update_ops=q_batchnorm_update_ops) self.sess.run(tf.global_variables_initializer()) - def _build_q_network(self, obs, space): - qnet = QNetwork( - ModelCatalog.get_model({ - "obs": obs, - "is_training": self._get_is_training_placeholder(), - }, space, self.num_actions, self.config["model"]), - self.num_actions, self.config["dueling"], self.config["hiddens"], - self.config["noisy"], self.config["num_atoms"], - self.config["v_min"], self.config["v_max"], self.config["sigma0"]) - return qnet.value, qnet.logits, qnet.dist, qnet.model - - def _build_q_value_policy(self, q_values): - return QValuePolicy(q_values, self.cur_observations, self.num_actions, - self.stochastic, self.eps).action - - def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best, - q_dist_tp1_best): - return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, - q_dist_tp1_best, self.importance_weights, self.rew_t, - self.done_mask, self.config["gamma"], - self.config["n_step"], self.config["num_atoms"], - self.config["v_min"], self.config["v_max"]) - + @override(TFPolicyGraph) def optimizer(self): return tf.train.AdamOptimizer( learning_rate=self.config["lr"], epsilon=self.config["adam_epsilon"]) + @override(TFPolicyGraph) def gradients(self, optimizer): if self.config["grad_norm_clipping"] is not None: grads_and_vars = _minimize_and_clip( @@ -421,23 +412,36 @@ def gradients(self, optimizer): grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] return grads_and_vars + @override(TFPolicyGraph) def extra_compute_action_feed_dict(self): return { self.stochastic: True, self.eps: self.cur_epsilon, } + @override(TFPolicyGraph) def extra_compute_grad_fetches(self): return { "td_error": self.loss.td_error, + "stats": self.loss.stats, } + @override(PolicyGraph) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, episode=None): return _postprocess_dqn(self, sample_batch) + @override(PolicyGraph) + def get_state(self): + return [TFPolicyGraph.get_state(self), self.cur_epsilon] + + @override(PolicyGraph) + def set_state(self, state): + TFPolicyGraph.set_state(self, state[0]) + self.set_epsilon(state[1]) + def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): td_err = self.sess.run( @@ -458,15 +462,31 @@ def update_target(self): def set_epsilon(self, epsilon): self.cur_epsilon = epsilon - def get_state(self): - return [TFPolicyGraph.get_state(self), self.cur_epsilon] + def _build_q_network(self, obs, space): + qnet = QNetwork( + ModelCatalog.get_model({ + "obs": obs, + "is_training": self._get_is_training_placeholder(), + }, space, self.num_actions, self.config["model"]), + self.num_actions, self.config["dueling"], self.config["hiddens"], + self.config["noisy"], self.config["num_atoms"], + self.config["v_min"], self.config["v_max"], self.config["sigma0"]) + return qnet.value, qnet.logits, qnet.dist, qnet.model - def set_state(self, state): - TFPolicyGraph.set_state(self, state[0]) - self.set_epsilon(state[1]) + def _build_q_value_policy(self, q_values): + return QValuePolicy(q_values, self.cur_observations, self.num_actions, + self.stochastic, self.eps).action + + def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best, + q_dist_tp1_best): + return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, + q_dist_tp1_best, self.importance_weights, self.rew_t, + self.done_mask, self.config["gamma"], + self.config["n_step"], self.config["num_atoms"], + self.config["v_min"], self.config["v_max"]) -def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): +def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): """Rewrites the given trajectory fragments to encode n-step rewards. reward[i] = ( @@ -499,9 +519,9 @@ def _postprocess_dqn(policy_graph, sample_batch): # N-step Q adjustments if policy_graph.config["n_step"] > 1: - adjust_nstep(policy_graph.config["n_step"], - policy_graph.config["gamma"], obs, actions, rewards, - new_obs, dones) + _adjust_nstep(policy_graph.config["n_step"], + policy_graph.config["gamma"], obs, actions, rewards, + new_obs, dones) batch = SampleBatch({ "obs": obs, diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index 550296812025c..4aa4a86aac889 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -16,6 +16,7 @@ from ray.rllib.agents.es import optimizers from ray.rllib.agents.es import policies from ray.rllib.agents.es import utils +from ray.rllib.utils.annotations import override from ray.rllib.utils import FilterManager logger = logging.getLogger(__name__) @@ -167,6 +168,7 @@ class ESAgent(Agent): _agent_name = "ES" _default_config = DEFAULT_CONFIG + @override(Agent) def _init(self): policy_params = {"action_noise_std": 0.01} @@ -198,28 +200,7 @@ def _init(self): self.reward_list = [] self.tstart = time.time() - def _collect_results(self, theta_id, min_episodes, min_timesteps): - num_episodes, num_timesteps = 0, 0 - results = [] - while num_episodes < min_episodes or num_timesteps < min_timesteps: - logger.info( - "Collected {} episodes {} timesteps so far this iter".format( - num_episodes, num_timesteps)) - rollout_ids = [ - worker.do_rollouts.remote(theta_id) for worker in self.workers - ] - # Get the results of the rollouts. - for result in ray.get(rollout_ids): - results.append(result) - # Update the number of episodes and the number of timesteps - # keeping in mind that result.noisy_lengths is a list of lists, - # where the inner lists have length 2. - num_episodes += sum(len(pair) for pair in result.noisy_lengths) - num_timesteps += sum( - sum(pair) for pair in result.noisy_lengths) - - return results, num_episodes, num_timesteps - + @override(Agent) def _train(self): config = self.config @@ -307,11 +288,38 @@ def _train(self): return result + @override(Agent) + def compute_action(self, observation): + return self.policy.compute(observation, update=False)[0] + + @override(Agent) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for w in self.workers: w.__ray_terminate__.remote() + def _collect_results(self, theta_id, min_episodes, min_timesteps): + num_episodes, num_timesteps = 0, 0 + results = [] + while num_episodes < min_episodes or num_timesteps < min_timesteps: + logger.info( + "Collected {} episodes {} timesteps so far this iter".format( + num_episodes, num_timesteps)) + rollout_ids = [ + worker.do_rollouts.remote(theta_id) for worker in self.workers + ] + # Get the results of the rollouts. + for result in ray.get(rollout_ids): + results.append(result) + # Update the number of episodes and the number of timesteps + # keeping in mind that result.noisy_lengths is a list of lists, + # where the inner lists have length 2. + num_episodes += sum(len(pair) for pair in result.noisy_lengths) + num_timesteps += sum( + sum(pair) for pair in result.noisy_lengths) + + return results, num_episodes, num_timesteps + def __getstate__(self): return { "weights": self.policy.get_weights(), @@ -326,6 +334,3 @@ def __setstate__(self, state): FilterManager.synchronize({ "default": self.policy.get_filter() }, self.workers) - - def compute_action(self, observation): - return self.policy.compute(observation, update=False)[0] diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index 45af922000715..bab04f48239ff 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -8,6 +8,7 @@ from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.optimizers import AsyncSamplesOptimizer +from ray.rllib.utils.annotations import override OPTIMIZER_SHARED_CONFIGS = [ "lr", @@ -17,10 +18,11 @@ "train_batch_size", "replay_buffer_num_slots", "replay_proportion", - "num_parallel_data_loaders", - "grad_clip", + "num_data_loader_buffers", "max_sample_requests_in_flight_per_worker", "broadcast_interval", + "num_sgd_iter", + "minibatch_buffer_size", ] # yapf: disable @@ -32,6 +34,17 @@ "vtrace_clip_pg_rho_threshold": 1.0, # System params. + # + # == Overview of data flow in IMPALA == + # 1. Policy evaluation in parallel across `num_workers` actors produces + # batches of size `sample_batch_size * num_envs_per_worker`. + # 2. If enabled, the replay buffer stores and produces batches of size + # `sample_batch_size * num_envs_per_worker`. + # 3. If enabled, the minibatch ring buffer stores and replays batches of + # size `train_batch_size` up to `num_sgd_iter` times per batch. + # 4. The learner thread executes data parallel SGD across `num_gpus` GPUs + # on batches of size `train_batch_size`. + # "sample_batch_size": 50, "train_batch_size": 500, "min_iter_time_s": 10, @@ -39,18 +52,23 @@ # number of GPUs the learner should use. "num_gpus": 1, # set >1 to load data into GPUs in parallel. Increases GPU memory usage - # proportionally with the number of loaders. - "num_parallel_data_loaders": 1, - # level of queuing for sampling. - "max_sample_requests_in_flight_per_worker": 2, - # max number of workers to broadcast one set of weights to - "broadcast_interval": 1, + # proportionally with the number of buffers. + "num_data_loader_buffers": 1, + # how many train batches should be retained for minibatching. This conf + # only has an effect if `num_sgd_iter > 1`. + "minibatch_buffer_size": 1, + # number of passes to make over each train batch + "num_sgd_iter": 1, # set >0 to enable experience replay. Saved samples will be replayed with # a p:1 proportion to new data samples. "replay_proportion": 0.0, # number of sample batches to store for replay. The number of transitions # saved total will be (replay_buffer_num_slots * sample_batch_size). "replay_buffer_num_slots": 100, + # level of queuing for sampling. + "max_sample_requests_in_flight_per_worker": 2, + # max number of workers to broadcast one set of weights to + "broadcast_interval": 1, # Learning params. "grad_clip": 40.0, @@ -77,6 +95,7 @@ class ImpalaAgent(Agent): _default_config = DEFAULT_CONFIG _policy_graph = VTracePolicyGraph + @override(Agent) def _init(self): for k in OPTIMIZER_SHARED_CONFIGS: if k not in self.config["optimizer"]: @@ -93,6 +112,7 @@ def _init(self): self.remote_evaluators, self.config["optimizer"]) + @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled start = time.time() diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index cfa2f1373aae2..12c0c30fb3c78 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -11,9 +11,11 @@ import ray from ray.rllib.agents.impala import vtrace +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.models.action_dist import Categorical @@ -166,7 +168,7 @@ def to_batches(tensor): mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) mask = tf.reshape(mask, [-1]) else: - mask = tf.ones_like(rewards) + mask = tf.ones_like(rewards, dtype=tf.bool) # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. self.loss = VTraceLoss( @@ -242,6 +244,15 @@ def to_batches(tensor): }, } + @override(TFPolicyGraph) + def copy(self, existing_inputs): + return VTracePolicyGraph( + self.observation_space, + self.action_space, + self.config, + existing_inputs=existing_inputs) + + @override(TFPolicyGraph) def optimizer(self): if self.config["opt_type"] == "adam": return tf.train.AdamOptimizer(self.cur_lr) @@ -250,18 +261,22 @@ def optimizer(self): self.config["momentum"], self.config["epsilon"]) + @override(TFPolicyGraph) def gradients(self, optimizer): grads = tf.gradients(self.loss.total_loss, self.var_list) self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) clipped_grads = list(zip(self.grads, self.var_list)) return clipped_grads + @override(TFPolicyGraph) def extra_compute_action_fetches(self): return {"behaviour_logits": self.model.outputs} + @override(TFPolicyGraph) def extra_compute_grad_fetches(self): return self.stats_fetches + @override(PolicyGraph) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -269,12 +284,6 @@ def postprocess_trajectory(self, del sample_batch.data["new_obs"] # not used, so save some bandwidth return sample_batch + @override(PolicyGraph) def get_initial_state(self): return self.model.state_init - - def copy(self, existing_inputs): - return VTracePolicyGraph( - self.observation_space, - self.action_space, - self.config, - existing_inputs=existing_inputs) diff --git a/python/ray/rllib/agents/mock.py b/python/ray/rllib/agents/mock.py index f4bf909918095..89d17848ec5d0 100644 --- a/python/ray/rllib/agents/mock.py +++ b/python/ray/rllib/agents/mock.py @@ -100,3 +100,16 @@ def _train(self): timesteps_this_iter=self.config["iter_timesteps"], time_this_iter_s=self.config["iter_time"], info={}) + + +def _agent_import_failed(trace): + """Returns dummy agent class for if PyTorch etc. is not installed.""" + + class _AgentImportFailed(Agent): + _agent_name = "AgentImportFailed" + _default_config = with_common_config({}) + + def _setup(self, config): + raise ImportError(trace) + + return _AgentImportFailed diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index 925cbc1a10aa4..ba525887ec0ec 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -5,6 +5,7 @@ from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph from ray.rllib.optimizers import SyncSamplesOptimizer +from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -29,15 +30,19 @@ class PGAgent(Agent): _default_config = DEFAULT_CONFIG _policy_graph = PGPolicyGraph + @override(Agent) def _init(self): self.local_evaluator = self.make_local_evaluator( self.env_creator, self._policy_graph) self.remote_evaluators = self.make_remote_evaluators( self.env_creator, self._policy_graph, self.config["num_workers"]) - self.optimizer = SyncSamplesOptimizer(self.local_evaluator, - self.remote_evaluators, - self.config["optimizer"]) + optimizer_config = dict( + self.config["optimizer"], + **{"train_batch_size": self.config["train_batch_size"]}) + self.optimizer = SyncSamplesOptimizer( + self.local_evaluator, self.remote_evaluators, optimizer_config) + @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled self.optimizer.step() diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index 2a342c117fb3f..59e9a9effc12b 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -7,7 +7,9 @@ import ray from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.utils.annotations import override class PGLoss(object): @@ -75,6 +77,7 @@ def __init__(self, obs_space, action_space, config): max_seq_len=config["model"]["max_seq_len"]) sess.run(tf.global_variables_initializer()) + @override(PolicyGraph) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -83,5 +86,6 @@ def postprocess_trajectory(self, return compute_advantages( sample_batch, 0.0, self.config["gamma"], use_gae=False) + @override(PolicyGraph) def get_initial_state(self): return self.model.state_init diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index d5e50832f4515..59eba8aced580 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -7,6 +7,7 @@ from ray.rllib.agents import Agent, with_common_config from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer +from ray.rllib.utils.annotations import override logger = logging.getLogger(__name__) @@ -52,6 +53,9 @@ # Uses the sync samples optimizer instead of the multi-gpu one. This does # not support minibatches. "simple_optimizer": False, + # (Deprecated) Use the sampling behavior as of 0.6, which launches extra + # sampling tasks for performance but can waste a large portion of samples. + "straggler_mitigation": False, }) # __sphinx_doc_end__ # yapf: enable @@ -64,6 +68,7 @@ class PPOAgent(Agent): _default_config = DEFAULT_CONFIG _policy_graph = PPOPolicyGraph + @override(Agent) def _init(self): self._validate_config() self.local_evaluator = self.make_local_evaluator( @@ -82,41 +87,15 @@ def _init(self): "sgd_batch_size": self.config["sgd_minibatch_size"], "num_sgd_iter": self.config["num_sgd_iter"], "num_gpus": self.config["num_gpus"], + "sample_batch_size": self.config["sample_batch_size"], + "num_envs_per_worker": self.config["num_envs_per_worker"], "train_batch_size": self.config["train_batch_size"], "standardize_fields": ["advantages"], + "straggler_mitigation": ( + self.config["straggler_mitigation"]), }) - def _validate_config(self): - waste_ratio = ( - self.config["sample_batch_size"] * self.config["num_workers"] / - self.config["train_batch_size"]) - if waste_ratio > 1: - msg = ("sample_batch_size * num_workers >> train_batch_size. " - "This means that many steps will be discarded. Consider " - "reducing sample_batch_size, or increase train_batch_size.") - if waste_ratio > 1.5: - raise ValueError(msg) - else: - logger.warn(msg) - if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]: - raise ValueError( - "Minibatch size {} must be <= train batch size {}.".format( - self.config["sgd_minibatch_size"], - self.config["train_batch_size"])) - if (self.config["batch_mode"] == "truncate_episodes" - and not self.config["use_gae"]): - raise ValueError( - "Episode truncation is not supported without a value function") - if (self.config["multiagent"]["policy_graphs"] - and not self.config["simple_optimizer"]): - logger.warn("forcing simple_optimizer=True in multi-agent mode") - self.config["simple_optimizer"] = True - if self.config["observation_filter"] != "NoFilter": - # TODO(ekl): consider setting the default to be NoFilter - logger.warn( - "By default, observations will be normalized with {}".format( - self.config["observation_filter"])) - + @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled fetches = self.optimizer.step() @@ -134,3 +113,25 @@ def _train(self): timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=dict(fetches, **res.get("info", {}))) return res + + def _validate_config(self): + if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]: + raise ValueError( + "Minibatch size {} must be <= train batch size {}.".format( + self.config["sgd_minibatch_size"], + self.config["train_batch_size"])) + if (self.config["batch_mode"] == "truncate_episodes" + and not self.config["use_gae"]): + raise ValueError( + "Episode truncation is not supported without a value function") + if (self.config["multiagent"]["policy_graphs"] + and not self.config["simple_optimizer"]): + logger.info( + "In multi-agent mode, policies will be optimized sequentially " + "by the multi-GPU optimizer. Consider setting " + "simple_optimizer=True if this doesn't work for you.") + if self.config["observation_filter"] != "NoFilter": + # TODO(ekl): consider setting the default to be NoFilter + logger.warning( + "By default, observations will be normalized with {}".format( + self.config["observation_filter"])) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index 3762f16f9084e..80ec01ea50300 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -6,9 +6,11 @@ import ray from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override from ray.rllib.utils.explained_variance import explained_variance @@ -204,7 +206,7 @@ def __init__(self, mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) mask = tf.reshape(mask, [-1]) else: - mask = tf.ones_like(adv_ph) + mask = tf.ones_like(adv_ph, dtype=tf.bool) self.loss_obj = PPOLoss( action_space, @@ -245,6 +247,7 @@ def __init__(self, self.explained_variance = explained_variance(value_targets_ph, self.value_function) self.stats_fetches = { + "cur_kl_coeff": self.kl_coeff, "cur_lr": tf.cast(self.cur_lr, tf.float64), "total_loss": self.loss_obj.loss, "policy_loss": self.loss_obj.mean_policy_loss, @@ -254,6 +257,7 @@ def __init__(self, "entropy": self.loss_obj.mean_entropy } + @override(TFPolicyGraph) def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" return PPOPolicyGraph( @@ -262,29 +266,7 @@ def copy(self, existing_inputs): self.config, existing_inputs=existing_inputs) - def extra_compute_action_fetches(self): - return {"vf_preds": self.value_function, "logits": self.logits} - - def extra_compute_grad_fetches(self): - return self.stats_fetches - - def update_kl(self, sampled_kl): - if sampled_kl > 2.0 * self.kl_target: - self.kl_coeff_val *= 1.5 - elif sampled_kl < 0.5 * self.kl_target: - self.kl_coeff_val *= 0.5 - self.kl_coeff.load(self.kl_coeff_val, session=self.sess) - return self.kl_coeff_val - - def value(self, ob, *args): - feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} - assert len(args) == len(self.model.state_in), \ - (args, self.model.state_in) - for k, v in zip(self.model.state_in, args): - feed_dict[k] = v - vf = self.sess.run(self.value_function, feed_dict) - return vf[0] - + @override(PolicyGraph) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -296,7 +278,7 @@ def postprocess_trajectory(self, next_state = [] for i in range(len(self.model.state_in)): next_state.append([sample_batch["state_out_{}".format(i)][-1]]) - last_r = self.value(sample_batch["new_obs"][-1], *next_state) + last_r = self._value(sample_batch["new_obs"][-1], *next_state) batch = compute_advantages( sample_batch, last_r, @@ -305,9 +287,36 @@ def postprocess_trajectory(self, use_gae=self.config["use_gae"]) return batch + @override(TFPolicyGraph) def gradients(self, optimizer): return optimizer.compute_gradients( self._loss, colocate_gradients_with_ops=True) + @override(PolicyGraph) def get_initial_state(self): return self.model.state_init + + @override(TFPolicyGraph) + def extra_compute_action_fetches(self): + return {"vf_preds": self.value_function, "logits": self.logits} + + @override(TFPolicyGraph) + def extra_compute_grad_fetches(self): + return self.stats_fetches + + def update_kl(self, sampled_kl): + if sampled_kl > 2.0 * self.kl_target: + self.kl_coeff_val *= 1.5 + elif sampled_kl < 0.5 * self.kl_target: + self.kl_coeff_val *= 0.5 + self.kl_coeff.load(self.kl_coeff_val, session=self.sess) + return self.kl_coeff_val + + def _value(self, ob, *args): + feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} + assert len(args) == len(self.model.state_in), \ + (args, self.model.state_in) + for k, v in zip(self.model.state_in, args): + feed_dict[k] = v + vf = self.sess.run(self.value_function, feed_dict) + return vf[0] diff --git a/python/ray/rllib/agents/ppo/rollout.py b/python/ray/rllib/agents/ppo/rollout.py deleted file mode 100644 index 4084e9ba063af..0000000000000 --- a/python/ray/rllib/agents/ppo/rollout.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.rllib.evaluation.sample_batch import SampleBatch - - -def collect_samples(agents, train_batch_size): - num_timesteps_so_far = 0 - trajectories = [] - # This variable maps the object IDs of trajectories that are currently - # computed to the agent that they are computed on; we start some initial - # tasks here. - - agent_dict = {} - - for agent in agents: - fut_sample = agent.sample.remote() - agent_dict[fut_sample] = agent - - while num_timesteps_so_far < train_batch_size: - # TODO(pcm): Make wait support arbitrary iterators and remove the - # conversion to list here. - [fut_sample], _ = ray.wait(list(agent_dict)) - agent = agent_dict.pop(fut_sample) - # Start task with next trajectory and record it in the dictionary. - fut_sample2 = agent.sample.remote() - agent_dict[fut_sample2] = agent - - next_sample = ray.get(fut_sample) - num_timesteps_so_far += next_sample.count - trajectories.append(next_sample) - return SampleBatch.concat_samples(trajectories) diff --git a/python/ray/rllib/agents/qmix/README.md b/python/ray/rllib/agents/qmix/README.md new file mode 100644 index 0000000000000..d023a7fc70dd6 --- /dev/null +++ b/python/ray/rllib/agents/qmix/README.md @@ -0,0 +1 @@ +Code in this package is adapted from https://github.com/oxwhirl/pymarl_alpha. diff --git a/python/ray/rllib/agents/qmix/__init__.py b/python/ray/rllib/agents/qmix/__init__.py new file mode 100644 index 0000000000000..a5a1e4993b675 --- /dev/null +++ b/python/ray/rllib/agents/qmix/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG +from ray.rllib.agents.qmix.apex import ApexQMixAgent + +__all__ = ["QMixAgent", "ApexQMixAgent", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/qmix/apex.py b/python/ray/rllib/agents/qmix/apex.py new file mode 100644 index 0000000000000..9f471faefceb6 --- /dev/null +++ b/python/ray/rllib/agents/qmix/apex.py @@ -0,0 +1,55 @@ +"""Experimental: scalable Ape-X variant of QMIX""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG as QMIX_CONFIG +from ray.rllib.utils.annotations import override +from ray.rllib.utils import merge_dicts + +APEX_QMIX_DEFAULT_CONFIG = merge_dicts( + QMIX_CONFIG, # see also the options in qmix.py, which are also supported + { + "optimizer_class": "AsyncReplayOptimizer", + "optimizer": merge_dicts( + QMIX_CONFIG["optimizer"], + { + "max_weight_sync_delay": 400, + "num_replay_buffer_shards": 4, + "batch_replay": True, # required for RNN. Disables prio. + "debug": False + }), + "num_gpus": 0, + "num_workers": 32, + "buffer_size": 2000000, + "learning_starts": 50000, + "train_batch_size": 512, + "sample_batch_size": 50, + "max_weight_sync_delay": 400, + "target_network_update_freq": 500000, + "timesteps_per_iteration": 25000, + "per_worker_exploration": True, + "min_iter_time_s": 30, + }, +) + + +class ApexQMixAgent(QMixAgent): + """QMIX variant that uses the Ape-X distributed policy optimizer. + + By default, this is configured for a large single node (32 cores). For + running in a large cluster, increase the `num_workers` config var. + """ + + _agent_name = "APEX_QMIX" + _default_config = APEX_QMIX_DEFAULT_CONFIG + + @override(QMixAgent) + def update_target_if_needed(self): + # Ape-X updates based on num steps trained, not sampled + if self.optimizer.num_steps_trained - self.last_target_update_ts > \ + self.config["target_network_update_freq"]: + self.local_evaluator.for_policy(lambda p: p.update_target()) + self.last_target_update_ts = self.optimizer.num_steps_trained + self.num_target_updates += 1 diff --git a/python/ray/rllib/agents/qmix/mixers.py b/python/ray/rllib/agents/qmix/mixers.py new file mode 100644 index 0000000000000..3f8fbbce4f9b0 --- /dev/null +++ b/python/ray/rllib/agents/qmix/mixers.py @@ -0,0 +1,64 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class VDNMixer(nn.Module): + def __init__(self): + super(VDNMixer, self).__init__() + + def forward(self, agent_qs, batch): + return th.sum(agent_qs, dim=2, keepdim=True) + + +class QMixer(nn.Module): + def __init__(self, n_agents, state_shape, mixing_embed_dim): + super(QMixer, self).__init__() + + self.n_agents = n_agents + self.embed_dim = mixing_embed_dim + self.state_dim = int(np.prod(state_shape)) + + self.hyper_w_1 = nn.Linear(self.state_dim, + self.embed_dim * self.n_agents) + self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) + + # State dependent bias for hidden layer + self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) + + # V(s) instead of a bias for the last layers + self.V = nn.Sequential( + nn.Linear(self.state_dim, self.embed_dim), nn.ReLU(), + nn.Linear(self.embed_dim, 1)) + + def forward(self, agent_qs, states): + """Forward pass for the mixer. + + Arguments: + agent_qs: Tensor of shape [B, T, n_agents, n_actions] + states: Tensor of shape [B, T, state_dim] + """ + bs = agent_qs.size(0) + states = states.reshape(-1, self.state_dim) + agent_qs = agent_qs.view(-1, 1, self.n_agents) + # First layer + w1 = th.abs(self.hyper_w_1(states)) + b1 = self.hyper_b_1(states) + w1 = w1.view(-1, self.n_agents, self.embed_dim) + b1 = b1.view(-1, 1, self.embed_dim) + hidden = F.elu(th.bmm(agent_qs, w1) + b1) + # Second layer + w_final = th.abs(self.hyper_w_final(states)) + w_final = w_final.view(-1, self.embed_dim, 1) + # State-dependent bias + v = self.V(states).view(-1, 1, 1) + # Compute final output + y = th.bmm(hidden, w_final) + v + # Reshape and return + q_tot = y.view(bs, -1, 1) + return q_tot diff --git a/python/ray/rllib/agents/qmix/model.py b/python/ray/rllib/agents/qmix/model.py new file mode 100644 index 0000000000000..679e2c659a861 --- /dev/null +++ b/python/ray/rllib/agents/qmix/model.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from torch import nn +import torch.nn.functional as F + + +# TODO(ekl) we should have common models for pytorch like we do for TF +class RNNModel(nn.Module): + def __init__(self, obs_size, rnn_hidden_dim, n_actions): + nn.Module.__init__(self) + self.rnn_hidden_dim = rnn_hidden_dim + self.n_actions = n_actions + self.fc1 = nn.Linear(obs_size, rnn_hidden_dim) + self.rnn = nn.GRUCell(rnn_hidden_dim, rnn_hidden_dim) + self.fc2 = nn.Linear(rnn_hidden_dim, n_actions) + + def init_hidden(self): + # make hidden states on same device as model + return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_() + + def forward(self, inputs, hidden_state): + x = F.relu(self.fc1(inputs.float())) + h_in = hidden_state.reshape(-1, self.rnn_hidden_dim) + h = self.rnn(x, h_in) + q = self.fc2(h) + return q, h diff --git a/python/ray/rllib/agents/qmix/qmix.py b/python/ray/rllib/agents/qmix/qmix.py new file mode 100644 index 0000000000000..2bc4c6b23b4ca --- /dev/null +++ b/python/ray/rllib/agents/qmix/qmix.py @@ -0,0 +1,92 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.agents.agent import with_common_config +from ray.rllib.agents.dqn.dqn import DQNAgent +from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph + +# yapf: disable +# __sphinx_doc_begin__ +DEFAULT_CONFIG = with_common_config({ + # === QMix === + # Mixing network. Either "qmix", "vdn", or None + "mixer": "qmix", + # Size of the mixing network embedding + "mixing_embed_dim": 32, + # Whether to use Double_Q learning + "double_q": True, + # Optimize over complete episodes by default. + "batch_mode": "complete_episodes", + + # === Exploration === + # Max num timesteps for annealing schedules. Exploration is annealed from + # 1.0 to exploration_fraction over this number of timesteps scaled by + # exploration_fraction + "schedule_max_timesteps": 100000, + # Number of env steps to optimize for before returning + "timesteps_per_iteration": 1000, + # Fraction of entire training period over which the exploration rate is + # annealed + "exploration_fraction": 0.1, + # Final value of random action probability + "exploration_final_eps": 0.02, + # Update the target network every `target_network_update_freq` steps. + "target_network_update_freq": 500, + + # === Replay buffer === + # Size of the replay buffer in steps. + "buffer_size": 10000, + + # === Optimization === + # Learning rate for adam optimizer + "lr": 0.0005, + # RMSProp alpha + "optim_alpha": 0.99, + # RMSProp epsilon + "optim_eps": 0.00001, + # If not None, clip gradients during optimization at this value + "grad_norm_clipping": 10, + # How many steps of the model to sample before learning starts. + "learning_starts": 1000, + # Update the replay buffer with this many samples at once. Note that + # this setting applies per-worker if num_workers > 1. + "sample_batch_size": 4, + # Size of a batched sampled from replay buffer for training. Note that + # if async_updates is set, then each worker returns gradients for a + # batch of this size. + "train_batch_size": 32, + + # === Parallelism === + # Number of workers for collecting samples with. This only makes sense + # to increase if your environment is particularly slow to sample, or if + # you"re using the Async or Ape-X optimizers. + "num_workers": 0, + # Optimizer class to use. + "optimizer_class": "SyncBatchReplayOptimizer", + # Whether to use a distribution of epsilons across workers for exploration. + "per_worker_exploration": False, + # Whether to compute priorities on workers. + "worker_side_prioritization": False, + # Prevent iterations from going lower than this time span + "min_iter_time_s": 1, + + # === Model === + "model": { + "lstm_cell_size": 64, + "max_seq_len": 999999, + }, +}) +# __sphinx_doc_end__ +# yapf: enable + + +class QMixAgent(DQNAgent): + """QMix implementation in PyTorch.""" + + _agent_name = "QMIX" + _default_config = DEFAULT_CONFIG + _policy_graph = QMixPolicyGraph + _optimizer_shared_configs = [ + "learning_starts", "buffer_size", "train_batch_size" + ] diff --git a/python/ray/rllib/agents/qmix/qmix_policy_graph.py b/python/ray/rllib/agents/qmix/qmix_policy_graph.py new file mode 100644 index 0000000000000..445275f3f054c --- /dev/null +++ b/python/ray/rllib/agents/qmix/qmix_policy_graph.py @@ -0,0 +1,411 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from gym.spaces import Tuple, Discrete, Dict +import logging +import numpy as np +import torch as th +import torch.nn as nn +from torch.optim import RMSprop +from torch.distributions import Categorical + +import ray +from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer +from ray.rllib.agents.qmix.model import RNNModel +from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.models.action_dist import TupleActions +from ray.rllib.models.pytorch.misc import var_to_np +from ray.rllib.models.lstm import chop_into_sequences +from ray.rllib.models.model import _unpack_obs +from ray.rllib.models.preprocessors import get_preprocessor +from ray.rllib.env.constants import GROUP_REWARDS +from ray.rllib.utils.annotations import override + +logger = logging.getLogger(__name__) + + +class QMixLoss(nn.Module): + def __init__(self, + model, + target_model, + mixer, + target_mixer, + n_agents, + n_actions, + double_q=True, + gamma=0.99): + nn.Module.__init__(self) + self.model = model + self.target_model = target_model + self.mixer = mixer + self.target_mixer = target_mixer + self.n_agents = n_agents + self.n_actions = n_actions + self.double_q = double_q + self.gamma = gamma + + def forward(self, rewards, actions, terminated, mask, obs, action_mask): + """Forward pass of the loss. + + Arguments: + rewards: Tensor of shape [B, T-1, n_agents] + actions: Tensor of shape [B, T-1, n_agents] + terminated: Tensor of shape [B, T-1, n_agents] + mask: Tensor of shape [B, T-1, n_agents] + obs: Tensor of shape [B, T, n_agents, obs_size] + action_mask: Tensor of shape [B, T, n_agents, n_actions] + """ + + B, T = obs.size(0), obs.size(1) + + # Calculate estimated Q-Values + mac_out = [] + h = self.model.init_hidden().expand([B, self.n_agents, -1]) + for t in range(T): + q, h = _mac(self.model, obs[:, t], h) + mac_out.append(q) + mac_out = th.stack(mac_out, dim=1) # Concat over time + + # Pick the Q-Values for the actions taken -> [B * n_agents, T-1] + chosen_action_qvals = th.gather( + mac_out[:, :-1], dim=3, index=actions.unsqueeze(3)).squeeze(3) + + # Calculate the Q-Values necessary for the target + target_mac_out = [] + target_h = self.target_model.init_hidden().expand( + [B, self.n_agents, -1]) + for t in range(T): + target_q, target_h = _mac(self.target_model, obs[:, t], target_h) + target_mac_out.append(target_q) + + # We don't need the first timesteps Q-Value estimate for targets + target_mac_out = th.stack( + target_mac_out[1:], dim=1) # Concat across time + + # Mask out unavailable actions + target_mac_out[action_mask[:, 1:] == 0] = -9999999 + + # Max over target Q-Values + if self.double_q: + # Get actions that maximise live Q (for double q-learning) + mac_out[action_mask == 0] = -9999999 + cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1] + target_max_qvals = th.gather(target_mac_out, 3, + cur_max_actions).squeeze(3) + else: + target_max_qvals = target_mac_out.max(dim=3)[0] + + # Mix + if self.mixer is not None: + # TODO(ekl) add support for handling global state? This is just + # treating the stacked agent obs as the state. + chosen_action_qvals = self.mixer(chosen_action_qvals, obs[:, :-1]) + target_max_qvals = self.target_mixer(target_max_qvals, obs[:, 1:]) + + # Calculate 1-step Q-Learning targets + targets = rewards + self.gamma * (1 - terminated) * target_max_qvals + + # Td-error + td_error = (chosen_action_qvals - targets.detach()) + + mask = mask.expand_as(td_error) + + # 0-out the targets that came from padded data + masked_td_error = td_error * mask + + # Normal L2 loss, take mean over actual data + loss = (masked_td_error**2).sum() / mask.sum() + return loss, mask, masked_td_error, chosen_action_qvals, targets + + +class QMixPolicyGraph(PolicyGraph): + """QMix impl. Assumes homogeneous agents for now. + + You must use MultiAgentEnv.with_agent_groups() to group agents + together for QMix. This creates the proper Tuple obs/action spaces and + populates the '_group_rewards' info field. + + Action masking: to specify an action mask for individual agents, use a + dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}. + The mask space must be `Box(0, 1, (n_actions,))`. + """ + + def __init__(self, obs_space, action_space, config): + _validate(obs_space, action_space) + config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config) + self.config = config + self.observation_space = obs_space + self.action_space = action_space + self.n_agents = len(obs_space.original_space.spaces) + self.n_actions = action_space.spaces[0].n + self.h_size = config["model"]["lstm_cell_size"] + + agent_obs_space = obs_space.original_space.spaces[0] + if isinstance(agent_obs_space, Dict): + space_keys = set(agent_obs_space.spaces.keys()) + if space_keys != {"obs", "action_mask"}: + raise ValueError( + "Dict obs space for agent must have keyset " + "['obs', 'action_mask'], got {}".format(space_keys)) + mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) + if mask_shape != (self.n_actions, ): + raise ValueError("Action mask shape must be {}, got {}".format( + (self.n_actions, ), mask_shape)) + self.has_action_mask = True + self.obs_size = _get_size(agent_obs_space.spaces["obs"]) + else: + self.has_action_mask = False + self.obs_size = _get_size(agent_obs_space) + + self.model = RNNModel(self.obs_size, self.h_size, self.n_actions) + self.target_model = RNNModel(self.obs_size, self.h_size, + self.n_actions) + + # Setup the mixer network. + # The global state is just the stacked agent observations for now. + self.state_shape = [self.obs_size, self.n_agents] + if config["mixer"] is None: + self.mixer = None + self.target_mixer = None + elif config["mixer"] == "qmix": + self.mixer = QMixer(self.n_agents, self.state_shape, + config["mixing_embed_dim"]) + self.target_mixer = QMixer(self.n_agents, self.state_shape, + config["mixing_embed_dim"]) + elif config["mixer"] == "vdn": + self.mixer = VDNMixer() + self.target_mixer = VDNMixer() + else: + raise ValueError("Unknown mixer type {}".format(config["mixer"])) + + self.cur_epsilon = 1.0 + self.update_target() # initial sync + + # Setup optimizer + self.params = list(self.model.parameters()) + self.loss = QMixLoss(self.model, self.target_model, self.mixer, + self.target_mixer, self.n_agents, self.n_actions, + self.config["double_q"], self.config["gamma"]) + self.optimiser = RMSprop( + params=self.params, + lr=config["lr"], + alpha=config["optim_alpha"], + eps=config["optim_eps"]) + + @override(PolicyGraph) + def compute_actions(self, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + obs_batch, action_mask = self._unpack_observation(obs_batch) + assert len(state_batches) == self.n_agents, state_batches + state_batches = np.stack(state_batches, axis=1) + + # Compute actions + with th.no_grad(): + q_values, hiddens = _mac(self.model, th.from_numpy(obs_batch), + th.from_numpy(state_batches)) + avail = th.from_numpy(action_mask).float() + masked_q_values = q_values.clone() + masked_q_values[avail == 0.0] = -float("inf") + # epsilon-greedy action selector + random_numbers = th.rand_like(q_values[:, :, 0]) + pick_random = (random_numbers < self.cur_epsilon).long() + random_actions = Categorical(avail).sample().long() + actions = (pick_random * random_actions + + (1 - pick_random) * masked_q_values.max(dim=2)[1]) + actions = var_to_np(actions) + hiddens = var_to_np(hiddens) + + return (TupleActions(list(actions.transpose([1, 0]))), + hiddens.transpose([1, 0, 2]), {}) + + @override(PolicyGraph) + def compute_apply(self, samples): + obs_batch, action_mask = self._unpack_observation(samples["obs"]) + group_rewards = self._get_group_rewards(samples["infos"]) + + # These will be padded to shape [B * T, ...] + [rew, action_mask, act, dones, obs], initial_states, seq_lens = \ + chop_into_sequences( + samples["eps_id"], + samples["agent_index"], [ + group_rewards, action_mask, samples["actions"], + samples["dones"], obs_batch + ], + [samples["state_in_{}".format(k)] + for k in range(self.n_agents)], + max_seq_len=self.config["model"]["max_seq_len"], + dynamic_max=True, + _extra_padding=1) + # TODO(ekl) adding 1 extra unit of padding here, since otherwise we + # lose the terminating reward and the Q-values will be unanchored! + B, T = len(seq_lens), max(seq_lens) + 1 + + def to_batches(arr): + new_shape = [B, T] + list(arr.shape[1:]) + return th.from_numpy(np.reshape(arr, new_shape)) + + rewards = to_batches(rew)[:, :-1].float() + actions = to_batches(act)[:, :-1].long() + obs = to_batches(obs).reshape([B, T, self.n_agents, + self.obs_size]).float() + action_mask = to_batches(action_mask) + + # TODO(ekl) this treats group termination as individual termination + terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand( + B, T, self.n_agents)[:, :-1] + filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) < + np.expand_dims(seq_lens, 1)).astype(np.float32) + mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, + self.n_agents)[:, :-1] + mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) + + # Compute loss + loss_out, mask, masked_td_error, chosen_action_qvals, targets = \ + self.loss(rewards, actions, terminated, mask, obs, action_mask) + + # Optimise + self.optimiser.zero_grad() + loss_out.backward() + grad_norm = th.nn.utils.clip_grad_norm_( + self.params, self.config["grad_norm_clipping"]) + self.optimiser.step() + + mask_elems = mask.sum().item() + stats = { + "loss": loss_out.item(), + "grad_norm": grad_norm + if isinstance(grad_norm, float) else grad_norm.item(), + "td_error_abs": masked_td_error.abs().sum().item() / mask_elems, + "q_taken_mean": (chosen_action_qvals * mask).sum().item() / + mask_elems, + "target_mean": (targets * mask).sum().item() / mask_elems, + } + return {"stats": stats}, {} + + @override(PolicyGraph) + def get_initial_state(self): + return [ + self.model.init_hidden().numpy().squeeze() + for _ in range(self.n_agents) + ] + + @override(PolicyGraph) + def get_weights(self): + return {"model": self.model.state_dict()} + + @override(PolicyGraph) + def set_weights(self, weights): + self.model.load_state_dict(weights["model"]) + + @override(PolicyGraph) + def get_state(self): + return { + "model": self.model.state_dict(), + "target_model": self.target_model.state_dict(), + "mixer": self.mixer.state_dict() if self.mixer else None, + "target_mixer": self.target_mixer.state_dict() + if self.mixer else None, + "cur_epsilon": self.cur_epsilon, + } + + @override(PolicyGraph) + def set_state(self, state): + self.model.load_state_dict(state["model"]) + self.target_model.load_state_dict(state["target_model"]) + if state["mixer"] is not None: + self.mixer.load_state_dict(state["mixer"]) + self.target_mixer.load_state_dict(state["target_mixer"]) + self.set_epsilon(state["cur_epsilon"]) + self.update_target() + + def update_target(self): + self.target_model.load_state_dict(self.model.state_dict()) + if self.mixer is not None: + self.target_mixer.load_state_dict(self.mixer.state_dict()) + logger.debug("Updated target networks") + + def set_epsilon(self, epsilon): + self.cur_epsilon = epsilon + + def _get_group_rewards(self, info_batch): + group_rewards = np.array([ + info.get(GROUP_REWARDS, [0.0] * self.n_agents) + for info in info_batch + ]) + return group_rewards + + def _unpack_observation(self, obs_batch): + unpacked = _unpack_obs( + np.array(obs_batch), + self.observation_space.original_space, + tensorlib=np) + if self.has_action_mask: + obs = np.concatenate( + [o["obs"] for o in unpacked], + axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) + action_mask = np.concatenate( + [o["action_mask"] for o in unpacked], axis=1).reshape( + [len(obs_batch), self.n_agents, self.n_actions]) + else: + obs = np.concatenate( + unpacked, + axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) + action_mask = np.ones( + [len(obs_batch), self.n_agents, self.n_actions]) + return obs, action_mask + + +def _validate(obs_space, action_space): + if not hasattr(obs_space, "original_space") or \ + not isinstance(obs_space.original_space, Tuple): + raise ValueError("Obs space must be a Tuple, got {}. Use ".format( + obs_space) + "MultiAgentEnv.with_agent_groups() to group related " + "agents for QMix.") + if not isinstance(action_space, Tuple): + raise ValueError( + "Action space must be a Tuple, got {}. ".format(action_space) + + "Use MultiAgentEnv.with_agent_groups() to group related " + "agents for QMix.") + if not isinstance(action_space.spaces[0], Discrete): + raise ValueError( + "QMix requires a discrete action space, got {}".format( + action_space.spaces[0])) + if len({str(x) for x in obs_space.original_space.spaces}) > 1: + raise ValueError( + "Implementation limitation: observations of grouped agents " + "must be homogeneous, got {}".format( + obs_space.original_space.spaces)) + if len({str(x) for x in action_space.spaces}) > 1: + raise ValueError( + "Implementation limitation: action space of grouped agents " + "must be homogeneous, got {}".format(action_space.spaces)) + + +def _get_size(obs_space): + return get_preprocessor(obs_space)(obs_space).size + + +def _mac(model, obs, h): + """Forward pass of the multi-agent controller. + + Arguments: + model: Model that produces q-values for a 1d agent batch + obs: Tensor of shape [B, n_agents, obs_size] + h: Tensor of shape [B, n_agents, h_size] + + Returns: + q_vals: Tensor of shape [B, n_agents, n_actions] + h: Tensor of shape [B, n_agents, h_size] + """ + B, n_agents = obs.size(0), obs.size(1) + obs_flat = obs.reshape([B * n_agents, -1]) + h_flat = h.reshape([B * n_agents, -1]) + q_flat, h_flat = model.forward(obs_flat, h_flat) + return q_flat.reshape([B, n_agents, -1]), h_flat.reshape([B, n_agents, -1]) diff --git a/python/ray/rllib/agents/registry.py b/python/ray/rllib/agents/registry.py new file mode 100644 index 0000000000000..98468026f837f --- /dev/null +++ b/python/ray/rllib/agents/registry.py @@ -0,0 +1,122 @@ +"""Registry of algorithm names for `rllib train --run=`""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import traceback + +from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS + + +def _import_qmix(): + from ray.rllib.agents import qmix + return qmix.QMixAgent + + +def _import_apex_qmix(): + from ray.rllib.agents import qmix + return qmix.ApexQMixAgent + + +def _import_ddpg(): + from ray.rllib.agents import ddpg + return ddpg.DDPGAgent + + +def _import_apex_ddpg(): + from ray.rllib.agents import ddpg + return ddpg.ApexDDPGAgent + + +def _import_ppo(): + from ray.rllib.agents import ppo + return ppo.PPOAgent + + +def _import_es(): + from ray.rllib.agents import es + return es.ESAgent + + +def _import_ars(): + from ray.rllib.agents import ars + return ars.ARSAgent + + +def _import_dqn(): + from ray.rllib.agents import dqn + return dqn.DQNAgent + + +def _import_apex(): + from ray.rllib.agents import dqn + return dqn.ApexAgent + + +def _import_a3c(): + from ray.rllib.agents import a3c + return a3c.A3CAgent + + +def _import_a2c(): + from ray.rllib.agents import a3c + return a3c.A2CAgent + + +def _import_pg(): + from ray.rllib.agents import pg + return pg.PGAgent + + +def _import_impala(): + from ray.rllib.agents import impala + return impala.ImpalaAgent + + +ALGORITHMS = { + "DDPG": _import_ddpg, + "APEX_DDPG": _import_apex_ddpg, + "PPO": _import_ppo, + "ES": _import_es, + "ARS": _import_ars, + "DQN": _import_dqn, + "APEX": _import_apex, + "A3C": _import_a3c, + "A2C": _import_a2c, + "PG": _import_pg, + "IMPALA": _import_impala, + "QMIX": _import_qmix, + "APEX_QMIX": _import_apex_qmix, +} + + +def get_agent_class(alg): + """Returns the class of a known agent given its name.""" + + try: + return _get_agent_class(alg) + except ImportError: + from ray.rllib.agents.mock import _agent_import_failed + return _agent_import_failed(traceback.format_exc()) + + +def _get_agent_class(alg): + if alg in ALGORITHMS: + return ALGORITHMS[alg]() + elif alg in CONTRIBUTED_ALGORITHMS: + return CONTRIBUTED_ALGORITHMS[alg]() + elif alg == "script": + from ray.tune import script_runner + return script_runner.ScriptRunner + elif alg == "__fake": + from ray.rllib.agents.mock import _MockAgent + return _MockAgent + elif alg == "__sigmoid_fake_data": + from ray.rllib.agents.mock import _SigmoidFakeData + return _SigmoidFakeData + elif alg == "__parameter_tuning": + from ray.rllib.agents.mock import _ParameterTuningAgent + return _ParameterTuningAgent + else: + raise Exception(("Unknown algorithm {}.").format(alg)) diff --git a/python/ray/rllib/contrib/README.rst b/python/ray/rllib/contrib/README.rst new file mode 100644 index 0000000000000..0532a47674fa1 --- /dev/null +++ b/python/ray/rllib/contrib/README.rst @@ -0,0 +1,3 @@ +Contributed algorithms, which can be run via ``rllib train --run=contrib/`` + +See https://ray.readthedocs.io/en/latest/rllib-dev.html for guidelines. diff --git a/python/ray/rllib/tuned_examples/regression_tests/__init__.py b/python/ray/rllib/contrib/__init__.py similarity index 100% rename from python/ray/rllib/tuned_examples/regression_tests/__init__.py rename to python/ray/rllib/contrib/__init__.py diff --git a/python/ray/rllib/contrib/random_agent/random_agent.py b/python/ray/rllib/contrib/random_agent/random_agent.py new file mode 100644 index 0000000000000..803f33ac8c8a9 --- /dev/null +++ b/python/ray/rllib/contrib/random_agent/random_agent.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from ray.rllib.agents.agent import Agent, with_common_config +from ray.rllib.utils.annotations import override + + +# yapf: disable +# __sphinx_doc_begin__ +class RandomAgent(Agent): + """Agent that takes random actions and never learns.""" + + _agent_name = "RandomAgent" + _default_config = with_common_config({ + "rollouts_per_iteration": 10, + }) + + @override(Agent) + def _init(self): + self.env = self.env_creator(self.config["env_config"]) + + @override(Agent) + def _train(self): + rewards = [] + steps = 0 + for _ in range(self.config["rollouts_per_iteration"]): + obs = self.env.reset() + done = False + reward = 0.0 + while not done: + action = self.env.action_space.sample() + obs, r, done, info = self.env.step(action) + reward += r + steps += 1 + rewards.append(reward) + return { + "episode_reward_mean": np.mean(rewards), + "timesteps_this_iter": steps, + } +# __sphinx_doc_end__ +# don't enable yapf after, it's buggy here + + +if __name__ == "__main__": + agent = RandomAgent( + env="CartPole-v0", config={"rollouts_per_iteration": 10}) + result = agent.train() + assert result["episode_reward_mean"] > 10, result + print("Test: OK") diff --git a/python/ray/rllib/contrib/registry.py b/python/ray/rllib/contrib/registry.py new file mode 100644 index 0000000000000..650a4429d3b92 --- /dev/null +++ b/python/ray/rllib/contrib/registry.py @@ -0,0 +1,15 @@ +"""Registry of algorithm names for `rllib train --run=`""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def _import_random_agent(): + from ray.rllib.contrib.random_agent.random_agent import RandomAgent + return RandomAgent + + +CONTRIBUTED_ALGORITHMS = { + "contrib/RandomAgent": _import_random_agent, +} diff --git a/python/ray/rllib/env/async_vector_env.py b/python/ray/rllib/env/async_vector_env.py index aff3738026b82..68ff1f2f7f221 100644 --- a/python/ray/rllib/env/async_vector_env.py +++ b/python/ray/rllib/env/async_vector_env.py @@ -5,6 +5,7 @@ from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.utils.annotations import override class AsyncVectorEnv(object): @@ -158,6 +159,7 @@ def __init__(self, external_env, preprocessor=None): self.observation_space = external_env.observation_space external_env.start() + @override(AsyncVectorEnv) def poll(self): with self.external_env._results_avail_condition: results = self._poll() @@ -172,6 +174,12 @@ def poll(self): "ExternalEnv was created with max_concurrent={}".format(limit)) return results + @override(AsyncVectorEnv) + def send_actions(self, action_dict): + for eid, action in action_dict.items(): + self.external_env._episodes[eid].action_queue.put( + action[_DUMMY_AGENT_ID]) + def _poll(self): all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {} off_policy_actions = {} @@ -195,11 +203,6 @@ def _poll(self): _with_dummy_agent_id(all_infos), \ _with_dummy_agent_id(off_policy_actions) - def send_actions(self, action_dict): - for eid, action in action_dict.items(): - self.external_env._episodes[eid].action_queue.put( - action[_DUMMY_AGENT_ID]) - class _VectorEnvToAsync(AsyncVectorEnv): """Internal adapter of VectorEnv to AsyncVectorEnv. @@ -219,6 +222,7 @@ def __init__(self, vector_env): self.cur_dones = [False for _ in range(self.num_envs)] self.cur_infos = [None for _ in range(self.num_envs)] + @override(AsyncVectorEnv) def poll(self): if self.new_obs is None: self.new_obs = self.vector_env.vector_reset() @@ -235,6 +239,7 @@ def poll(self): _with_dummy_agent_id(dones, "__all__"), \ _with_dummy_agent_id(infos), {} + @override(AsyncVectorEnv) def send_actions(self, action_dict): action_vector = [None] * self.num_envs for i in range(self.num_envs): @@ -242,9 +247,11 @@ def send_actions(self, action_dict): self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \ self.vector_env.vector_step(action_vector) + @override(AsyncVectorEnv) def try_reset(self, env_id): return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)} + @override(AsyncVectorEnv) def get_unwrapped(self): return self.vector_env.get_unwrapped() @@ -275,12 +282,14 @@ def __init__(self, make_env, existing_envs, num_envs): assert isinstance(env, MultiAgentEnv) self.env_states = [_MultiAgentEnvState(env) for env in self.envs] + @override(AsyncVectorEnv) def poll(self): obs, rewards, dones, infos = {}, {}, {}, {} for i, env_state in enumerate(self.env_states): obs[i], rewards[i], dones[i], infos[i] = env_state.poll() return obs, rewards, dones, infos, {} + @override(AsyncVectorEnv) def send_actions(self, action_dict): for env_id, agent_dict in action_dict.items(): if env_id in self.dones: @@ -291,10 +300,18 @@ def send_actions(self, action_dict): assert isinstance(rewards, dict), "Not a multi-agent reward" assert isinstance(dones, dict), "Not a multi-agent return" assert isinstance(infos, dict), "Not a multi-agent info" + if set(obs.keys()) != set(rewards.keys()): + raise ValueError( + "Key set for obs and rewards must be the same: " + "{} vs {}".format(obs.keys(), rewards.keys())) + if set(infos).difference(set(obs)): + raise ValueError("Key set for infos must be a subset of obs: " + "{} vs {}".format(infos.keys(), obs.keys())) if dones["__all__"]: self.dones.add(env_id) self.env_states[env_id].observe(obs, rewards, dones, infos) + @override(AsyncVectorEnv) def try_reset(self, env_id): obs = self.env_states[env_id].reset() assert isinstance(obs, dict), "Not a multi-agent obs" diff --git a/python/ray/rllib/env/constants.py b/python/ray/rllib/env/constants.py new file mode 100644 index 0000000000000..2b6c460f52d4d --- /dev/null +++ b/python/ray/rllib/env/constants.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# info key for the individual rewards of an agent, for example: +# info: { +# group_1: { +# _group_rewards: [5, -1, 1], # 3 agents in this group +# } +# } +GROUP_REWARDS = "_group_rewards" + +# info key for the individual infos of an agent, for example: +# info: { +# group_1: { +# _group_infos: [{"foo": ...}, {}], # 2 agents in this group +# } +# } +GROUP_INFO = "_group_info" diff --git a/python/ray/rllib/env/group_agents_wrapper.py b/python/ray/rllib/env/group_agents_wrapper.py new file mode 100644 index 0000000000000..8f6051bf07d71 --- /dev/null +++ b/python/ray/rllib/env/group_agents_wrapper.py @@ -0,0 +1,107 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from ray.rllib.env.constants import GROUP_REWARDS, GROUP_INFO +from ray.rllib.env.multi_agent_env import MultiAgentEnv + + +# TODO(ekl) we should add some unit tests for this +class _GroupAgentsWrapper(MultiAgentEnv): + """Wraps a MultiAgentEnv environment with agents grouped as specified. + + See multi_agent_env.py for the specification of groups. + + This API is experimental. + """ + + def __init__(self, env, groups, obs_space=None, act_space=None): + """Wrap an existing multi-agent env to group agents together. + + See MultiAgentEnv.with_agent_groups() for usage info. + + Arguments: + env (MultiAgentEnv): env to wrap + groups (dict): Grouping spec as documented in MultiAgentEnv + obs_space (Space): Optional observation space for the grouped + env. Must be a tuple space. + act_space (Space): Optional action space for the grouped env. + Must be a tuple space. + """ + + self.env = env + self.groups = groups + self.agent_id_to_group = {} + for group_id, agent_ids in groups.items(): + for agent_id in agent_ids: + if agent_id in self.agent_id_to_group: + raise ValueError( + "Agent id {} is in multiple groups".format( + agent_id, groups)) + self.agent_id_to_group[agent_id] = group_id + if obs_space is not None: + self.observation_space = obs_space + if act_space is not None: + self.action_space = act_space + + def reset(self): + obs = self.env.reset() + return self._group_items(obs) + + def step(self, action_dict): + # Ungroup and send actions + action_dict = self._ungroup_items(action_dict) + obs, rewards, dones, infos = self.env.step(action_dict) + + # Apply grouping transforms to the env outputs + obs = self._group_items(obs) + rewards = self._group_items( + rewards, agg_fn=lambda gvals: list(gvals.values())) + dones = self._group_items( + dones, agg_fn=lambda gvals: all(gvals.values())) + infos = self._group_items( + infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())}) + + # Aggregate rewards, but preserve the original values in infos + for agent_id, rew in rewards.items(): + if isinstance(rew, list): + rewards[agent_id] = sum(rew) + if agent_id not in infos: + infos[agent_id] = {} + infos[agent_id][GROUP_REWARDS] = rew + + return obs, rewards, dones, infos + + def _ungroup_items(self, items): + out = {} + for agent_id, value in items.items(): + if agent_id in self.groups: + assert len(value) == len(self.groups[agent_id]), \ + (agent_id, value, self.groups) + for a, v in zip(self.groups[agent_id], value): + out[a] = v + else: + out[agent_id] = value + return out + + def _group_items(self, items, agg_fn=lambda gvals: list(gvals.values())): + grouped_items = {} + for agent_id, item in items.items(): + if agent_id in self.agent_id_to_group: + group_id = self.agent_id_to_group[agent_id] + if group_id in grouped_items: + continue # already added + group_out = OrderedDict() + for a in self.groups[group_id]: + if a in items: + group_out[a] = items[a] + else: + raise ValueError( + "Missing member of group {}: {}: {}".format( + group_id, a, items)) + grouped_items[group_id] = agg_fn(group_out) + else: + grouped_items[agent_id] = item + return grouped_items diff --git a/python/ray/rllib/env/multi_agent_env.py b/python/ray/rllib/env/multi_agent_env.py index 42f7cee8c0428..b1aaf96807056 100644 --- a/python/ray/rllib/env/multi_agent_env.py +++ b/python/ray/rllib/env/multi_agent_env.py @@ -30,9 +30,14 @@ class MultiAgentEnv(object): } >>> print(dones) { - "car_0": False, - "car_1": True, - "__all__": False, + "car_0": False, # car_0 is still running + "car_1": True, # car_1 is done + "__all__": False, # the env is not done + } + >>> print(infos) + { + "car_0": {}, # info for car_0 + "car_1": {}, # info for car_1 } """ @@ -56,7 +61,48 @@ def step(self, action_dict): rewards (dict): Reward values for each ready agent. If the episode is just started, the value will be None. dones (dict): Done values for each ready agent. The special key - "__all__" is used to indicate env termination. - infos (dict): Info values for each ready agent. + "__all__" (required) is used to indicate env termination. + infos (dict): Optional info values for each agent id. """ raise NotImplementedError + +# yapf: disable +# __grouping_doc_begin__ + def with_agent_groups(self, groups, obs_space=None, act_space=None): + """Convenience method for grouping together agents in this env. + + An agent group is a list of agent ids that are mapped to a single + logical agent. All agents of the group must act at the same time in the + environment. The grouped agent exposes Tuple action and observation + spaces that are the concatenated action and obs spaces of the + individual agents. + + The rewards of all the agents in a group are summed. The individual + agent rewards are available under the "individual_rewards" key of the + group info return. + + Agent grouping is required to leverage algorithms such as Q-Mix. + + This API is experimental. + + Arguments: + groups (dict): Mapping from group id to a list of the agent ids + of group members. If an agent id is not present in any group + value, it will be left ungrouped. + obs_space (Space): Optional observation space for the grouped + env. Must be a tuple space. + act_space (Space): Optional action space for the grouped env. + Must be a tuple space. + + Examples: + >>> env = YourMultiAgentEnv(...) + >>> grouped_env = env.with_agent_groups(env, { + ... "group1": ["agent1", "agent2", "agent3"], + ... "group2": ["agent4", "agent5"], + ... }) + """ + + from ray.rllib.env.group_agents_wrapper import _GroupAgentsWrapper + return _GroupAgentsWrapper(self, groups, obs_space, act_space) +# __grouping_doc_end__ +# yapf: enable diff --git a/python/ray/rllib/env/vector_env.py b/python/ray/rllib/env/vector_env.py index 8d2289cf41448..c2eb1692061ce 100644 --- a/python/ray/rllib/env/vector_env.py +++ b/python/ray/rllib/env/vector_env.py @@ -2,6 +2,8 @@ from __future__ import division from __future__ import print_function +from ray.rllib.utils.annotations import override + class VectorEnv(object): """An environment that supports batch evaluation. @@ -72,12 +74,15 @@ def __init__(self, make_env, existing_envs, num_envs): self.action_space = self.envs[0].action_space self.observation_space = self.envs[0].observation_space + @override(VectorEnv) def vector_reset(self): return [e.reset() for e in self.envs] + @override(VectorEnv) def reset_at(self, index): return self.envs[index].reset() + @override(VectorEnv) def vector_step(self, actions): obs_batch, rew_batch, done_batch, info_batch = [], [], [], [] for i in range(self.num_envs): @@ -88,5 +93,6 @@ def vector_step(self, actions): info_batch.append(info) return obs_batch, rew_batch, done_batch, info_batch + @override(VectorEnv) def get_unwrapped(self): return self.envs diff --git a/python/ray/rllib/evaluation/episode.py b/python/ray/rllib/evaluation/episode.py index 24fa431f9b461..11977745184d5 100644 --- a/python/ray/rllib/evaluation/episode.py +++ b/python/ray/rllib/evaluation/episode.py @@ -60,6 +60,7 @@ def __init__(self, policies, policy_mapping_fn, batch_builder_factory, self._agent_to_policy = {} self._agent_to_rnn_state = {} self._agent_to_last_obs = {} + self._agent_to_last_info = {} self._agent_to_last_action = {} self._agent_to_last_pi_info = {} self._agent_to_prev_action = {} @@ -81,6 +82,11 @@ def last_observation_for(self, agent_id=_DUMMY_AGENT_ID): return self._agent_to_last_obs.get(agent_id) + def last_info_for(self, agent_id=_DUMMY_AGENT_ID): + """Returns the last info for the specified agent.""" + + return self._agent_to_last_info.get(agent_id) + def last_action_for(self, agent_id=_DUMMY_AGENT_ID): """Returns the last action for the specified agent, or zeros.""" @@ -137,6 +143,9 @@ def _set_rnn_state(self, agent_id, rnn_state): def _set_last_observation(self, agent_id, obs): self._agent_to_last_obs[agent_id] = obs + def _set_last_info(self, agent_id, info): + self._agent_to_last_info[agent_id] = info + def _set_last_action(self, agent_id, action): self._agent_to_last_action[agent_id] = action diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index fadf2a5a249ae..1b270be3738c1 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -53,7 +53,7 @@ def summarize_episodes(episodes, new_episodes, num_dropped): """ if num_dropped > 0: - logger.warn("WARNING: {} workers have NOT returned metrics".format( + logger.warning("WARNING: {} workers have NOT returned metrics".format( num_dropped)) episode_rewards = [] @@ -80,8 +80,16 @@ def summarize_episodes(episodes, new_episodes, num_dropped): for policy_id, rewards in policy_rewards.copy().items(): policy_rewards[policy_id] = np.mean(rewards) - for k, v_list in custom_metrics.items(): - custom_metrics[k] = np.mean(v_list) + for k, v_list in custom_metrics.copy().items(): + custom_metrics[k + "_mean"] = np.mean(v_list) + filt = [v for v in v_list if not np.isnan(v)] + if filt: + custom_metrics[k + "_min"] = np.min(filt) + custom_metrics[k + "_max"] = np.max(filt) + else: + custom_metrics[k + "_min"] = float("nan") + custom_metrics[k + "_max"] = float("nan") + del custom_metrics[k] return dict( episode_reward_max=max_reward, diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index db5f7ee887b2f..f5c250aa92847 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -8,7 +8,6 @@ import tensorflow as tf import ray -from ray.rllib.models import ModelCatalog from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari from ray.rllib.env.env_context import EnvContext @@ -19,7 +18,11 @@ from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader +from ray.rllib.models import ModelCatalog +from ray.rllib.models.preprocessors import NoPreprocessor from ray.rllib.utils import merge_dicts +from ray.rllib.utils.annotations import override from ray.rllib.utils.compression import pack from ray.rllib.utils.filter import get_filter from ray.rllib.utils.tf_run_builder import TFRunBuilder @@ -100,13 +103,18 @@ def __init__(self, num_envs=1, observation_filter="NoFilter", clip_rewards=None, + clip_actions=True, env_config=None, model_config=None, policy_config=None, worker_index=0, monitor_path=None, + log_dir=None, log_level=None, - callbacks=None): + callbacks=None, + input_creator=lambda ioctx: ioctx.default_sampler_input(), + input_evaluation_method=None, + output_creator=lambda ioctx: NoopOutput()): """Initialize a policy evaluator. Arguments: @@ -155,6 +163,8 @@ def __init__(self, clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to experience postprocessing. Setting to None means clip for Atari only. + clip_actions (bool): Whether to clip action values to the range + specified by the policy action space. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. In the @@ -165,8 +175,22 @@ def __init__(self, through EnvContext so that envs can be configured per worker. monitor_path (str): Write out episode stats and videos to this directory if specified. + log_dir (str): Directory where logs can be placed. log_level (str): Set the root log level on creation. callbacks (dict): Dict of custom debug callbacks. + input_creator (func): Function that returns an InputReader object + for loading previous generated experiences. + input_evaluation_method (str): How to evaluate the current policy. + This only applies when the input is reading offline data. + Options are: + - None: don't evaluate the policy. The episode reward and + other metrics will be NaN. + - "simulation": run the environment in the background, but + use this data for evaluation only and never for learning. + - "counterfactual": use counterfactual policy evaluation to + estimate performance. + output_creator (func): Function that returns an OutputWriter object + for saving generated experiences. """ if log_level: @@ -188,23 +212,21 @@ def __init__(self, self.sample_batch_size = batch_steps * num_envs self.batch_mode = batch_mode self.compress_observations = compress_observations + self.preprocessing_enabled = True self.env = env_creator(env_context) if isinstance(self.env, MultiAgentEnv) or \ isinstance(self.env, AsyncVectorEnv): - if model_config.get("custom_preprocessor"): - raise ValueError( - "Custom preprocessors are not supported for env types " - "MultiAgentEnv and AsyncVectorEnv. Please preprocess " - "observations in your env directly.") - def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ not model_config.get("custom_preprocessor") and \ preprocessor_pref == "deepmind": + # Deepmind wrappers already handle all preprocessing + self.preprocessing_enabled = False + if clip_rewards is None: clip_rewards = True @@ -219,8 +241,6 @@ def wrap(env): else: def wrap(env): - env = ModelCatalog.get_preprocessor_as_wrapper( - env, model_config) if monitor_path: env = _monitor(env, monitor_path) return env @@ -235,6 +255,11 @@ def make_env(vector_index): policy_dict = _validate_and_canonicalize(policy_graph, self.env) self.policies_to_train = policies_to_train or list(policy_dict.keys()) if _has_tensorflow_graph(policy_dict): + if (ray.worker._mode() != ray.worker.LOCAL_MODE + and not ray.get_gpu_ids()): + logger.info("Creating policy evaluation worker {}".format( + worker_index) + + " on CPU (please ignore any CUDA init errors)") with tf.Graph().as_default(): if tf_session_creator: self.tf_sess = tf_session_creator() @@ -243,11 +268,11 @@ def make_env(vector_index): config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) with self.tf_sess.as_default(): - self.policy_map = self._build_policy_map( - policy_dict, policy_config) + self.policy_map, self.preprocessors = \ + self._build_policy_map(policy_dict, policy_config) else: - self.policy_map = self._build_policy_map(policy_dict, - policy_config) + self.policy_map, self.preprocessors = self._build_policy_map( + policy_dict, policy_config) self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID} if self.multiagent: @@ -278,56 +303,61 @@ def make_env(vector_index): else: raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) + + if input_evaluation_method == "simulation": + logger.warning( + "Requested 'simulation' input evaluation method: " + "will discard all sampler outputs and keep only metrics.") + sample_async = True + elif input_evaluation_method == "counterfactual": + raise NotImplementedError + elif input_evaluation_method is None: + pass + else: + raise ValueError("Unknown evaluation method: {}".format( + input_evaluation_method)) + if sample_async: self.sampler = AsyncSampler( self.async_env, self.policy_map, policy_mapping_fn, + self.preprocessors, self.filters, clip_rewards, unroll_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, - tf_sess=self.tf_sess) + tf_sess=self.tf_sess, + clip_actions=clip_actions, + blackhole_outputs=input_evaluation_method == "simulation") self.sampler.start() else: self.sampler = SyncSampler( self.async_env, self.policy_map, policy_mapping_fn, + self.preprocessors, self.filters, clip_rewards, unroll_length, self.callbacks, horizon=episode_horizon, pack=pack_episodes, - tf_sess=self.tf_sess) + tf_sess=self.tf_sess, + clip_actions=clip_actions) + + self.io_context = IOContext(log_dir, policy_config, worker_index, self) + self.input_reader = input_creator(self.io_context) + assert isinstance(self.input_reader, InputReader), self.input_reader + self.output_writer = output_creator(self.io_context) + assert isinstance(self.output_writer, OutputWriter), self.output_writer logger.debug("Created evaluator with env {} ({}), policies {}".format( self.async_env, self.env, self.policy_map)) - def _build_policy_map(self, policy_dict, policy_config): - policy_map = {} - for name, (cls, obs_space, act_space, - conf) in sorted(policy_dict.items()): - merged_conf = merge_dicts(policy_config, conf) - with tf.variable_scope(name): - if isinstance(obs_space, gym.spaces.Dict): - raise ValueError( - "Found raw Dict space as input to policy graph. " - "Please preprocess your environment observations " - "with DictFlatteningPreprocessor and set the " - "obs space to `preprocessor.observation_space`.") - elif isinstance(obs_space, gym.spaces.Tuple): - raise ValueError( - "Found raw Tuple space as input to policy graph. " - "Please preprocess your environment observations " - "with TupleFlatteningPreprocessor and set the " - "obs space to `preprocessor.observation_space`.") - policy_map[name] = cls(obs_space, act_space, merged_conf) - return policy_map - + @override(EvaluatorInterface) def sample(self): """Evaluate the current policies and return a batch of experiences. @@ -335,7 +365,7 @@ def sample(self): SampleBatch|MultiAgentBatch from evaluating the current policies. """ - batches = [self.sampler.get_data()] + batches = [self.input_reader.next()] steps_so_far = batches[0].count # In truncate_episodes mode, never pull more than 1 batch per env. @@ -347,10 +377,9 @@ def sample(self): while steps_so_far < self.sample_batch_size and len( batches) < max_batches: - batch = self.sampler.get_data() + batch = self.input_reader.next() steps_so_far += batch.count batches.append(batch) - batches.extend(self.sampler.get_extra_batches()) batch = batches[0].concat_samples(batches) if self.callbacks.get("on_sample_end"): @@ -368,6 +397,7 @@ def sample(self): batch["obs"] = [pack(o) for o in batch["obs"]] batch["new_obs"] = [pack(o) for o in batch["new_obs"]] + self.output_writer.write(batch) return batch @ray.method(num_return_vals=2) @@ -376,52 +406,7 @@ def sample_with_count(self): batch = self.sample() return batch, batch.count - def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): - """Apply the given function to the specified policy graph.""" - - return func(self.policy_map[policy_id]) - - def foreach_policy(self, func): - """Apply the given function to each (policy, policy_id) tuple.""" - - return [func(policy, pid) for pid, policy in self.policy_map.items()] - - def foreach_trainable_policy(self, func): - """Apply the given function to each (policy, policy_id) tuple. - - This only applies func to policies in `self.policies_to_train`.""" - - return [ - func(policy, pid) for pid, policy in self.policy_map.items() - if pid in self.policies_to_train - ] - - def sync_filters(self, new_filters): - """Changes self's filter to given and rebases any accumulated delta. - - Args: - new_filters (dict): Filters with new state to update local copy. - """ - assert all(k in new_filters for k in self.filters) - for k in self.filters: - self.filters[k].sync(new_filters[k]) - - def get_filters(self, flush_after=False): - """Returns a snapshot of filters. - - Args: - flush_after (bool): Clears the filter buffer state. - - Returns: - return_filters (dict): Dict for serializable filters - """ - return_filters = {} - for k, f in self.filters.items(): - return_filters[k] = f.as_serializable() - if flush_after: - f.clear_buffer() - return return_filters - + @override(EvaluatorInterface) def get_weights(self, policies=None): if policies is None: policies = self.policy_map.keys() @@ -430,10 +415,12 @@ def get_weights(self, policies=None): for pid, policy in self.policy_map.items() if pid in policies } + @override(EvaluatorInterface) def set_weights(self, weights): for pid, w in weights.items(): self.policy_map[pid].set_weights(w) + @override(EvaluatorInterface) def compute_gradients(self, samples): if isinstance(samples, MultiAgentBatch): grad_out, info_out = {}, {} @@ -443,7 +430,7 @@ def compute_gradients(self, samples): if pid not in self.policies_to_train: continue grad_out[pid], info_out[pid] = ( - self.policy_map[pid].build_compute_gradients( + self.policy_map[pid]._build_compute_gradients( builder, batch)) grad_out = {k: builder.get(v) for k, v in grad_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()} @@ -459,12 +446,13 @@ def compute_gradients(self, samples): info_out["batch_count"] = samples.count return grad_out, info_out + @override(EvaluatorInterface) def apply_gradients(self, grads): if isinstance(grads, dict): if self.tf_sess is not None: builder = TFRunBuilder(self.tf_sess, "apply_gradients") outputs = { - pid: self.policy_map[pid].build_apply_gradients( + pid: self.policy_map[pid]._build_apply_gradients( builder, grad) for pid, grad in grads.items() } @@ -477,6 +465,7 @@ def apply_gradients(self, grads): else: return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) + @override(EvaluatorInterface) def compute_apply(self, samples): if isinstance(samples, MultiAgentBatch): info_out = {} @@ -486,7 +475,7 @@ def compute_apply(self, samples): if pid not in self.policies_to_train: continue info_out[pid], _ = ( - self.policy_map[pid].build_compute_apply( + self.policy_map[pid]._build_compute_apply( builder, batch)) info_out = {k: builder.get(v) for k, v in info_out.items()} else: @@ -501,6 +490,52 @@ def compute_apply(self, samples): self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples)) return grad_fetch + def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): + """Apply the given function to the specified policy graph.""" + + return func(self.policy_map[policy_id]) + + def foreach_policy(self, func): + """Apply the given function to each (policy, policy_id) tuple.""" + + return [func(policy, pid) for pid, policy in self.policy_map.items()] + + def foreach_trainable_policy(self, func): + """Apply the given function to each (policy, policy_id) tuple. + + This only applies func to policies in `self.policies_to_train`.""" + + return [ + func(policy, pid) for pid, policy in self.policy_map.items() + if pid in self.policies_to_train + ] + + def sync_filters(self, new_filters): + """Changes self's filter to given and rebases any accumulated delta. + + Args: + new_filters (dict): Filters with new state to update local copy. + """ + assert all(k in new_filters for k in self.filters) + for k in self.filters: + self.filters[k].sync(new_filters[k]) + + def get_filters(self, flush_after=False): + """Returns a snapshot of filters. + + Args: + flush_after (bool): Clears the filter buffer state. + + Returns: + return_filters (dict): Dict for serializable filters + """ + return_filters = {} + for k, f in self.filters.items(): + return_filters[k] = f.as_serializable() + if flush_after: + f.clear_buffer() + return return_filters + def save(self): filters = self.get_filters(flush_after=True) state = { @@ -518,6 +553,43 @@ def restore(self, objs): def set_global_vars(self, global_vars): self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars)) + def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): + self.policy_map[policy_id].export_model(export_dir) + + def export_policy_checkpoint(self, + export_dir, + filename_prefix="model", + policy_id=DEFAULT_POLICY_ID): + self.policy_map[policy_id].export_checkpoint(export_dir, + filename_prefix) + + def _build_policy_map(self, policy_dict, policy_config): + policy_map = {} + preprocessors = {} + for name, (cls, obs_space, act_space, + conf) in sorted(policy_dict.items()): + merged_conf = merge_dicts(policy_config, conf) + if self.preprocessing_enabled: + preprocessor = ModelCatalog.get_preprocessor_for_space( + obs_space, merged_conf.get("model")) + preprocessors[name] = preprocessor + obs_space = preprocessor.observation_space + else: + preprocessors[name] = NoPreprocessor(obs_space) + if isinstance(obs_space, gym.spaces.Dict) or \ + isinstance(obs_space, gym.spaces.Tuple): + raise ValueError( + "Found raw Tuple|Dict space as input to policy graph. " + "Please preprocess these observations with a " + "Tuple|DictFlatteningPreprocessor.") + with tf.variable_scope(name): + policy_map[name] = cls(obs_space, act_space, merged_conf) + return policy_map, preprocessors + + def __del__(self): + if isinstance(self.sampler, AsyncSampler): + self.sampler.shutdown = True + def _validate_and_canonicalize(policy_graph, env): if isinstance(policy_graph, dict): @@ -549,6 +621,11 @@ def _validate_and_canonicalize(policy_graph, env): elif not issubclass(policy_graph, PolicyGraph): raise ValueError("policy_graph must be a rllib.PolicyGraph class") else: + if (isinstance(env, MultiAgentEnv) + and not hasattr(env, "observation_space")): + raise ValueError( + "MultiAgentEnv must have observation_space defined if run " + "in a single-agent configuration.") return { DEFAULT_POLICY_ID: (policy_graph, env.observation_space, env.action_space, {}) diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index c19da286b0b9a..fc4be570614f4 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -42,7 +42,9 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - episodes=None): + info_batch=None, + episodes=None, + **kwargs): """Compute actions for the current policy. Arguments: @@ -50,9 +52,11 @@ def compute_actions(self, state_batches (list): list of RNN state input batches, if any prev_action_batch (np.ndarray): batch of previous action values prev_reward_batch (np.ndarray): batch of previous rewards + info_batch (info): batch of info objects episodes (list): MultiAgentEpisode for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multiagent algorithms. + kwargs: forward compatibility placeholder Returns: actions (np.ndarray): batch of output actions, with shape like @@ -69,7 +73,9 @@ def compute_single_action(self, state, prev_action_batch=None, prev_reward_batch=None, - episode=None): + info_batch=None, + episode=None, + **kwargs): """Unbatched version of compute_actions. Arguments: @@ -77,9 +83,11 @@ def compute_single_action(self, state_batches (list): list of RNN state inputs, if any prev_action_batch (np.ndarray): batch of previous action values prev_reward_batch (np.ndarray): batch of previous rewards + info_batch (list): batch of info objects episode (MultiAgentEpisode): this provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms. + kwargs: forward compatibility placeholder Returns: actions (obj): single action @@ -192,3 +200,19 @@ def on_global_var_update(self, global_vars): global_vars (dict): Global variables broadcast from the driver. """ pass + + def export_model(self, export_dir): + """Export PolicyGraph to local directory for serving. + + Arguments: + export_dir (str): Local writable directory. + """ + raise NotImplementedError + + def export_checkpoint(self, export_dir): + """Export PolicyGraph checkpoint to local directory. + + Argument: + export_dir (str): Local writable directory. + """ + raise NotImplementedError diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py index caec1bf4352eb..f576e4f140d2a 100644 --- a/python/ray/rllib/evaluation/sample_batch.py +++ b/python/ray/rllib/evaluation/sample_batch.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +import six import collections import numpy as np @@ -79,6 +80,11 @@ def __init__(self, policy_map, clip_rewards): self.agent_to_policy = {} self.count = 0 # increment this manually + def total(self): + """Returns summed number of steps across all agent buffers.""" + + return sum(p.count for p in self.policy_builders.values()) + def has_pending_data(self): """Returns whether there is pending unprocessed data.""" @@ -195,6 +201,11 @@ def concat_samples(samples): out[policy_id] = SampleBatch.concat_samples(batches) return MultiAgentBatch(out, total_count) + def copy(self): + return MultiAgentBatch( + {k: v.copy() + for (k, v) in self.policy_batches.items()}, self.count) + def total(self): ct = 0 for batch in self.policy_batches.values(): @@ -223,8 +234,9 @@ def __init__(self, *args, **kwargs): self.data = dict(*args, **kwargs) lengths = [] for k, v in self.data.copy().items(): - assert type(k) == str, self + assert isinstance(k, six.string_types), self lengths.append(len(v)) + self.data[k] = np.array(v, copy=False) if not lengths: raise ValueError("Empty sample batch") assert len(set(lengths)) == 1, "data columns must be same length" @@ -256,6 +268,11 @@ def concat(self, other): out[k] = np.concatenate([self[k], other[k]]) return SampleBatch(out) + def copy(self): + return SampleBatch( + {k: np.array(v, copy=True) + for (k, v) in self.data.items()}) + def rows(self): """Returns an iterator over data rows, i.e. dicts with column values. diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 2c6411f33510f..d19530707029a 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +import gym from collections import defaultdict, namedtuple import logging import numpy as np @@ -9,8 +10,7 @@ import threading from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action -from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder, \ - MultiAgentBatch +from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv @@ -18,47 +18,45 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder logger = logging.getLogger(__name__) +_large_batch_warned = False RolloutMetrics = namedtuple( "RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards", "custom_metrics"]) -PolicyEvalData = namedtuple( - "PolicyEvalData", - ["env_id", "agent_id", "obs", "rnn_state", "prev_action", "prev_reward"]) +PolicyEvalData = namedtuple("PolicyEvalData", [ + "env_id", "agent_id", "obs", "info", "rnn_state", "prev_action", + "prev_reward" +]) class SyncSampler(object): - """This class interacts with the environment and tells it what to do. - - Note that batch_size is only a unit of measure here. Batches can - accumulate and the gradient can be calculated on up to 5 batches. - - This class provides data on invocation, rather than on a separate - thread.""" - def __init__(self, env, policies, policy_mapping_fn, + preprocessors, obs_filters, clip_rewards, unroll_length, callbacks, horizon=None, pack=False, - tf_sess=None): + tf_sess=None, + clip_actions=True): self.async_vector_env = AsyncVectorEnv.wrap_async(env) self.unroll_length = unroll_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn - self._obs_filters = obs_filters + self.preprocessors = preprocessors + self.obs_filters = obs_filters self.extra_batches = queue.Queue() self.rollout_provider = _env_runner( self.async_vector_env, self.extra_batches.put, self.policies, self.policy_mapping_fn, self.unroll_length, self.horizon, - self._obs_filters, clip_rewards, pack, callbacks, tf_sess) + self.preprocessors, self.obs_filters, clip_rewards, clip_actions, + pack, callbacks, tf_sess) self.metrics_queue = queue.Queue() def get_data(self): @@ -89,22 +87,20 @@ def get_extra_batches(self): class AsyncSampler(threading.Thread): - """This class interacts with the environment and tells it what to do. - - Note that batch_size is only a unit of measure here. Batches can - accumulate and the gradient can be calculated on up to 5 batches.""" - def __init__(self, env, policies, policy_mapping_fn, + preprocessors, obs_filters, clip_rewards, unroll_length, callbacks, horizon=None, pack=False, - tf_sess=None): + tf_sess=None, + clip_actions=True, + blackhole_outputs=False): for _, f in obs_filters.items(): assert getattr(f, "is_concurrent", False), \ "Observation Filter must support concurrent updates." @@ -117,12 +113,16 @@ def __init__(self, self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn - self._obs_filters = obs_filters + self.preprocessors = preprocessors + self.obs_filters = obs_filters self.clip_rewards = clip_rewards self.daemon = True self.pack = pack self.tf_sess = tf_sess self.callbacks = callbacks + self.clip_actions = clip_actions + self.blackhole_outputs = blackhole_outputs + self.shutdown = False def run(self): try: @@ -132,12 +132,19 @@ def run(self): raise e def _run(self): + if self.blackhole_outputs: + queue_putter = (lambda x: None) + extra_batches_putter = (lambda x: None) + else: + queue_putter = self.queue.put + extra_batches_putter = ( + lambda x: self.extra_batches.put(x, timeout=600.0)) rollout_provider = _env_runner( - self.async_vector_env, self.extra_batches.put, self.policies, + self.async_vector_env, extra_batches_putter, self.policies, self.policy_mapping_fn, self.unroll_length, self.horizon, - self._obs_filters, self.clip_rewards, self.pack, self.callbacks, - self.tf_sess) - while True: + self.preprocessors, self.obs_filters, self.clip_rewards, + self.clip_actions, self.pack, self.callbacks, self.tf_sess) + while not self.shutdown: # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is # set to some large number. This is an empirical observation. @@ -145,7 +152,7 @@ def _run(self): if isinstance(item, RolloutMetrics): self.metrics_queue.put(item) else: - self.queue.put(item, timeout=600.0) + queue_putter(item) def get_data(self): rollout = self.queue.get(timeout=600.0) @@ -154,20 +161,6 @@ def get_data(self): if isinstance(rollout, BaseException): raise rollout - # We can't auto-concat rollouts in these modes - if self.async_vector_env.num_envs > 1 or \ - isinstance(rollout, MultiAgentBatch): - return rollout - - # Auto-concat rollouts; TODO(ekl) is this important for A3C perf? - while not rollout["dones"][-1]: - try: - part = self.queue.get_nowait() - if isinstance(part, BaseException): - raise rollout - rollout = rollout.concat(part) - except queue.Empty: - break return rollout def get_metrics(self): @@ -195,8 +188,10 @@ def _env_runner(async_vector_env, policy_mapping_fn, unroll_length, horizon, + preprocessors, obs_filters, clip_rewards, + clip_actions, pack, callbacks, tf_sess=None): @@ -212,11 +207,14 @@ def _env_runner(async_vector_env, unroll_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. horizon (int): Horizon of the episode. + preprocessors (dict): Map of policy id to preprocessor for the + observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. pack (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `unroll_length` in size. + clip_actions (bool): Whether to clip actions to the space range. callbacks (dict): User callbacks to run on episode events. tf_sess (Session|None): Optional tensorflow session to use for batching TF policy evaluations. @@ -231,7 +229,7 @@ def _env_runner(async_vector_env, horizon = ( async_vector_env.get_unwrapped()[0].spec.max_episode_steps) except Exception: - logger.warn("no episode horizon specified, assuming inf") + logger.debug("no episode horizon specified, assuming inf") if not horizon: horizon = float("inf") @@ -266,7 +264,7 @@ def new_episode(): active_envs, to_eval, outputs = _process_observations( async_vector_env, policies, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, - obs_filters, unroll_length, pack, callbacks) + preprocessors, obs_filters, unroll_length, pack, callbacks) for o in outputs: yield o @@ -277,7 +275,7 @@ def new_episode(): # Process results and update episode state actions_to_send = _process_policy_eval_results( to_eval, eval_results, active_episodes, active_envs, - off_policy_actions) + off_policy_actions, policies, clip_actions) # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. @@ -286,8 +284,8 @@ def new_episode(): def _process_observations(async_vector_env, policies, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, - infos, off_policy_actions, horizon, obs_filters, - unroll_length, pack, callbacks): + infos, off_policy_actions, horizon, preprocessors, + obs_filters, unroll_length, pack, callbacks): """Record new data from the environment and prepare for policy evaluation. Returns: @@ -309,6 +307,21 @@ def _process_observations(async_vector_env, policies, batch_builder_pool, episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) + global _large_batch_warned + if (not _large_batch_warned and + episode.batch_builder.total() > max(1000, unroll_length * 10)): + _large_batch_warned = True + logger.warning( + "More than {} observations for {} env steps ".format( + episode.batch_builder.total(), + episode.batch_builder.count) + "are buffered in " + "the sampler. If this is more than you expected, check that " + "that you set a horizon on your environment correctly. Note " + "that in multi-agent environments, `sample_batch_size` sets " + "the batch size based on environment steps, not the steps of " + "individual agents, which can result in unexpectedly large " + "batches.") + # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: all_done = True @@ -329,21 +342,25 @@ def _process_observations(async_vector_env, policies, batch_builder_pool, # For each agent in the environment for agent_id, raw_obs in agent_obs.items(): policy_id = episode.policy_for(agent_id) - filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs) + prep_obs = _get_or_raise(preprocessors, + policy_id).transform(raw_obs) + filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, + infos[env_id].get(agent_id, {}), episode.rnn_state_for(agent_id), episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0)) last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) + episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable - if last_observation is not None and \ - infos[env_id][agent_id].get("training_enabled", True): + if (last_observation is not None and infos[env_id].get( + agent_id, {}).get("training_enabled", True)): episode.batch_builder.add_values( agent_id, policy_id, @@ -356,7 +373,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool, prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), dones=agent_done, - infos=infos[env_id][agent_id], + infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) @@ -399,12 +416,15 @@ def _process_observations(async_vector_env, policies, batch_builder_pool, for agent_id, raw_obs in resetted_obs.items(): policy_id = episode.policy_for(agent_id) policy = _get_or_raise(policies, policy_id) + prep_obs = _get_or_raise(preprocessors, + policy_id).transform(raw_obs) filtered_obs = _get_or_raise(obs_filters, - policy_id)(raw_obs) + policy_id)(prep_obs) episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, + episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), np.zeros_like( _flatten_action(policy.action_space.sample())), @@ -432,7 +452,8 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicyGraph.compute_actions.__code__): - pending_fetches[policy_id] = policy.build_compute_actions( + # TODO(ekl): how can we make info batch available to TF code? + pending_fetches[policy_id] = policy._build_compute_actions( builder, [t.obs for t in eval_data], rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], @@ -443,6 +464,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): rnn_in_cols, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], + info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: for k, v in pending_fetches.items(): @@ -452,7 +474,8 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): def _process_policy_eval_results(to_eval, eval_results, active_episodes, - active_envs, off_policy_actions): + active_envs, off_policy_actions, policies, + clip_actions): """Process the output of policy neural network evaluation. Records policy evaluation results into the given episode objects and @@ -479,10 +502,15 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes, pi_info_cols["state_out_{}".format(f_i)] = column # Save output rows actions = _unbatch_tuple_actions(actions) + policy = _get_or_raise(policies, policy_id) for i, action in enumerate(actions): env_id = eval_data[i].env_id agent_id = eval_data[i].agent_id - actions_to_send[env_id][agent_id] = action + if clip_actions: + actions_to_send[env_id][agent_id] = _clip_actions( + action, policy.action_space) + else: + actions_to_send[env_id][agent_id] = action episode = active_episodes[env_id] episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) episode._set_last_pi_info( @@ -516,6 +544,31 @@ def _fetch_atari_metrics(async_vector_env): return atari_out +def _clip_actions(actions, space): + """Called to clip actions to the specified range of this policy. + + Arguments: + actions: Single action. + space: Action space the actions should be present in. + + Returns: + Clipped batch of actions. + """ + + if isinstance(space, gym.spaces.Box): + return np.clip(actions, space.low, space.high) + elif isinstance(space, gym.spaces.Tuple): + if type(actions) not in (tuple, list): + raise ValueError("Expected tuple space for actions {}: {}".format( + actions, space)) + out = [] + for a, s in zip(actions, space.spaces): + out.append(_clip_actions(a, s)) + return out + else: + return actions + + def _unbatch_tuple_actions(action_batch): # convert list of batches -> batch of lists if isinstance(action_batch, TupleActions): diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 95e7a5d66bcbf..7574864c9c83f 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +import os import logging import tensorflow as tf import numpy as np @@ -9,8 +10,9 @@ import ray from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.models.lstm import chop_into_sequences -from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule +from ray.rllib.utils.tf_run_builder import TFRunBuilder logger = logging.getLogger(__name__) @@ -146,158 +148,216 @@ def __init__(self, logger.debug("Created {} with loss inputs: {}".format( self, self._loss_input_dict)) - def build_compute_actions(self, - builder, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - episodes=None): - state_batches = state_batches or [] - assert len(self._state_inputs) == len(state_batches), \ - (self._state_inputs, state_batches) - builder.add_feed_dict(self.extra_compute_action_feed_dict()) - builder.add_feed_dict({self._obs_input: obs_batch}) - if state_batches: - builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) - if self._prev_action_input is not None and prev_action_batch: - builder.add_feed_dict({self._prev_action_input: prev_action_batch}) - if self._prev_reward_input is not None and prev_reward_batch: - builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) - builder.add_feed_dict({self._is_training: False}) - builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) - fetches = builder.add_fetches([self._sampler] + self._state_outputs + - [self.extra_compute_action_fetches()]) - return fetches[0], fetches[1:-1], fetches[-1] - + @override(PolicyGraph) def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, - episodes=None): + info_batch=None, + episodes=None, + **kwargs): builder = TFRunBuilder(self._sess, "compute_actions") - fetches = self.build_compute_actions(builder, obs_batch, state_batches, - prev_action_batch, - prev_reward_batch) + fetches = self._build_compute_actions(builder, obs_batch, + state_batches, prev_action_batch, + prev_reward_batch) return builder.get(fetches) - def _get_loss_inputs_dict(self, batch): - feed_dict = {} - if self._batch_divisibility_req > 1: - meets_divisibility_reqs = ( - len(batch["obs"]) % self._batch_divisibility_req == 0 - and max(batch["agent_index"]) == 0) # not multiagent - else: - meets_divisibility_reqs = True - - # Simple case: not RNN nor do we need to pad - if not self._state_inputs and meets_divisibility_reqs: - for k, ph in self._loss_inputs: - feed_dict[ph] = batch[k] - return feed_dict - - if self._state_inputs: - max_seq_len = self._max_seq_len - dynamic_max = True - else: - max_seq_len = self._batch_divisibility_req - dynamic_max = False - - # RNN or multi-agent case - feature_keys = [k for k, v in self._loss_inputs] - state_keys = [ - "state_in_{}".format(i) for i in range(len(self._state_inputs)) - ] - feature_sequences, initial_states, seq_lens = chop_into_sequences( - batch["eps_id"], - batch["agent_index"], [batch[k] for k in feature_keys], - [batch[k] for k in state_keys], - max_seq_len, - dynamic_max=dynamic_max) - for k, v in zip(feature_keys, feature_sequences): - feed_dict[self._loss_input_dict[k]] = v - for k, v in zip(state_keys, initial_states): - feed_dict[self._loss_input_dict[k]] = v - feed_dict[self._seq_lens] = seq_lens - return feed_dict - - def build_compute_gradients(self, builder, postprocessed_batch): - builder.add_feed_dict(self.extra_compute_grad_feed_dict()) - builder.add_feed_dict({self._is_training: True}) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) - fetches = builder.add_fetches( - [self._grads, self.extra_compute_grad_fetches()]) - return fetches[0], fetches[1] - + @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_gradients") - fetches = self.build_compute_gradients(builder, postprocessed_batch) + fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches) - def build_apply_gradients(self, builder, gradients): - assert len(gradients) == len(self._grads), (gradients, self._grads) - builder.add_feed_dict(self.extra_apply_grad_feed_dict()) - builder.add_feed_dict({self._is_training: True}) - builder.add_feed_dict(dict(zip(self._grads, gradients))) - fetches = builder.add_fetches( - [self._apply_op, self.extra_apply_grad_fetches()]) - return fetches[1] - + @override(PolicyGraph) def apply_gradients(self, gradients): builder = TFRunBuilder(self._sess, "apply_gradients") - fetches = self.build_apply_gradients(builder, gradients) + fetches = self._build_apply_gradients(builder, gradients) return builder.get(fetches) - def build_compute_apply(self, builder, postprocessed_batch): - builder.add_feed_dict(self.extra_compute_grad_feed_dict()) - builder.add_feed_dict(self.extra_apply_grad_feed_dict()) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) - builder.add_feed_dict({self._is_training: True}) - fetches = builder.add_fetches([ - self._apply_op, - self.extra_compute_grad_fetches(), - self.extra_apply_grad_fetches() - ]) - return fetches[1], fetches[2] - + @override(PolicyGraph) def compute_apply(self, postprocessed_batch): builder = TFRunBuilder(self._sess, "compute_apply") - fetches = self.build_compute_apply(builder, postprocessed_batch) + fetches = self._build_compute_apply(builder, postprocessed_batch) return builder.get(fetches) + @override(PolicyGraph) def get_weights(self): return self._variables.get_flat() + @override(PolicyGraph) def set_weights(self, weights): return self._variables.set_flat(weights) + @override(PolicyGraph) + def export_model(self, export_dir): + """Export tensorflow graph to export_dir for serving.""" + with self._sess.graph.as_default(): + builder = tf.saved_model.builder.SavedModelBuilder(export_dir) + signature_def_map = self._build_signature_def() + builder.add_meta_graph_and_variables( + self._sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map) + builder.save() + + @override(PolicyGraph) + def export_checkpoint(self, export_dir, filename_prefix="model"): + """Export tensorflow checkpoint to export_dir.""" + save_path = os.path.join(export_dir, filename_prefix) + with self._sess.graph.as_default(): + saver = tf.train.Saver() + saver.save(self._sess, save_path) + + def copy(self, existing_inputs): + """Creates a copy of self using existing input placeholders. + + Optional, only required to work with the multi-GPU optimizer.""" + raise NotImplementedError + def extra_compute_action_feed_dict(self): + """Extra dict to pass to the compute actions session run.""" return {} def extra_compute_action_fetches(self): + """Extra values to fetch and return from compute_actions().""" return {} # e.g, value function def extra_compute_grad_feed_dict(self): + """Extra dict to pass to the compute gradients session run.""" return {} # e.g, kl_coeff def extra_compute_grad_fetches(self): + """Extra values to fetch and return from compute_gradients().""" return {} # e.g, td error def extra_apply_grad_feed_dict(self): + """Extra dict to pass to the apply gradients session run.""" return {} def extra_apply_grad_fetches(self): + """Extra values to fetch and return from apply_gradients().""" return {} # e.g., batch norm updates + def _extra_input_signature_def(self): + """Extra input signatures to add when exporting tf model. + Inferred from extra_compute_action_feed_dict() + """ + feed_dict = self.extra_compute_action_feed_dict() + return { + k.name: tf.saved_model.utils.build_tensor_info(k) + for k in feed_dict.keys() + } + + def _extra_output_signature_def(self): + """Extra output signatures to add when exporting tf model. + Inferred from extra_compute_action_fetches() + """ + fetches = self.extra_compute_action_fetches() + return { + k: tf.saved_model.utils.build_tensor_info(fetches[k]) + for k in fetches.keys() + } + def optimizer(self): + """TF optimizer to use for policy optimization.""" return tf.train.AdamOptimizer() def gradients(self, optimizer): + """Override for custom gradient computation.""" return optimizer.compute_gradients(self._loss) - def loss_inputs(self): - return self._loss_inputs + def _build_signature_def(self): + """Build signature def map for tensorflow SavedModelBuilder. + """ + # build input signatures + input_signature = self._extra_input_signature_def() + input_signature["observations"] = \ + tf.saved_model.utils.build_tensor_info(self._obs_input) + + if self._seq_lens is not None: + input_signature["seq_lens"] = \ + tf.saved_model.utils.build_tensor_info(self._seq_lens) + if self._prev_action_input is not None: + input_signature["prev_action"] = \ + tf.saved_model.utils.build_tensor_info(self._prev_action_input) + if self._prev_reward_input is not None: + input_signature["prev_reward"] = \ + tf.saved_model.utils.build_tensor_info(self._prev_reward_input) + input_signature["is_training"] = \ + tf.saved_model.utils.build_tensor_info(self._is_training) + + for state_input in self._state_inputs: + input_signature[state_input.name] = \ + tf.saved_model.utils.build_tensor_info(state_input) + + # build output signatures + output_signature = self._extra_output_signature_def() + output_signature["actions"] = \ + tf.saved_model.utils.build_tensor_info(self._sampler) + for state_output in self._state_outputs: + output_signature[state_output.name] = \ + tf.saved_model.utils.build_tensor_info(state_output) + signature_def = ( + tf.saved_model.signature_def_utils.build_signature_def( + input_signature, output_signature, + tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) + signature_def_key = \ + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501 + signature_def_map = {signature_def_key: signature_def} + return signature_def_map + + def _build_compute_actions(self, + builder, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + episodes=None): + state_batches = state_batches or [] + assert len(self._state_inputs) == len(state_batches), \ + (self._state_inputs, state_batches) + builder.add_feed_dict(self.extra_compute_action_feed_dict()) + builder.add_feed_dict({self._obs_input: obs_batch}) + if state_batches: + builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) + if self._prev_action_input is not None and prev_action_batch: + builder.add_feed_dict({self._prev_action_input: prev_action_batch}) + if self._prev_reward_input is not None and prev_reward_batch: + builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) + builder.add_feed_dict({self._is_training: False}) + builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) + fetches = builder.add_fetches([self._sampler] + self._state_outputs + + [self.extra_compute_action_fetches()]) + return fetches[0], fetches[1:-1], fetches[-1] + + def _build_compute_gradients(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + fetches = builder.add_fetches( + [self._grads, self.extra_compute_grad_fetches()]) + return fetches[0], fetches[1] + + def _build_apply_gradients(self, builder, gradients): + assert len(gradients) == len(self._grads), (gradients, self._grads) + builder.add_feed_dict(self.extra_apply_grad_feed_dict()) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(dict(zip(self._grads, gradients))) + fetches = builder.add_fetches( + [self._apply_op, self.extra_apply_grad_fetches()]) + return fetches[1] + + def _build_compute_apply(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict(self.extra_apply_grad_feed_dict()) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + builder.add_feed_dict({self._is_training: True}) + fetches = builder.add_fetches([ + self._apply_op, + self.extra_compute_grad_fetches(), + self.extra_apply_grad_fetches() + ]) + return fetches[1], fetches[2] def _get_is_training_placeholder(self): """Get the placeholder for _is_training, i.e., for batch norm layers. @@ -308,6 +368,46 @@ def _get_is_training_placeholder(self): self._is_training = tf.placeholder_with_default(False, ()) return self._is_training + def _get_loss_inputs_dict(self, batch): + feed_dict = {} + if self._batch_divisibility_req > 1: + meets_divisibility_reqs = ( + len(batch["obs"]) % self._batch_divisibility_req == 0 + and max(batch["agent_index"]) == 0) # not multiagent + else: + meets_divisibility_reqs = True + + # Simple case: not RNN nor do we need to pad + if not self._state_inputs and meets_divisibility_reqs: + for k, ph in self._loss_inputs: + feed_dict[ph] = batch[k] + return feed_dict + + if self._state_inputs: + max_seq_len = self._max_seq_len + dynamic_max = True + else: + max_seq_len = self._batch_divisibility_req + dynamic_max = False + + # RNN or multi-agent case + feature_keys = [k for k, v in self._loss_inputs] + state_keys = [ + "state_in_{}".format(i) for i in range(len(self._state_inputs)) + ] + feature_sequences, initial_states, seq_lens = chop_into_sequences( + batch["eps_id"], + batch["agent_index"], [batch[k] for k in feature_keys], + [batch[k] for k in state_keys], + max_seq_len, + dynamic_max=dynamic_max) + for k, v in zip(feature_keys, feature_sequences): + feed_dict[self._loss_input_dict[k]] = v + for k, v in zip(state_keys, initial_states): + feed_dict[self._loss_input_dict[k]] = v + feed_dict[self._seq_lens] = seq_lens + return feed_dict + class LearningRateSchedule(object): """Mixin for TFPolicyGraph that adds a learning rate schedule.""" @@ -320,11 +420,13 @@ def __init__(self, lr, lr_schedule): self.lr_schedule = PiecewiseSchedule( lr_schedule, outside_value=lr_schedule[-1][-1]) + @override(PolicyGraph) def on_global_var_update(self, global_vars): super(LearningRateSchedule, self).on_global_var_update(global_vars) self.cur_lr.load( self.lr_schedule.value(global_vars["timestep"]), session=self._sess) + @override(TFPolicyGraph) def optimizer(self): return tf.train.AdamOptimizer(self.cur_lr) diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index a762927bab442..a64ac3e74352c 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -13,6 +13,7 @@ pass # soft dep from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.utils.annotations import override class TorchPolicyGraph(PolicyGraph): @@ -56,23 +57,15 @@ def __init__(self, observation_space, action_space, model, loss, self._loss_inputs = loss_inputs self._optimizer = self.optimizer() - def extra_action_out(self, model_out): - """Returns dict of extra info to include in experience batch. - - Arguments: - model_out (list): Outputs of the policy model module.""" - return {} - - def optimizer(self): - """Custom PyTorch optimizer to use.""" - return torch.optim.Adam(self._model.parameters()) - + @override(PolicyGraph) def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, - episodes=None): + info_batch=None, + episodes=None, + **kwargs): if state_batches: raise NotImplementedError("Torch RNN support") with self.lock: @@ -83,6 +76,7 @@ def compute_actions(self, actions = F.softmax(logits, dim=1).multinomial(1).squeeze(0) return var_to_np(actions), [], self.extra_action_out(model_out) + @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): with self.lock: loss_in = [] @@ -96,6 +90,7 @@ def compute_gradients(self, postprocessed_batch): grads = [var_to_np(p.grad.data) for p in self._model.parameters()] return grads, {} + @override(PolicyGraph) def apply_gradients(self, gradients): with self.lock: for g, p in zip(gradients, self._model.parameters()): @@ -103,10 +98,23 @@ def apply_gradients(self, gradients): self._optimizer.step() return {} + @override(PolicyGraph) def get_weights(self): with self.lock: return self._model.state_dict() + @override(PolicyGraph) def set_weights(self, weights): with self.lock: self._model.load_state_dict(weights) + + def extra_action_out(self, model_out): + """Returns dict of extra info to include in experience batch. + + Arguments: + model_out (list): Outputs of the policy model module.""" + return {} + + def optimizer(self): + """Custom PyTorch optimizer to use.""" + return torch.optim.Adam(self._model.parameters()) diff --git a/python/ray/rllib/examples/carla/a3c_lane_keep.py b/python/ray/rllib/examples/carla/a3c_lane_keep.py deleted file mode 100644 index 9629808ba4c7c..0000000000000 --- a/python/ray/rllib/examples/carla/a3c_lane_keep.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import register_env, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import LANE_KEEP - -env_name = "carla_env" -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "reward_function": "lane_keep", - "enable_planner": False, - "scenarios": [LANE_KEEP], -}) - -register_env(env_name, lambda env_config: CarlaEnv(env_config)) -register_carla_model() - -ray.init() -run_experiments({ - "carla-a3c": { - "run": "A3C", - "env": "carla_env", - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "gamma": 0.8, - "num_workers": 1, - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/dqn_lane_keep.py b/python/ray/rllib/examples/carla/dqn_lane_keep.py deleted file mode 100644 index 84fed98cd5f90..0000000000000 --- a/python/ray/rllib/examples/carla/dqn_lane_keep.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import register_env, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import LANE_KEEP - -env_name = "carla_env" -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": True, - "server_map": "/Game/Maps/Town02", - "reward_function": "lane_keep", - "enable_planner": False, - "scenarios": [LANE_KEEP], -}) - -register_env(env_name, lambda env_config: CarlaEnv(env_config)) -register_carla_model() - -ray.init() -run_experiments({ - "carla-dqn": { - "run": "DQN", - "env": "carla_env", - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "timesteps_per_iteration": 100, - "learning_starts": 1000, - "schedule_max_timesteps": 100000, - "gamma": 0.8, - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/ppo_lane_keep.py b/python/ray/rllib/examples/carla/ppo_lane_keep.py deleted file mode 100644 index ac0f6ff8aff0e..0000000000000 --- a/python/ray/rllib/examples/carla/ppo_lane_keep.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import register_env, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import LANE_KEEP - -env_name = "carla_env" -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "reward_function": "lane_keep", - "enable_planner": False, - "scenarios": [LANE_KEEP], -}) - -register_env(env_name, lambda env_config: CarlaEnv(env_config)) -register_carla_model() - -ray.init() -run_experiments({ - "carla-ppo": { - "run": "PPO", - "env": "carla_env", - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "num_workers": 1, - "timesteps_per_batch": 2000, - "min_steps_per_task": 100, - "lambda": 0.95, - "clip_param": 0.2, - "num_sgd_iter": 20, - "sgd_stepsize": 0.0001, - "sgd_batchsize": 32, - "devices": ["/gpu:0"], - "tf_session_args": { - "gpu_options": { - "allow_growth": True - } - } - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/train_a3c.py b/python/ray/rllib/examples/carla/train_a3c.py index 2c12cd8245cfe..8fbcfbc576d1e 100644 --- a/python/ray/rllib/examples/carla/train_a3c.py +++ b/python/ray/rllib/examples/carla/train_a3c.py @@ -3,13 +3,12 @@ from __future__ import print_function import ray -from ray.tune import grid_search, register_env, run_experiments +from ray.tune import grid_search, run_experiments from env import CarlaEnv, ENV_CONFIG from models import register_carla_model from scenarios import TOWN2_STRAIGHT -env_name = "carla_env" env_config = ENV_CONFIG.copy() env_config.update({ "verbose": False, @@ -23,7 +22,6 @@ "scenarios": TOWN2_STRAIGHT, }) -register_env(env_name, lambda env_config: CarlaEnv(env_config)) register_carla_model() redis_address = ray.services.get_node_ip_address() + ":6379" @@ -31,7 +29,7 @@ run_experiments({ "carla-a3c": { "run": "A3C", - "env": "carla_env", + "env": CarlaEnv, "config": { "env_config": env_config, "use_gpu_for_workers": True, diff --git a/python/ray/rllib/examples/carla/train_dqn.py b/python/ray/rllib/examples/carla/train_dqn.py index fa2dba1053aa5..27aa65444d386 100644 --- a/python/ray/rllib/examples/carla/train_dqn.py +++ b/python/ray/rllib/examples/carla/train_dqn.py @@ -3,13 +3,12 @@ from __future__ import print_function import ray -from ray.tune import register_env, run_experiments +from ray.tune import run_experiments from env import CarlaEnv, ENV_CONFIG from models import register_carla_model from scenarios import TOWN2_ONE_CURVE -env_name = "carla_env" env_config = ENV_CONFIG.copy() env_config.update({ "verbose": False, @@ -21,7 +20,6 @@ "scenarios": TOWN2_ONE_CURVE, }) -register_env(env_name, lambda env_config: CarlaEnv(env_config)) register_carla_model() ray.init() @@ -35,7 +33,7 @@ def shape_out(spec): run_experiments({ "carla-dqn": { "run": "DQN", - "env": "carla_env", + "env": CarlaEnv, "config": { "env_config": env_config, "model": { diff --git a/python/ray/rllib/examples/carla/train_ppo.py b/python/ray/rllib/examples/carla/train_ppo.py index a9339ca794819..6c49240142c26 100644 --- a/python/ray/rllib/examples/carla/train_ppo.py +++ b/python/ray/rllib/examples/carla/train_ppo.py @@ -3,13 +3,12 @@ from __future__ import print_function import ray -from ray.tune import register_env, run_experiments +from ray.tune import run_experiments from env import CarlaEnv, ENV_CONFIG from models import register_carla_model from scenarios import TOWN2_STRAIGHT -env_name = "carla_env" env_config = ENV_CONFIG.copy() env_config.update({ "verbose": False, @@ -20,14 +19,13 @@ "server_map": "/Game/Maps/Town02", "scenarios": TOWN2_STRAIGHT, }) -register_env(env_name, lambda env_config: CarlaEnv(env_config)) register_carla_model() ray.init(redirect_output=True) run_experiments({ "carla": { "run": "PPO", - "env": "carla_env", + "env": CarlaEnv, "config": { "env_config": env_config, "model": { diff --git a/python/ray/rllib/examples/custom_env.py b/python/ray/rllib/examples/custom_env.py index 66c0288081f9c..0d96eef6acb64 100644 --- a/python/ray/rllib/examples/custom_env.py +++ b/python/ray/rllib/examples/custom_env.py @@ -11,7 +11,6 @@ import ray from ray.tune import run_experiments -from ray.tune.registry import register_env class SimpleCorridor(gym.Env): @@ -42,13 +41,13 @@ def step(self, action): if __name__ == "__main__": - env_creator_name = "corridor" - register_env(env_creator_name, lambda config: SimpleCorridor(config)) + # Can also register the env creator function explicitly with: + # register_env("corridor", lambda config: SimpleCorridor(config)) ray.init() run_experiments({ "demo": { "run": "PPO", - "env": "corridor", + "env": SimpleCorridor, # or "corridor" if registered above "config": { "env_config": { "corridor_length": 5, diff --git a/python/ray/rllib/examples/custom_metrics_and_callbacks.py b/python/ray/rllib/examples/custom_metrics_and_callbacks.py index eec7bffb571fd..af1d25f16cadf 100644 --- a/python/ray/rllib/examples/custom_metrics_and_callbacks.py +++ b/python/ray/rllib/examples/custom_metrics_and_callbacks.py @@ -25,16 +25,23 @@ def on_episode_step(info): def on_episode_end(info): episode = info["episode"] - mean_pole_angle = np.mean(episode.user_data["pole_angles"]) + pole_angle = np.mean(episode.user_data["pole_angles"]) print("episode {} ended with length {} and pole angles {}".format( - episode.episode_id, episode.length, mean_pole_angle)) - episode.custom_metrics["mean_pole_angle"] = mean_pole_angle + episode.episode_id, episode.length, pole_angle)) + episode.custom_metrics["pole_angle"] = pole_angle def on_sample_end(info): print("returned sample batch of size {}".format(info["samples"].count)) +def on_train_result(info): + print("agent.train() result: {} -> {} episodes".format( + info["agent"], info["result"]["episodes_this_iter"])) + # you can mutate the result dict to add new fields to return + info["result"]["callback_ok"] = True + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num-iters", type=int, default=2000) @@ -54,6 +61,7 @@ def on_sample_end(info): "on_episode_step": tune.function(on_episode_step), "on_episode_end": tune.function(on_episode_end), "on_sample_end": tune.function(on_sample_end), + "on_train_result": tune.function(on_train_result), }, }, } @@ -62,5 +70,8 @@ def on_sample_end(info): # verify custom metrics for integration tests custom_metrics = trials[0].last_result["custom_metrics"] print(custom_metrics) - assert "mean_pole_angle" in custom_metrics - assert type(custom_metrics["mean_pole_angle"]) is float + assert "pole_angle_mean" in custom_metrics + assert "pole_angle_min" in custom_metrics + assert "pole_angle_max" in custom_metrics + assert type(custom_metrics["pole_angle_mean"]) is float + assert "callback_ok" in trials[0].last_result diff --git a/python/ray/rllib/examples/export/cartpole_dqn_export.py b/python/ray/rllib/examples/export/cartpole_dqn_export.py new file mode 100644 index 0000000000000..6bfcae060d136 --- /dev/null +++ b/python/ray/rllib/examples/export/cartpole_dqn_export.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import ray +import tensorflow as tf + +from ray.rllib.agents.registry import get_agent_class + +ray.init(num_cpus=10) + + +def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix): + cls = get_agent_class(algo_name) + alg = cls(config={}, env="CartPole-v0") + for _ in range(num_steps): + alg.train() + + # Export tensorflow checkpoint for fine-tuning + alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix) + # Export tensorflow SavedModel for online serving + alg.export_policy_model(model_dir) + + +def restore_saved_model(export_dir): + signature_key = \ + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + g = tf.Graph() + with g.as_default(): + with tf.Session(graph=g) as sess: + meta_graph_def = \ + tf.saved_model.load(sess, + [tf.saved_model.tag_constants.SERVING], + export_dir) + print("Model restored!") + print("Signature Def Information:") + print(meta_graph_def.signature_def[signature_key]) + print("You can inspect the model using TensorFlow SavedModel CLI.") + print("https://www.tensorflow.org/guide/saved_model") + + +def restore_checkpoint(export_dir, prefix): + sess = tf.Session() + meta_file = "%s.meta" % prefix + saver = tf.train.import_meta_graph(os.path.join(export_dir, meta_file)) + saver.restore(sess, os.path.join(export_dir, prefix)) + print("Checkpoint restored!") + print("Variables Information:") + for v in tf.trainable_variables(): + value = sess.run(v) + print(v.name, value) + + +if __name__ == "__main__": + algo = "DQN" + model_dir = "/tmp/model_export_dir" + ckpt_dir = "/tmp/ckpt_export_dir" + prefix = "model.ckpt" + num_steps = 3 + train_and_export(algo, num_steps, model_dir, ckpt_dir, prefix) + restore_saved_model(model_dir) + restore_checkpoint(ckpt_dir, prefix) diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py index 87a8eb9282ff0..e2ab5270f9d87 100644 --- a/python/ray/rllib/examples/multiagent_cartpole.py +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -107,7 +107,8 @@ def gen_policy(i): "training_iteration": args.num_iters }, "config": { - "simple_optimizer": True, + "log_level": "DEBUG", + "num_sgd_iter": 10, "multiagent": { "policy_graphs": policy_graphs, "policy_mapping_fn": tune.function( diff --git a/python/ray/rllib/examples/multiagent_two_trainers.py b/python/ray/rllib/examples/multiagent_two_trainers.py index e2c8bc97a8c23..46831db452b6d 100644 --- a/python/ray/rllib/examples/multiagent_two_trainers.py +++ b/python/ray/rllib/examples/multiagent_two_trainers.py @@ -57,7 +57,6 @@ def policy_mapping_fn(agent_id): "policy_mapping_fn": policy_mapping_fn, "policies_to_train": ["ppo_policy"], }, - "simple_optimizer": True, # disable filters, otherwise we would need to synchronize those # as well to the DQN agent "observation_filter": "NoFilter", diff --git a/python/ray/rllib/examples/starcraft/README.rst b/python/ray/rllib/examples/starcraft/README.rst new file mode 100644 index 0000000000000..b10cc5d6f87d6 --- /dev/null +++ b/python/ray/rllib/examples/starcraft/README.rst @@ -0,0 +1,18 @@ +StarCraft on RLlib +================== + +This builds off the StarCraft env in https://github.com/oxwhirl/pymarl_alpha. + +Temporary instructions +---------------------- + +To install, run + +``` +git clone https://github.com/oxwhirl/pymarl_alpha +mv pymarl_alpha ~/pymarl +cd ~/pymarl +install_sc1.sh +install_sc2.sh +export PYMARL_PATH="~/pymarl" +``` diff --git a/python/ray/rllib/examples/starcraft/sc2.yaml b/python/ray/rllib/examples/starcraft/sc2.yaml new file mode 100644 index 0000000000000..db108e226bf53 --- /dev/null +++ b/python/ray/rllib/examples/starcraft/sc2.yaml @@ -0,0 +1,32 @@ +## Adapted from `https://github.com/oxwhirl/pymarl_alpha`. + +env: sc2 + +env_args: + map_name: "3m_3m" # SC2 map name + difficulty: "7" # Very hard + move_amount: 2 # How much units are ordered to move per step + step_mul: 8 # How many frames are skiped per step + reward_sparse: False # Only +1/-1 reward for win/defeat (the rest of reward configs are ignored if True) + reward_only_positive: True # Reward is always positive + reward_negative_scale: 0.5 # How much to scale negative rewards, ignored if reward_only_positive=True + reward_death_value: 10 # Reward for killing an enemy unit and penalty for having an allied unit killed (if reward_only_poitive=False) + reward_scale: True # Whether or not to scale rewards before returning to agents + reward_scale_rate: 20 # If reward_scale=True, the agents receive the reward of (max_reward / reward_scale_rate), where max_reward is the maximum possible reward per episode + reward_win: 200 # Reward for win + reward_defeat: 0 # Reward for defeat (should be nonpositive) + state_last_action: True # Whether the last actions of units is a part of the state + obs_instead_of_state: False # Use combination of all agnets' observations as state + obs_own_health: True # Whether agents receive their own health as a part of observation + obs_all_health: True # Whether agents receive the health of all units (in the sight range) as a part of observataion + continuing_episode: False # Stop/continue episode after its termination + game_version: "4.1.2" # Ignored for Mac/Windows + save_replay_prefix: "" # Prefix of the replay to be saved + heuristic: False # Whether or not use a simple nonlearning hearistic as a controller + +test_nepisode: 32 +test_interval: 10000 +log_interval: 2000 +runner_log_interval: 2000 +learner_log_interval: 2000 +t_max: 2000000 diff --git a/python/ray/rllib/examples/starcraft/starcraft_env.py b/python/ray/rllib/examples/starcraft/starcraft_env.py new file mode 100644 index 0000000000000..7cfd3f266109f --- /dev/null +++ b/python/ray/rllib/examples/starcraft/starcraft_env.py @@ -0,0 +1,153 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +from gym.spaces import Discrete, Box, Dict, Tuple +import os +import sys +import tensorflow as tf +import tensorflow.contrib.slim as slim +import yaml + +import ray +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.tune.registry import register_env +from ray.rllib.models import Model, ModelCatalog +from ray.rllib.models.misc import normc_initializer +from ray.rllib.agents.qmix import QMixAgent +from ray.rllib.agents.pg import PGAgent +from ray.rllib.agents.ppo import PPOAgent +from ray.tune.logger import pretty_print + + +class MaskedActionsModel(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + action_mask = input_dict["obs"]["action_mask"] + if num_outputs != action_mask.shape[1].value: + raise ValueError( + "This model assumes num outputs is equal to max avail actions", + num_outputs, action_mask) + + # Standard FC net component. + last_layer = input_dict["obs"]["obs"] + hiddens = [256, 256] + for i, size in enumerate(hiddens): + label = "fc{}".format(i) + last_layer = slim.fully_connected( + last_layer, + size, + weights_initializer=normc_initializer(1.0), + activation_fn=tf.nn.tanh, + scope=label) + action_logits = slim.fully_connected( + last_layer, + num_outputs, + weights_initializer=normc_initializer(0.01), + activation_fn=None, + scope="fc_out") + + # Mask out invalid actions (use tf.float32.min for stability) + inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) + masked_logits = inf_mask + action_logits + + return masked_logits, last_layer + + +class SC2MultiAgentEnv(MultiAgentEnv): + """RLlib Wrapper around StarCraft2.""" + + def __init__(self, override_cfg): + PYMARL_PATH = override_cfg.pop("pymarl_path") + os.environ["SC2PATH"] = os.path.join(PYMARL_PATH, + "3rdparty/StarCraftII") + sys.path.append(os.path.join(PYMARL_PATH, "src")) + from envs.starcraft2 import StarCraft2Env + curpath = os.path.dirname(os.path.abspath(__file__)) + with open(os.path.join(curpath, "sc2.yaml")) as f: + pymarl_args = yaml.load(f) + pymarl_args.update(override_cfg) + pymarl_args["env_args"].setdefault("seed", 0) + + self._starcraft_env = StarCraft2Env(**pymarl_args) + obs_size = self._starcraft_env.get_obs_size() + num_actions = self._starcraft_env.get_total_actions() + self.observation_space = Dict({ + "action_mask": Box(0, 1, shape=(num_actions, )), + "obs": Box(-1, 1, shape=(obs_size, )) + }) + self.action_space = Discrete(self._starcraft_env.get_total_actions()) + + def reset(self): + obs_list, state_list = self._starcraft_env.reset() + return_obs = {} + for i, obs in enumerate(obs_list): + return_obs[i] = { + "action_mask": self._starcraft_env.get_avail_agent_actions(i), + "obs": obs + } + return return_obs + + def step(self, action_dict): + # TODO(rliaw): Check to handle missing agents, if any + actions = [action_dict[k] for k in sorted(action_dict)] + rew, done, info = self._starcraft_env.step(actions) + obs_list = self._starcraft_env.get_obs() + return_obs = {} + for i, obs in enumerate(obs_list): + return_obs[i] = { + "action_mask": self._starcraft_env.get_avail_agent_actions(i), + "obs": obs + } + rews = {i: rew / len(obs_list) for i in range(len(obs_list))} + dones = {i: done for i in range(len(obs_list))} + dones["__all__"] = done + infos = {i: info for i in range(len(obs_list))} + return return_obs, rews, dones, infos + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-iters", type=int, default=100) + parser.add_argument("--run", type=str, default="qmix") + args = parser.parse_args() + + path_to_pymarl = os.environ.get("PYMARL_PATH", + os.path.expanduser("~/pymarl/")) + + ray.init() + ModelCatalog.register_custom_model("mask_model", MaskedActionsModel) + + register_env("starcraft", lambda cfg: SC2MultiAgentEnv(cfg)) + agent_cfg = { + "observation_filter": "NoFilter", + "num_workers": 4, + "model": { + "custom_model": "mask_model", + }, + "env_config": { + "pymarl_path": path_to_pymarl + } + } + if args.run.lower() == "qmix": + + def grouped_sc2(cfg): + env = SC2MultiAgentEnv(cfg) + agent_list = list(range(env._starcraft_env.n_agents)) + grouping = { + "group_1": agent_list, + } + obs_space = Tuple([env.observation_space for i in agent_list]) + act_space = Tuple([env.action_space for i in agent_list]) + return env.with_agent_groups( + grouping, obs_space=obs_space, act_space=act_space) + + register_env("grouped_starcraft", grouped_sc2) + agent = QMixAgent(env="grouped_starcraft", config=agent_cfg) + elif args.run.lower() == "pg": + agent = PGAgent(env="starcraft", config=agent_cfg) + elif args.run.lower() == "ppo": + agent_cfg.update({"vf_share_layers": True}) + agent = PPOAgent(env="starcraft", config=agent_cfg) + for i in range(args.num_iters): + print(pretty_print(agent.train())) diff --git a/python/ray/rllib/examples/twostep_game.py b/python/ray/rllib/examples/twostep_game.py new file mode 100644 index 0000000000000..63c860979a151 --- /dev/null +++ b/python/ray/rllib/examples/twostep_game.py @@ -0,0 +1,117 @@ +"""The two-step game from QMIX: https://arxiv.org/pdf/1803.11485.pdf""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +from gym.spaces import Tuple, Discrete + +import ray +from ray.tune import register_env, run_experiments, grid_search +from ray.rllib.env.multi_agent_env import MultiAgentEnv + +parser = argparse.ArgumentParser() +parser.add_argument("--stop", type=int, default=50000) +parser.add_argument("--run", type=str, default="QMIX") + + +class TwoStepGame(MultiAgentEnv): + action_space = Discrete(2) + + # Each agent gets a separate [3] obs space, to ensure that they can + # learn meaningfully different Q values even with a shared Q model. + observation_space = Discrete(6) + + def __init__(self, env_config): + self.state = None + + def reset(self): + self.state = 0 + return {"agent_1": self.state, "agent_2": self.state + 3} + + def step(self, action_dict): + if self.state == 0: + action = action_dict["agent_1"] + assert action in [0, 1], action + if action == 0: + self.state = 1 + else: + self.state = 2 + global_rew = 0 + done = False + elif self.state == 1: + global_rew = 7 + done = True + else: + if action_dict["agent_1"] == 0 and action_dict["agent_2"] == 0: + global_rew = 0 + elif action_dict["agent_1"] == 1 and action_dict["agent_2"] == 1: + global_rew = 8 + else: + global_rew = 1 + done = True + + rewards = {"agent_1": global_rew / 2.0, "agent_2": global_rew / 2.0} + obs = {"agent_1": self.state, "agent_2": self.state + 3} + dones = {"__all__": done} + infos = {} + return obs, rewards, dones, infos + + +if __name__ == "__main__": + args = parser.parse_args() + + grouping = { + "group_1": ["agent_1", "agent_2"], + } + obs_space = Tuple([ + TwoStepGame.observation_space, + TwoStepGame.observation_space, + ]) + act_space = Tuple([ + TwoStepGame.action_space, + TwoStepGame.action_space, + ]) + register_env( + "grouped_twostep", + lambda config: TwoStepGame(config).with_agent_groups( + grouping, obs_space=obs_space, act_space=act_space)) + + if args.run == "QMIX": + config = { + "sample_batch_size": 4, + "train_batch_size": 32, + "exploration_final_eps": 0.0, + "num_workers": 0, + "mixer": grid_search([None, "qmix", "vdn"]), + } + elif args.run == "APEX_QMIX": + config = { + "num_gpus": 0, + "num_workers": 2, + "optimizer": { + "num_replay_buffer_shards": 1, + }, + "min_iter_time_s": 3, + "buffer_size": 1000, + "learning_starts": 1000, + "train_batch_size": 128, + "sample_batch_size": 32, + "target_network_update_freq": 500, + "timesteps_per_iteration": 1000, + } + else: + config = {} + + ray.init() + run_experiments({ + "two_step": { + "run": args.run, + "env": "grouped_twostep", + "stop": { + "timesteps_total": args.stop, + }, + "config": config, + }, + }) diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 75a43deeb7894..f2a69efaf9b03 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -4,10 +4,11 @@ from collections import namedtuple import distutils.version - import tensorflow as tf import numpy as np +from ray.rllib.utils.annotations import override + use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion("1.5.0")) @@ -42,10 +43,12 @@ def sample(self): class Categorical(ActionDistribution): """Categorical distribution for discrete action spaces.""" + @override(ActionDistribution) def logp(self, x): return -tf.nn.sparse_softmax_cross_entropy_with_logits( logits=self.inputs, labels=x) + @override(ActionDistribution) def entropy(self): if use_tf150_api: a0 = self.inputs - tf.reduce_max( @@ -61,6 +64,7 @@ def entropy(self): p0 = ea0 / z0 return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1]) + @override(ActionDistribution) def kl(self, other): if use_tf150_api: a0 = self.inputs - tf.reduce_max( @@ -84,6 +88,7 @@ def kl(self, other): return tf.reduce_sum( p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), reduction_indices=[1]) + @override(ActionDistribution) def sample(self): return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1) @@ -95,28 +100,21 @@ class DiagGaussian(ActionDistribution): second half the gaussian standard deviations. """ - def __init__(self, inputs, low=None, high=None): + def __init__(self, inputs): ActionDistribution.__init__(self, inputs) mean, log_std = tf.split(inputs, 2, axis=1) self.mean = mean - self.low = low - self.high = high - - # Squash to range if specified. We use a sigmoid here this to avoid the - # mean drifting too far past the bounds and causing nan outputs. - # https://github.com/ray-project/ray/issues/1862 - if low is not None: - self.mean = low + tf.sigmoid(self.mean) * (high - low) - self.log_std = log_std self.std = tf.exp(log_std) + @override(ActionDistribution) def logp(self, x): return (-0.5 * tf.reduce_sum( tf.square((x - self.mean) / self.std), reduction_indices=[1]) - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - tf.reduce_sum(self.log_std, reduction_indices=[1])) + @override(ActionDistribution) def kl(self, other): assert isinstance(other, DiagGaussian) return tf.reduce_sum( @@ -125,16 +123,15 @@ def kl(self, other): (2.0 * tf.square(other.std)) - 0.5, reduction_indices=[1]) + @override(ActionDistribution) def entropy(self): return tf.reduce_sum( .5 * self.log_std + .5 * np.log(2.0 * np.pi * np.e), reduction_indices=[1]) + @override(ActionDistribution) def sample(self): - out = self.mean + self.std * tf.random_normal(tf.shape(self.mean)) - if self.low is not None: - out = tf.clip_by_value(out, self.low, self.high) - return out + return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) class Deterministic(ActionDistribution): @@ -143,38 +140,11 @@ class Deterministic(ActionDistribution): This is similar to DiagGaussian with standard deviation zero. """ + @override(ActionDistribution) def sample(self): return self.inputs -def squash_to_range(dist_cls, low, high): - """Squashes an action distribution to a range in (low, high). - - Arguments: - dist_cls (class): ActionDistribution class to wrap. - low (float|array): Scalar value or array of values. - high (float|array): Scalar value or array of values. - """ - - class SquashToRangeWrapper(dist_cls): - def __init__(self, inputs): - dist_cls.__init__(self, inputs, low=low, high=high) - - def logp(self, x): - return dist_cls.logp(self, x) - - def kl(self, other): - return dist_cls.kl(self, other) - - def entropy(self): - return dist_cls.entropy(self) - - def sample(self): - return dist_cls.sample(self) - - return SquashToRangeWrapper - - class MultiActionDistribution(ActionDistribution): """Action distribution that operates for list of actions. @@ -190,8 +160,8 @@ def __init__(self, inputs, action_space, child_distributions, input_lens): child_list.append(distribution(split_inputs[i])) self.child_distributions = child_list + @override(ActionDistribution) def logp(self, x): - """The log-likelihood of the action distribution.""" split_indices = [] for dist in self.child_distributions: if isinstance(dist, Categorical): @@ -210,8 +180,8 @@ def logp(self, x): ]) return np.sum(log_list) + @override(ActionDistribution) def kl(self, other): - """The KL-divergence between two action distributions.""" kl_list = np.asarray([ distribution.kl(other_distribution) for distribution, other_distribution in zip( @@ -219,15 +189,14 @@ def kl(self, other): ]) return np.sum(kl_list) + @override(ActionDistribution) def entropy(self): - """The entropy of the action distribution.""" entropy_list = np.array( [s.entropy() for s in self.child_distributions]) return np.sum(entropy_list) + @override(ActionDistribution) def sample(self): - """Draw a sample from the action distribution.""" - return TupleActions([s.sample() for s in self.child_distributions]) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 63a7e73890ccf..b3a84cef30773 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -11,12 +11,8 @@ from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ _global_registry -from ray.rllib.env.async_vector_env import _ExternalEnvToAsync -from ray.rllib.env.external_env import ExternalEnv -from ray.rllib.env.vector_env import VectorEnv from ray.rllib.models.action_dist import ( - Categorical, Deterministic, DiagGaussian, MultiActionDistribution, - squash_to_range) + Categorical, Deterministic, DiagGaussian, MultiActionDistribution) from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork @@ -38,7 +34,7 @@ "fcnet_hiddens": [256, 256], # For control envs, documented in ray.rllib.models.Model "free_log_std": False, - # Whether to squash the action output to space range + # (deprecated) Whether to use sigmoid to squash actions to space range "squash_to_range": False, # == LSTM == @@ -114,8 +110,9 @@ def get_action_dist(action_space, config, dist_type=None): if dist_type is None: dist = DiagGaussian if config.get("squash_to_range"): - dist = squash_to_range(dist, action_space.low, - action_space.high) + raise ValueError( + "The squash_to_range option is deprecated. See the " + "clip_actions agent option instead.") return dist, action_space.shape[0] * 2 elif dist_type == "deterministic": return Deterministic, action_space.shape[0] @@ -271,15 +268,26 @@ def get_torch_model(input_shape, num_outputs, options=None): @staticmethod def get_preprocessor(env, options=None): - """Returns a suitable processor for the given environment. + """Returns a suitable preprocessor for the given env. + + This is a wrapper for get_preprocessor_for_space(). + """ + + return ModelCatalog.get_preprocessor_for_space(env.observation_space, + options) + + @staticmethod + def get_preprocessor_for_space(observation_space, options=None): + """Returns a suitable preprocessor for the given observation space. Args: - env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap. + observation_space (Space): The input observation space. options (dict): Options to pass to the preprocessor. Returns: - preprocessor (Preprocessor): Preprocessor for the env observations. + preprocessor (Preprocessor): Preprocessor for the observations. """ + options = options or MODEL_DEFAULTS for k in options.keys(): if k not in MODEL_DEFAULTS: @@ -290,38 +298,15 @@ def get_preprocessor(env, options=None): preprocessor = options["custom_preprocessor"] logger.info("Using custom preprocessor {}".format(preprocessor)) prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)( - env.observation_space, options) + observation_space, options) else: - cls = get_preprocessor(env.observation_space) - prep = cls(env.observation_space, options) + cls = get_preprocessor(observation_space) + prep = cls(observation_space, options) logger.debug("Created preprocessor {}: {} -> {}".format( - prep, env.observation_space, prep.shape)) + prep, observation_space, prep.shape)) return prep - @staticmethod - def get_preprocessor_as_wrapper(env, options=None): - """Returns a preprocessor as a gym observation wrapper. - - Args: - env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap. - options (dict): Options to pass to the preprocessor. - - Returns: - env (RLlib env): Wrapped environment - """ - - options = options or MODEL_DEFAULTS - preprocessor = ModelCatalog.get_preprocessor(env, options) - if isinstance(env, gym.Env): - return _RLlibPreprocessorWrapper(env, preprocessor) - elif isinstance(env, VectorEnv): - return _RLlibVectorPreprocessorWrapper(env, preprocessor) - elif isinstance(env, ExternalEnv): - return _ExternalEnvToAsync(env, preprocessor) - else: - raise ValueError("Don't know how to wrap {}".format(env)) - @staticmethod def register_custom_preprocessor(preprocessor_name, preprocessor_class): """Register a custom preprocessor class by name. @@ -348,40 +333,3 @@ def register_custom_model(model_name, model_class): model_class (type): Python class of the model. """ _global_registry.register(RLLIB_MODEL, model_name, model_class) - - -class _RLlibPreprocessorWrapper(gym.ObservationWrapper): - """Adapts a RLlib preprocessor for use as an observation wrapper.""" - - def __init__(self, env, preprocessor): - super(_RLlibPreprocessorWrapper, self).__init__(env) - self.preprocessor = preprocessor - self.observation_space = preprocessor.observation_space - - def observation(self, observation): - return self.preprocessor.transform(observation) - - -class _RLlibVectorPreprocessorWrapper(VectorEnv): - """Preprocessing wrapper for vector envs.""" - - def __init__(self, env, preprocessor): - self.env = env - self.prep = preprocessor - self.action_space = env.action_space - self.observation_space = preprocessor.observation_space - self.num_envs = env.num_envs - - def vector_reset(self): - return [self.prep.transform(obs) for obs in self.env.vector_reset()] - - def reset_at(self, index): - return self.prep.transform(self.env.reset_at(index)) - - def vector_step(self, actions): - obs, rewards, dones, infos = self.env.vector_step(actions) - obs = [self.prep.transform(o) for o in obs] - return obs, rewards, dones, infos - - def get_unwrapped(self): - return self.env.get_unwrapped() diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index 5a759fd59ef8a..19745b9e7a3ca 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -7,11 +7,13 @@ from ray.rllib.models.model import Model from ray.rllib.models.misc import normc_initializer, get_activation_fn +from ray.rllib.utils.annotations import override class FullyConnectedNetwork(Model): """Generic fully connected network.""" + @override(Model) def _build_layers(self, inputs, num_outputs, options): """Process the flattened inputs. diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index fdb7af6021001..d87cc5c4b4d19 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -23,6 +23,72 @@ from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.models.model import Model +from ray.rllib.utils.annotations import override + + +class LSTM(Model): + """Adds a LSTM cell on top of some other model output. + + Uses a linear layer at the end for output. + + Important: we assume inputs is a padded batch of sequences denoted by + self.seq_lens. See add_time_dimension() for more information. + """ + + @override(Model) + def _build_layers_v2(self, input_dict, num_outputs, options): + cell_size = options.get("lstm_cell_size") + if options.get("lstm_use_prev_action_reward"): + action_dim = int( + np.product( + input_dict["prev_actions"].get_shape().as_list()[1:])) + features = tf.concat( + [ + input_dict["obs"], + tf.reshape( + tf.cast(input_dict["prev_actions"], tf.float32), + [-1, action_dim]), + tf.reshape(input_dict["prev_rewards"], [-1, 1]), + ], + axis=1) + else: + features = input_dict["obs"] + last_layer = add_time_dimension(features, self.seq_lens) + + # Setup the LSTM cell + lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) + self.state_init = [ + np.zeros(lstm.state_size.c, np.float32), + np.zeros(lstm.state_size.h, np.float32) + ] + + # Setup LSTM inputs + if self.state_in: + c_in, h_in = self.state_in + else: + c_in = tf.placeholder( + tf.float32, [None, lstm.state_size.c], name="c") + h_in = tf.placeholder( + tf.float32, [None, lstm.state_size.h], name="h") + self.state_in = [c_in, h_in] + + # Setup LSTM outputs + state_in = rnn.LSTMStateTuple(c_in, h_in) + lstm_out, lstm_state = tf.nn.dynamic_rnn( + lstm, + last_layer, + initial_state=state_in, + sequence_length=self.seq_lens, + time_major=False, + dtype=tf.float32) + + self.state_out = list(lstm_state) + + # Compute outputs + last_layer = tf.reshape(lstm_out, [-1, cell_size]) + logits = linear(last_layer, num_outputs, "action", + normc_initializer(0.01)) + return logits, last_layer def add_time_dimension(padded_inputs, seq_lens): @@ -57,7 +123,8 @@ def chop_into_sequences(episode_ids, feature_columns, state_columns, max_seq_len, - dynamic_max=True): + dynamic_max=True, + _extra_padding=0): """Truncate and pad experiences into fixed-length sequences. Arguments: @@ -70,6 +137,7 @@ def chop_into_sequences(episode_ids, dynamic_max (bool): Whether to dynamically shrink the max seq len. For example, if max len is 20 and the actual max seq len in the data is 7, it will be shrunk to 7. + _extra_padding (int): Add extra padding to the end of sequences. Returns: f_pad (list): Padded feature columns. These will be of shape @@ -111,7 +179,7 @@ def chop_into_sequences(episode_ids, # Dynamically shrink max len as needed to optimize memory usage if dynamic_max: - max_seq_len = max(seq_lens) + max_seq_len = max(seq_lens) + _extra_padding feature_sequences = [] for f in feature_columns: @@ -138,67 +206,3 @@ def chop_into_sequences(episode_ids, initial_states.append(np.array(s_init)) return feature_sequences, initial_states, np.array(seq_lens) - - -class LSTM(Model): - """Adds a LSTM cell on top of some other model output. - - Uses a linear layer at the end for output. - - Important: we assume inputs is a padded batch of sequences denoted by - self.seq_lens. See add_time_dimension() for more information. - """ - - def _build_layers_v2(self, input_dict, num_outputs, options): - cell_size = options.get("lstm_cell_size") - if options.get("lstm_use_prev_action_reward"): - action_dim = int( - np.product( - input_dict["prev_actions"].get_shape().as_list()[1:])) - features = tf.concat( - [ - input_dict["obs"], - tf.reshape( - tf.cast(input_dict["prev_actions"], tf.float32), - [-1, action_dim]), - tf.reshape(input_dict["prev_rewards"], [-1, 1]), - ], - axis=1) - else: - features = input_dict["obs"] - last_layer = add_time_dimension(features, self.seq_lens) - - # Setup the LSTM cell - lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) - self.state_init = [ - np.zeros(lstm.state_size.c, np.float32), - np.zeros(lstm.state_size.h, np.float32) - ] - - # Setup LSTM inputs - if self.state_in: - c_in, h_in = self.state_in - else: - c_in = tf.placeholder( - tf.float32, [None, lstm.state_size.c], name="c") - h_in = tf.placeholder( - tf.float32, [None, lstm.state_size.h], name="h") - self.state_in = [c_in, h_in] - - # Setup LSTM outputs - state_in = rnn.LSTMStateTuple(c_in, h_in) - lstm_out, lstm_state = tf.nn.dynamic_rnn( - lstm, - last_layer, - initial_state=state_in, - sequence_length=self.seq_lens, - time_major=False, - dtype=tf.float32) - - self.state_out = list(lstm_state) - - # Compute outputs - last_layer = tf.reshape(lstm_out, [-1, cell_size]) - logits = linear(last_layer, num_outputs, "action", - normc_initializer(0.01)) - return logits, last_layer diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index 561b636dc863e..cda435e3ef4bf 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -82,19 +82,6 @@ def __init__(self, self.outputs = tf.concat( [self.outputs, 0.0 * self.outputs + log_std], 1) - def _validate_output_shape(self): - """Checks that the model has the correct number of outputs.""" - try: - out = tf.convert_to_tensor(self.outputs) - shape = out.shape.as_list() - except Exception: - raise ValueError("Output is not a tensor: {}".format(self.outputs)) - else: - if len(shape) != 2 or shape[1] != self._num_outputs: - raise ValueError( - "Expected output shape of [None, {}], got {}".format( - self._num_outputs, shape)) - def _build_layers(self, inputs, num_outputs, options): """Builds and returns the output and last layer of the network. @@ -159,6 +146,19 @@ def loss(self): """ return tf.constant(0.0) + def _validate_output_shape(self): + """Checks that the model has the correct number of outputs.""" + try: + out = tf.convert_to_tensor(self.outputs) + shape = out.shape.as_list() + except Exception: + raise ValueError("Output is not a tensor: {}".format(self.outputs)) + else: + if len(shape) != 2 or shape[1] != self._num_outputs: + raise ValueError( + "Expected output shape of [None, {}], got {}".format( + self._num_outputs, shape)) + def _restore_original_dimensions(input_dict, obs_space): if hasattr(obs_space, "original_space"): @@ -168,7 +168,15 @@ def _restore_original_dimensions(input_dict, obs_space): return input_dict -def _unpack_obs(obs, space): +def _unpack_obs(obs, space, tensorlib=tf): + """Unpack a flattened Dict or Tuple observation array/tensor. + + Arguments: + obs: The flattened observation tensor + space: The original space prior to flattening + tensorlib: The library used to unflatten (reshape) the array/tensor + """ + if (isinstance(space, gym.spaces.Dict) or isinstance(space, gym.spaces.Tuple)): prep = get_preprocessor(space)(space) @@ -186,14 +194,18 @@ def _unpack_obs(obs, space): offset += p.size u.append( _unpack_obs( - tf.reshape(obs_slice, [-1] + list(p.shape)), v)) + tensorlib.reshape(obs_slice, [-1] + list(p.shape)), + v, + tensorlib=tensorlib)) else: u = OrderedDict() for p, (k, v) in zip(prep.preprocessors, space.spaces.items()): obs_slice = obs[:, offset:offset + p.size] offset += p.size u[k] = _unpack_obs( - tf.reshape(obs_slice, [-1] + list(p.shape)), v) + tensorlib.reshape(obs_slice, [-1] + list(p.shape)), + v, + tensorlib=tensorlib) return u else: return obs diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index a4af708b79151..0238ef2d8d889 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -8,6 +8,8 @@ import numpy as np import gym +from ray.rllib.utils.annotations import override + ATARI_OBS_SHAPE = (210, 160, 3) ATARI_RAM_OBS_SHAPE = (128, ) @@ -57,6 +59,7 @@ class GenericPixelPreprocessor(Preprocessor): instead for deepmind-style Atari preprocessing. """ + @override(Preprocessor) def _init_shape(self, obs_space, options): self._grayscale = options.get("grayscale") self._zero_mean = options.get("zero_mean") @@ -72,6 +75,7 @@ def _init_shape(self, obs_space, options): shape = shape[-1:] + shape[:-1] return shape + @override(Preprocessor) def transform(self, observation): """Downsamples images from (210, 160, 3) by the configured factor.""" scaled = observation[25:-25, :, :] @@ -96,27 +100,36 @@ def transform(self, observation): class AtariRamPreprocessor(Preprocessor): + @override(Preprocessor) def _init_shape(self, obs_space, options): return (128, ) + @override(Preprocessor) def transform(self, observation): return (observation - 128) / 128 class OneHotPreprocessor(Preprocessor): + @override(Preprocessor) def _init_shape(self, obs_space, options): return (self._obs_space.n, ) + @override(Preprocessor) def transform(self, observation): arr = np.zeros(self._obs_space.n) + if not self._obs_space.contains(observation): + raise ValueError("Observation outside expected value range", + self._obs_space, observation) arr[observation] = 1 return arr class NoPreprocessor(Preprocessor): + @override(Preprocessor) def _init_shape(self, obs_space, options): return self._obs_space.shape + @override(Preprocessor) def transform(self, observation): return observation @@ -127,6 +140,7 @@ class TupleFlatteningPreprocessor(Preprocessor): RLlib models will unpack the flattened output before _build_layers_v2(). """ + @override(Preprocessor) def _init_shape(self, obs_space, options): assert isinstance(self._obs_space, gym.spaces.Tuple) size = 0 @@ -139,6 +153,7 @@ def _init_shape(self, obs_space, options): size += preprocessor.size return (size, ) + @override(Preprocessor) def transform(self, observation): assert len(observation) == len(self.preprocessors), observation return np.concatenate([ @@ -153,6 +168,7 @@ class DictFlatteningPreprocessor(Preprocessor): RLlib models will unpack the flattened output before _build_layers_v2(). """ + @override(Preprocessor) def _init_shape(self, obs_space, options): assert isinstance(self._obs_space, gym.spaces.Dict) size = 0 @@ -164,6 +180,7 @@ def _init_shape(self, obs_space, options): size += preprocessor.size return (size, ) + @override(Preprocessor) def transform(self, observation): if not isinstance(observation, OrderedDict): observation = OrderedDict(sorted(list(observation.items()))) diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index 4105af7dd3675..0638c4fc83c59 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -7,16 +7,18 @@ from ray.rllib.models.model import Model from ray.rllib.models.misc import get_activation_fn, flatten +from ray.rllib.utils.annotations import override class VisionNetwork(Model): """Generic vision network.""" + @override(Model) def _build_layers_v2(self, input_dict, num_outputs, options): inputs = input_dict["obs"] filters = options.get("conv_filters") if not filters: - filters = get_filter_config(options) + filters = _get_filter_config(inputs) activation = get_activation_fn(options.get("conv_activation")) @@ -47,7 +49,7 @@ def _build_layers_v2(self, input_dict, num_outputs, options): return flatten(fc2), flatten(fc1) -def get_filter_config(options): +def _get_filter_config(inputs): filters_84x84 = [ [16, [8, 8], 4], [32, [4, 4], 2], @@ -58,12 +60,15 @@ def get_filter_config(options): [32, [4, 4], 2], [256, [11, 11], 1], ] - dim = options.get("dim") - if dim == 84: + shape = inputs.shape.as_list()[1:] + if len(shape) == 3 and shape[:2] == [84, 84]: return filters_84x84 - elif dim == 42: + elif len(shape) == 3 and shape[:2] == [42, 42]: return filters_42x42 else: raise ValueError( - "No default configuration for image size={}".format(dim) + - ", you must specify `conv_filters` manually as a model option.") + "No default configuration for obs input {}".format(inputs) + + ", you must specify `conv_filters` manually as a model option. " + "Default configurations are only available for inputs of size " + "[?, 42, 42, K] and [?, 84, 84, K]. You may alternatively want " + "to use a custom model or preprocessor.") diff --git a/python/ray/rllib/offline/__init__.py b/python/ray/rllib/offline/__init__.py new file mode 100644 index 0000000000000..195d9e7763291 --- /dev/null +++ b/python/ray/rllib/offline/__init__.py @@ -0,0 +1,20 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.offline.io_context import IOContext +from ray.rllib.offline.json_reader import JsonReader +from ray.rllib.offline.json_writer import JsonWriter +from ray.rllib.offline.output_writer import OutputWriter, NoopOutput +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.mixed_input import MixedInput + +__all__ = [ + "IOContext", + "JsonReader", + "JsonWriter", + "NoopOutput", + "OutputWriter", + "InputReader", + "MixedInput", +] diff --git a/python/ray/rllib/offline/input_reader.py b/python/ray/rllib/offline/input_reader.py new file mode 100644 index 0000000000000..3ee7356a4250c --- /dev/null +++ b/python/ray/rllib/offline/input_reader.py @@ -0,0 +1,30 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.utils.annotations import override + + +class InputReader(object): + """Input object for loading experiences in policy evaluation.""" + + def next(self): + """Return the next batch of experiences read.""" + + raise NotImplementedError + + +class SamplerInput(InputReader): + """Reads input experiences from an existing sampler.""" + + def __init__(self, sampler): + self.sampler = sampler + + @override(InputReader) + def next(self): + batches = [self.sampler.get_data()] + batches.extend(self.sampler.get_extra_batches()) + if len(batches) > 1: + return batches[0].concat_samples(batches) + else: + return batches[0] diff --git a/python/ray/rllib/offline/io_context.py b/python/ray/rllib/offline/io_context.py new file mode 100644 index 0000000000000..055bd714b152c --- /dev/null +++ b/python/ray/rllib/offline/io_context.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.offline.input_reader import SamplerInput + + +class IOContext(object): + """Attributes to pass to input / output class constructors. + + RLlib auto-sets these attributes when constructing input / output classes. + + Attributes: + log_dir (str): Default logging directory. + config (dict): Configuration of the agent. + worker_index (int): When there are multiple workers created, this + uniquely identifies the current worker. + evaluator (PolicyEvaluator): policy evaluator object reference. + """ + + def __init__(self, log_dir, config, worker_index, evaluator): + self.log_dir = log_dir + self.config = config + self.worker_index = worker_index + self.evaluator = evaluator + + def default_sampler_input(self): + return SamplerInput(self.evaluator.sampler) diff --git a/python/ray/rllib/offline/json_reader.py b/python/ray/rllib/offline/json_reader.py new file mode 100644 index 0000000000000..61dadfc4ccc97 --- /dev/null +++ b/python/ray/rllib/offline/json_reader.py @@ -0,0 +1,126 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import logging +import os +import random +import six +from six.moves.urllib.parse import urlparse + +try: + from smart_open import smart_open +except ImportError: + smart_open = None + +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.compression import unpack_if_needed + +logger = logging.getLogger(__name__) + + +class JsonReader(InputReader): + """Reader object that loads experiences from JSON file chunks. + + The input files will be read from in an random order.""" + + def __init__(self, ioctx, inputs): + """Initialize a JsonReader. + + Arguments: + ioctx (IOContext): current IO context object. + inputs (str|list): either a glob expression for files, e.g., + "/tmp/**/*.json", or a list of single file paths or URIs, e.g., + ["s3://bucket/file.json", "s3://bucket/file2.json"]. + """ + + self.ioctx = ioctx + if isinstance(inputs, six.string_types): + if os.path.isdir(inputs): + inputs = os.path.join(inputs, "*.json") + logger.warning( + "Treating input directory as glob pattern: {}".format( + inputs)) + if urlparse(inputs).scheme: + raise ValueError( + "Don't know how to glob over `{}`, ".format(inputs) + + "please specify a list of files to read instead.") + else: + self.files = glob.glob(inputs) + elif type(inputs) is list: + self.files = inputs + else: + raise ValueError( + "type of inputs must be list or str, not {}".format(inputs)) + if self.files: + logger.info("Found {} input files.".format(len(self.files))) + else: + raise ValueError("No files found matching {}".format(inputs)) + self.cur_file = None + + @override(InputReader) + def next(self): + batch = self._try_parse(self._next_line()) + tries = 0 + while not batch and tries < 100: + tries += 1 + logger.debug("Skipping empty line in {}".format(self.cur_file)) + batch = self._try_parse(self._next_line()) + if not batch: + raise ValueError( + "Failed to read valid experience batch from file: {}".format( + self.cur_file)) + return batch + + def _try_parse(self, line): + line = line.strip() + if not line: + return None + try: + return _from_json(line) + except Exception: + logger.exception("Ignoring corrupt json record in {}: {}".format( + self.cur_file, line)) + return None + + def _next_line(self): + if not self.cur_file: + self.cur_file = self._next_file() + line = self.cur_file.readline() + tries = 0 + while not line and tries < 100: + tries += 1 + if hasattr(self.cur_file, "close"): # legacy smart_open impls + self.cur_file.close() + self.cur_file = self._next_file() + line = self.cur_file.readline() + if not line: + logger.debug("Ignoring empty file {}".format(self.cur_file)) + if not line: + raise ValueError("Failed to read next line from files: {}".format( + self.files)) + return line + + def _next_file(self): + path = random.choice(self.files) + if urlparse(path).scheme: + if smart_open is None: + raise ValueError( + "You must install the `smart_open` module to read " + "from URIs like {}".format(path)) + return smart_open(path, "r") + else: + return open(path, "r") + + +def _from_json(batch): + if isinstance(batch, bytes): # smart_open S3 doesn't respect "r" + batch = batch.decode("utf-8") + data = json.loads(batch) + for k, v in data.items(): + data[k] = [unpack_if_needed(x) for x in unpack_if_needed(v)] + return SampleBatch(data) diff --git a/python/ray/rllib/offline/json_writer.py b/python/ray/rllib/offline/json_writer.py new file mode 100644 index 0000000000000..03401e9d18c7f --- /dev/null +++ b/python/ray/rllib/offline/json_writer.py @@ -0,0 +1,108 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from datetime import datetime +import json +import logging +import numpy as np +import os +from six.moves.urllib.parse import urlparse +import time + +try: + from smart_open import smart_open +except ImportError: + smart_open = None + +from ray.rllib.offline.output_writer import OutputWriter +from ray.rllib.utils.annotations import override +from ray.rllib.utils.compression import pack + +logger = logging.getLogger(__name__) + + +class JsonWriter(OutputWriter): + """Writer object that saves experiences in JSON file chunks.""" + + def __init__(self, + ioctx, + path, + max_file_size=64 * 1024 * 1024, + compress_columns=frozenset(["obs", "new_obs"])): + """Initialize a JsonWriter. + + Arguments: + ioctx (IOContext): current IO context object. + path (str): a path/URI of the output directory to save files in. + max_file_size (int): max size of single files before rolling over. + compress_columns (list): list of sample batch columns to compress. + """ + + self.ioctx = ioctx + self.path = path + self.max_file_size = max_file_size + self.compress_columns = compress_columns + if urlparse(path).scheme: + self.path_is_uri = True + else: + # Try to create local dirs if they don't exist + try: + os.makedirs(path) + except OSError: + pass # already exists + assert os.path.exists(path), "Failed to create {}".format(path) + self.path_is_uri = False + self.file_index = 0 + self.bytes_written = 0 + self.cur_file = None + + @override(OutputWriter) + def write(self, sample_batch): + start = time.time() + data = _to_json(sample_batch, self.compress_columns) + f = self._get_file() + f.write(data) + f.write("\n") + if hasattr(f, "flush"): # legacy smart_open impls + f.flush() + self.bytes_written += len(data) + logger.debug("Wrote {} bytes to {} in {}s".format( + len(data), f, + time.time() - start)) + + def _get_file(self): + if not self.cur_file or self.bytes_written >= self.max_file_size: + if self.cur_file: + self.cur_file.close() + timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + path = os.path.join( + self.path, "output-{}_worker-{}_{}.json".format( + timestr, self.ioctx.worker_index, self.file_index)) + if self.path_is_uri: + if smart_open is None: + raise ValueError( + "You must install the `smart_open` module to write " + "to URIs like {}".format(path)) + self.cur_file = smart_open(path, "w") + else: + self.cur_file = open(path, "w") + self.file_index += 1 + self.bytes_written = 0 + logger.info("Writing to new output file {}".format(self.cur_file)) + return self.cur_file + + +def _to_jsonable(v, compress): + if compress: + return str(pack(v)) + elif isinstance(v, np.ndarray): + return v.tolist() + return v + + +def _to_json(batch, compress_columns): + return json.dumps({ + k: _to_jsonable(v, compress=k in compress_columns) + for k, v in batch.data.items() + }) diff --git a/python/ray/rllib/offline/mixed_input.py b/python/ray/rllib/offline/mixed_input.py new file mode 100644 index 0000000000000..9e9a53e6974bc --- /dev/null +++ b/python/ray/rllib/offline/mixed_input.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.json_reader import JsonReader +from ray.rllib.utils.annotations import override + + +class MixedInput(InputReader): + """Mixes input from a number of other input sources. + + Examples: + >>> MixedInput(ioctx, { + "sampler": 0.4, + "/tmp/experiences/*.json": 0.4, + "s3://bucket/expert.json": 0.2, + }) + """ + + def __init__(self, ioctx, dist): + """Initialize a MixedInput. + + Arguments: + ioctx (IOContext): current IO context object. + dist (dict): dict mapping JSONReader paths or "sampler" to + probabilities. The probabilities must sum to 1.0. + """ + if sum(dist.values()) != 1.0: + raise ValueError("Values must sum to 1.0: {}".format(dist)) + self.choices = [] + self.p = [] + for k, v in dist.items(): + if k == "sampler": + self.choices.append(ioctx.default_sampler_input()) + else: + self.choices.append(JsonReader(ioctx, k)) + self.p.append(v) + + @override(InputReader) + def next(self): + source = np.random.choice(self.choices, p=self.p) + return source.next() diff --git a/python/ray/rllib/offline/output_writer.py b/python/ray/rllib/offline/output_writer.py new file mode 100644 index 0000000000000..34a38ed85fc63 --- /dev/null +++ b/python/ray/rllib/offline/output_writer.py @@ -0,0 +1,25 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.utils.annotations import override + + +class OutputWriter(object): + """Writer object for saving experiences from policy evaluation.""" + + def write(self, sample_batch): + """Save a batch of experiences. + + Arguments: + sample_batch: SampleBatch or MultiAgentBatch to save. + """ + raise NotImplementedError + + +class NoopOutput(OutputWriter): + """Output writer that discards its outputs.""" + + @override(OutputWriter) + def write(self, sample_batch): + pass diff --git a/python/ray/rllib/optimizers/__init__.py b/python/ray/rllib/optimizers/__init__.py index f7ede66f72872..41ea5446652be 100644 --- a/python/ray/rllib/optimizers/__init__.py +++ b/python/ray/rllib/optimizers/__init__.py @@ -5,10 +5,17 @@ AsyncGradientsOptimizer from ray.rllib.optimizers.sync_samples_optimizer import SyncSamplesOptimizer from ray.rllib.optimizers.sync_replay_optimizer import SyncReplayOptimizer +from ray.rllib.optimizers.sync_batch_replay_optimizer import \ + SyncBatchReplayOptimizer from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer __all__ = [ - "PolicyOptimizer", "AsyncReplayOptimizer", "AsyncSamplesOptimizer", - "AsyncGradientsOptimizer", "SyncSamplesOptimizer", "SyncReplayOptimizer", - "LocalMultiGPUOptimizer" + "PolicyOptimizer", + "AsyncReplayOptimizer", + "AsyncSamplesOptimizer", + "AsyncGradientsOptimizer", + "SyncSamplesOptimizer", + "SyncReplayOptimizer", + "LocalMultiGPUOptimizer", + "SyncBatchReplayOptimizer", ] diff --git a/python/ray/rllib/optimizers/async_gradients_optimizer.py b/python/ray/rllib/optimizers/async_gradients_optimizer.py index 499d2a91f9247..b1e5ebe846ca6 100644 --- a/python/ray/rllib/optimizers/async_gradients_optimizer.py +++ b/python/ray/rllib/optimizers/async_gradients_optimizer.py @@ -4,6 +4,7 @@ import ray from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat @@ -15,6 +16,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer): gradient computations on the remote workers. """ + @override(PolicyOptimizer) def _init(self, grads_per_step=100): self.apply_timer = TimerStat() self.wait_timer = TimerStat() @@ -25,6 +27,7 @@ def _init(self, grads_per_step=100): raise ValueError( "Async optimizer requires at least 1 remote evaluator") + @override(PolicyOptimizer) def step(self): weights = ray.put(self.local_evaluator.get_weights()) pending_gradients = {} @@ -64,6 +67,7 @@ def step(self): pending_gradients[future] = e num_gradients += 1 + @override(PolicyOptimizer) def stats(self): return dict( PolicyOptimizer.stats(self), **{ diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index 3cd5a16ad1112..4eccc0bd5c93b 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -20,6 +20,7 @@ MultiAgentBatch from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer +from ray.rllib.utils.annotations import override from ray.rllib.utils.actors import TaskPool, create_colocated from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat @@ -29,145 +30,22 @@ LEARNER_QUEUE_MAX_SIZE = 16 -@ray.remote(num_cpus=0) -class ReplayActor(object): - """A replay buffer shard. - - Ray actors are single-threaded, so for scalability multiple replay actors - may be created to increase parallelism.""" - - def __init__(self, num_shards, learning_starts, buffer_size, - train_batch_size, prioritized_replay_alpha, - prioritized_replay_beta, prioritized_replay_eps): - self.replay_starts = learning_starts // num_shards - self.buffer_size = buffer_size // num_shards - self.train_batch_size = train_batch_size - self.prioritized_replay_beta = prioritized_replay_beta - self.prioritized_replay_eps = prioritized_replay_eps - - def new_buffer(): - return PrioritizedReplayBuffer( - self.buffer_size, alpha=prioritized_replay_alpha) - - self.replay_buffers = collections.defaultdict(new_buffer) - - # Metrics - self.add_batch_timer = TimerStat() - self.replay_timer = TimerStat() - self.update_priorities_timer = TimerStat() - self.num_added = 0 - - def get_host(self): - return os.uname()[1] - - def add_batch(self, batch): - # Handle everything as if multiagent - if isinstance(batch, SampleBatch): - batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) - with self.add_batch_timer: - for policy_id, s in batch.policy_batches.items(): - for row in s.rows(): - self.replay_buffers[policy_id].add( - row["obs"], row["actions"], row["rewards"], - row["new_obs"], row["dones"], row["weights"]) - self.num_added += batch.count - - def replay(self): - if self.num_added < self.replay_starts: - return None - - with self.replay_timer: - samples = {} - for policy_id, replay_buffer in self.replay_buffers.items(): - (obses_t, actions, rewards, obses_tp1, dones, weights, - batch_indexes) = replay_buffer.sample( - self.train_batch_size, beta=self.prioritized_replay_beta) - samples[policy_id] = SampleBatch({ - "obs": obses_t, - "actions": actions, - "rewards": rewards, - "new_obs": obses_tp1, - "dones": dones, - "weights": weights, - "batch_indexes": batch_indexes - }) - return MultiAgentBatch(samples, self.train_batch_size) - - def update_priorities(self, prio_dict): - with self.update_priorities_timer: - for policy_id, (batch_indexes, td_errors) in prio_dict.items(): - new_priorities = ( - np.abs(td_errors) + self.prioritized_replay_eps) - self.replay_buffers[policy_id].update_priorities( - batch_indexes, new_priorities) - - def stats(self, debug=False): - stat = { - "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), - "replay_time_ms": round(1000 * self.replay_timer.mean, 3), - "update_priorities_time_ms": round( - 1000 * self.update_priorities_timer.mean, 3), - } - for policy_id, replay_buffer in self.replay_buffers.items(): - stat.update({ - "policy_{}".format(policy_id): replay_buffer.stats(debug=debug) - }) - return stat - - -class LearnerThread(threading.Thread): - """Background thread that updates the local model from replay data. - - The learner thread communicates with the main thread through Queues. This - is needed since Ray operations can only be run on the main thread. In - addition, moving heavyweight gradient ops session runs off the main thread - improves overall throughput. - """ - - def __init__(self, local_evaluator): - threading.Thread.__init__(self) - self.learner_queue_size = WindowStat("size", 50) - self.local_evaluator = local_evaluator - self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) - self.outqueue = queue.Queue() - self.queue_timer = TimerStat() - self.grad_timer = TimerStat() - self.daemon = True - self.weights_updated = False - self.stopped = False - - def run(self): - while not self.stopped: - self.step() - - def step(self): - with self.queue_timer: - ra, replay = self.inqueue.get() - if replay is not None: - prio_dict = {} - with self.grad_timer: - grad_out = self.local_evaluator.compute_apply(replay) - for pid, info in grad_out.items(): - prio_dict[pid] = ( - replay.policy_batches[pid]["batch_indexes"], - info["td_error"]) - # send `replay` back also so that it gets released by the original - # thread: https://github.com/ray-project/ray/issues/2610 - self.outqueue.put((ra, replay, prio_dict, replay.count)) - self.learner_queue_size.push(self.inqueue.qsize()) - self.weights_updated = True - - class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, remote evaluators (Ape-X actors), and replay buffer actors. + This has two modes of operation: + - normal replay: replays independent samples. + - batch replay: simplified mode where entire sample batches are + replayed. This supports RNNs, but not prioritization. + This optimizer requires that policy evaluators return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" + @override(PolicyOptimizer) def _init(self, learning_starts=1000, buffer_size=10000, @@ -179,9 +57,11 @@ def _init(self, sample_batch_size=50, num_replay_buffer_shards=1, max_weight_sync_delay=400, - debug=False): + debug=False, + batch_replay=False): self.debug = debug + self.batch_replay = batch_replay self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps @@ -190,7 +70,11 @@ def _init(self, self.learner = LearnerThread(self.local_evaluator) self.learner.start() - self.replay_actors = create_colocated(ReplayActor, [ + if self.batch_replay: + replay_cls = BatchReplayActor + else: + replay_cls = ReplayActor + self.replay_actors = create_colocated(replay_cls, [ num_replay_buffer_shards, learning_starts, buffer_size, @@ -224,19 +108,11 @@ def _init(self, # Kick off async background sampling self.sample_tasks = TaskPool() if self.remote_evaluators: - self.set_evaluators(self.remote_evaluators) - - # For https://github.com/ray-project/ray/issues/2541 only - def set_evaluators(self, remote_evaluators): - self.remote_evaluators = remote_evaluators - weights = self.local_evaluator.get_weights() - for ev in self.remote_evaluators: - ev.set_weights.remote(weights) - self.steps_since_update[ev] = 0 - for _ in range(SAMPLE_QUEUE_DEPTH): - self.sample_tasks.add(ev, ev.sample_with_count.remote()) + self._set_evaluators(self.remote_evaluators) + @override(PolicyOptimizer) def step(self): + assert self.learner.is_alive() assert len(self.remote_evaluators) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() @@ -251,6 +127,53 @@ def step(self): self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps + @override(PolicyOptimizer) + def stop(self): + for r in self.replay_actors: + r.__ray_terminate__.remote() + self.learner.stopped = True + + @override(PolicyOptimizer) + def stats(self): + replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug)) + timing = { + "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) + for k in self.timers + } + timing["learner_grad_time_ms"] = round( + 1000 * self.learner.grad_timer.mean, 3) + timing["learner_dequeue_time_ms"] = round( + 1000 * self.learner.queue_timer.mean, 3) + stats = { + "sample_throughput": round(self.timers["sample"].mean_throughput, + 3), + "train_throughput": round(self.timers["train"].mean_throughput, 3), + "num_weight_syncs": self.num_weight_syncs, + "num_samples_dropped": self.num_samples_dropped, + "learner_queue": self.learner.learner_queue_size.stats(), + "replay_shard_0": replay_stats, + } + debug_stats = { + "timing_breakdown": timing, + "pending_sample_tasks": self.sample_tasks.count, + "pending_replay_tasks": self.replay_tasks.count, + } + if self.debug: + stats.update(debug_stats) + if self.learner.stats: + stats["learner"] = self.learner.stats + return dict(PolicyOptimizer.stats(self), **stats) + + # For https://github.com/ray-project/ray/issues/2541 only + def _set_evaluators(self, remote_evaluators): + self.remote_evaluators = remote_evaluators + weights = self.local_evaluator.get_weights() + for ev in self.remote_evaluators: + ev.set_weights.remote(weights) + self.steps_since_update[ev] = 0 + for _ in range(SAMPLE_QUEUE_DEPTH): + self.sample_tasks.add(ev, ev.sample_with_count.remote()) + def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None @@ -290,45 +213,191 @@ def _step(self): else: with self.timers["get_samples"]: samples = ray.get(replay) - self.learner.inqueue.put((ra, samples)) + # Defensive copy against plasma crashes, see #2610 #3452 + self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): - ra, _, prio_dict, count = self.learner.outqueue.get() + ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps - def stop(self): - for r in self.replay_actors: - r.__ray_terminate__.remote() - self.learner.stopped = True - def stats(self): - replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug)) - timing = { - "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) - for k in self.timers - } - timing["learner_grad_time_ms"] = round( - 1000 * self.learner.grad_timer.mean, 3) - timing["learner_dequeue_time_ms"] = round( - 1000 * self.learner.queue_timer.mean, 3) - stats = { - "sample_throughput": round(self.timers["sample"].mean_throughput, - 3), - "train_throughput": round(self.timers["train"].mean_throughput, 3), - "num_weight_syncs": self.num_weight_syncs, - "num_samples_dropped": self.num_samples_dropped, - "learner_queue": self.learner.learner_queue_size.stats(), - "replay_shard_0": replay_stats, +@ray.remote(num_cpus=0) +class ReplayActor(object): + """A replay buffer shard. + + Ray actors are single-threaded, so for scalability multiple replay actors + may be created to increase parallelism.""" + + def __init__(self, num_shards, learning_starts, buffer_size, + train_batch_size, prioritized_replay_alpha, + prioritized_replay_beta, prioritized_replay_eps): + self.replay_starts = learning_starts // num_shards + self.buffer_size = buffer_size // num_shards + self.train_batch_size = train_batch_size + self.prioritized_replay_beta = prioritized_replay_beta + self.prioritized_replay_eps = prioritized_replay_eps + + def new_buffer(): + return PrioritizedReplayBuffer( + self.buffer_size, alpha=prioritized_replay_alpha) + + self.replay_buffers = collections.defaultdict(new_buffer) + + # Metrics + self.add_batch_timer = TimerStat() + self.replay_timer = TimerStat() + self.update_priorities_timer = TimerStat() + self.num_added = 0 + + def get_host(self): + return os.uname()[1] + + def add_batch(self, batch): + # Handle everything as if multiagent + if isinstance(batch, SampleBatch): + batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) + with self.add_batch_timer: + for policy_id, s in batch.policy_batches.items(): + for row in s.rows(): + self.replay_buffers[policy_id].add( + row["obs"], row["actions"], row["rewards"], + row["new_obs"], row["dones"], row["weights"]) + self.num_added += batch.count + + def replay(self): + if self.num_added < self.replay_starts: + return None + + with self.replay_timer: + samples = {} + for policy_id, replay_buffer in self.replay_buffers.items(): + (obses_t, actions, rewards, obses_tp1, dones, weights, + batch_indexes) = replay_buffer.sample( + self.train_batch_size, beta=self.prioritized_replay_beta) + samples[policy_id] = SampleBatch({ + "obs": obses_t, + "actions": actions, + "rewards": rewards, + "new_obs": obses_tp1, + "dones": dones, + "weights": weights, + "batch_indexes": batch_indexes + }) + return MultiAgentBatch(samples, self.train_batch_size) + + def update_priorities(self, prio_dict): + with self.update_priorities_timer: + for policy_id, (batch_indexes, td_errors) in prio_dict.items(): + new_priorities = ( + np.abs(td_errors) + self.prioritized_replay_eps) + self.replay_buffers[policy_id].update_priorities( + batch_indexes, new_priorities) + + def stats(self, debug=False): + stat = { + "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), + "replay_time_ms": round(1000 * self.replay_timer.mean, 3), + "update_priorities_time_ms": round( + 1000 * self.update_priorities_timer.mean, 3), } - debug_stats = { - "timing_breakdown": timing, - "pending_sample_tasks": self.sample_tasks.count, - "pending_replay_tasks": self.replay_tasks.count, + for policy_id, replay_buffer in self.replay_buffers.items(): + stat.update({ + "policy_{}".format(policy_id): replay_buffer.stats(debug=debug) + }) + return stat + + +@ray.remote(num_cpus=0) +class BatchReplayActor(object): + """The batch replay version of the replay actor. + + This allows for RNN models, but ignores prioritization params. + """ + + def __init__(self, num_shards, learning_starts, buffer_size, + train_batch_size, prioritized_replay_alpha, + prioritized_replay_beta, prioritized_replay_eps): + self.replay_starts = learning_starts // num_shards + self.buffer_size = buffer_size // num_shards + self.train_batch_size = train_batch_size + self.buffer = [] + + # Metrics + self.num_added = 0 + self.cur_size = 0 + + def get_host(self): + return os.uname()[1] + + def add_batch(self, batch): + # Handle everything as if multiagent + if isinstance(batch, SampleBatch): + batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) + self.buffer.append(batch) + self.cur_size += batch.count + self.num_added += batch.count + while self.cur_size > self.buffer_size: + self.cur_size -= self.buffer.pop(0).count + + def replay(self): + if self.num_added < self.replay_starts: + return None + return random.choice(self.buffer) + + def update_priorities(self, prio_dict): + pass + + def stats(self, debug=False): + stat = { + "cur_size": self.cur_size, + "num_added": self.num_added, } - if self.debug: - stats.update(debug_stats) - return dict(PolicyOptimizer.stats(self), **stats) + return stat + + +class LearnerThread(threading.Thread): + """Background thread that updates the local model from replay data. + + The learner thread communicates with the main thread through Queues. This + is needed since Ray operations can only be run on the main thread. In + addition, moving heavyweight gradient ops session runs off the main thread + improves overall throughput. + """ + + def __init__(self, local_evaluator): + threading.Thread.__init__(self) + self.learner_queue_size = WindowStat("size", 50) + self.local_evaluator = local_evaluator + self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) + self.outqueue = queue.Queue() + self.queue_timer = TimerStat() + self.grad_timer = TimerStat() + self.daemon = True + self.weights_updated = False + self.stopped = False + self.stats = {} + + def run(self): + while not self.stopped: + self.step() + + def step(self): + with self.queue_timer: + ra, replay = self.inqueue.get() + if replay is not None: + prio_dict = {} + with self.grad_timer: + grad_out = self.local_evaluator.compute_apply(replay) + for pid, info in grad_out.items(): + prio_dict[pid] = ( + replay.policy_batches[pid].data.get("batch_indexes"), + info.get("td_error")) + if "stats" in info: + self.stats[pid] = info["stats"] + self.outqueue.put((ra, prio_dict, replay.count)) + self.learner_queue_size.push(self.inqueue.qsize()) + self.weights_updated = True diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index 6b8f6014d75a1..9322e1fcd79a5 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -18,6 +18,7 @@ from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils.actors import TaskPool +from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat @@ -27,6 +28,199 @@ NUM_DATA_LOAD_THREADS = 16 +class AsyncSamplesOptimizer(PolicyOptimizer): + """Main event loop of the IMPALA architecture. + + This class coordinates the data transfers between the learner thread + and remote evaluators (IMPALA actors). + """ + + @override(PolicyOptimizer) + def _init(self, + train_batch_size=500, + sample_batch_size=50, + num_envs_per_worker=1, + num_gpus=0, + lr=0.0005, + replay_buffer_num_slots=0, + replay_proportion=0.0, + num_data_loader_buffers=1, + max_sample_requests_in_flight_per_worker=2, + broadcast_interval=1, + num_sgd_iter=1, + minibatch_buffer_size=1, + _fake_gpus=False): + self.learning_started = False + self.train_batch_size = train_batch_size + self.sample_batch_size = sample_batch_size + self.broadcast_interval = broadcast_interval + + if num_gpus > 1 or num_data_loader_buffers > 1: + logger.info( + "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( + num_gpus, num_data_loader_buffers)) + if num_data_loader_buffers < minibatch_buffer_size: + raise ValueError( + "In multi-gpu mode you must have at least as many " + "parallel data loader buffers as minibatch buffers: " + "{} vs {}".format(num_data_loader_buffers, + minibatch_buffer_size)) + self.learner = TFMultiGPULearner( + self.local_evaluator, + lr=lr, + num_gpus=num_gpus, + train_batch_size=train_batch_size, + num_data_loader_buffers=num_data_loader_buffers, + minibatch_buffer_size=minibatch_buffer_size, + num_sgd_iter=num_sgd_iter, + _fake_gpus=_fake_gpus) + else: + self.learner = LearnerThread(self.local_evaluator, + minibatch_buffer_size, num_sgd_iter) + self.learner.start() + + assert len(self.remote_evaluators) > 0 + + # Stats + self.timers = {k: TimerStat() for k in ["train", "sample"]} + self.num_weight_syncs = 0 + self.num_replayed = 0 + self.learning_started = False + + # Kick off async background sampling + self.sample_tasks = TaskPool() + weights = self.local_evaluator.get_weights() + for ev in self.remote_evaluators: + ev.set_weights.remote(weights) + for _ in range(max_sample_requests_in_flight_per_worker): + self.sample_tasks.add(ev, ev.sample.remote()) + + self.batch_buffer = [] + + if replay_proportion: + if replay_buffer_num_slots * sample_batch_size <= train_batch_size: + raise ValueError( + "Replay buffer size is too small to produce train, " + "please increase replay_buffer_num_slots.", + replay_buffer_num_slots, sample_batch_size, + train_batch_size) + self.replay_proportion = replay_proportion + self.replay_buffer_num_slots = replay_buffer_num_slots + self.replay_batches = [] + + @override(PolicyOptimizer) + def step(self): + assert self.learner.is_alive() + start = time.time() + sample_timesteps, train_timesteps = self._step() + time_delta = time.time() - start + self.timers["sample"].push(time_delta) + self.timers["sample"].push_units_processed(sample_timesteps) + if train_timesteps > 0: + self.learning_started = True + if self.learning_started: + self.timers["train"].push(time_delta) + self.timers["train"].push_units_processed(train_timesteps) + self.num_steps_sampled += sample_timesteps + self.num_steps_trained += train_timesteps + + @override(PolicyOptimizer) + def stop(self): + self.learner.stopped = True + + @override(PolicyOptimizer) + def stats(self): + timing = { + "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) + for k in self.timers + } + timing["learner_grad_time_ms"] = round( + 1000 * self.learner.grad_timer.mean, 3) + timing["learner_load_time_ms"] = round( + 1000 * self.learner.load_timer.mean, 3) + timing["learner_load_wait_time_ms"] = round( + 1000 * self.learner.load_wait_timer.mean, 3) + timing["learner_dequeue_time_ms"] = round( + 1000 * self.learner.queue_timer.mean, 3) + stats = { + "sample_throughput": round(self.timers["sample"].mean_throughput, + 3), + "train_throughput": round(self.timers["train"].mean_throughput, 3), + "num_weight_syncs": self.num_weight_syncs, + "num_steps_replayed": self.num_replayed, + "timing_breakdown": timing, + "learner_queue": self.learner.learner_queue_size.stats(), + } + if self.learner.stats: + stats["learner"] = self.learner.stats + return dict(PolicyOptimizer.stats(self), **stats) + + def _step(self): + sample_timesteps, train_timesteps = 0, 0 + num_sent = 0 + weights = None + + for ev, sample_batch in self._augment_with_replay( + self.sample_tasks.completed_prefetch()): + self.batch_buffer.append(sample_batch) + if sum(b.count + for b in self.batch_buffer) >= self.train_batch_size: + train_batch = self.batch_buffer[0].concat_samples( + self.batch_buffer) + self.learner.inqueue.put(train_batch) + self.batch_buffer = [] + + # If the batch was replayed, skip the update below. + if ev is None: + continue + + sample_timesteps += sample_batch.count + + # Put in replay buffer if enabled + if self.replay_buffer_num_slots > 0: + self.replay_batches.append(sample_batch) + if len(self.replay_batches) > self.replay_buffer_num_slots: + self.replay_batches.pop(0) + + # Note that it's important to pull new weights once + # updated to avoid excessive correlation between actors + if weights is None or (self.learner.weights_updated + and num_sent >= self.broadcast_interval): + self.learner.weights_updated = False + weights = ray.put(self.local_evaluator.get_weights()) + num_sent = 0 + ev.set_weights.remote(weights) + self.num_weight_syncs += 1 + num_sent += 1 + + # Kick off another sample request + self.sample_tasks.add(ev, ev.sample.remote()) + + while not self.learner.outqueue.empty(): + count = self.learner.outqueue.get() + train_timesteps += count + + return sample_timesteps, train_timesteps + + def _augment_with_replay(self, sample_futures): + def can_replay(): + num_needed = int( + np.ceil(self.train_batch_size / self.sample_batch_size)) + return len(self.replay_batches) > num_needed + + for ev, sample_batch in sample_futures: + sample_batch = ray.get(sample_batch) + yield ev, sample_batch + + if can_replay(): + f = self.replay_proportion + while random.random() < f: + f -= 1 + replay_batch = random.choice(self.replay_batches) + self.num_replayed += replay_batch.count + yield None, replay_batch + + class LearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. @@ -36,12 +230,14 @@ class LearnerThread(threading.Thread): improves overall throughput. """ - def __init__(self, local_evaluator): + def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_evaluator = local_evaluator self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() + self.minibatch_buffer = MinibatchBuffer( + self.inqueue, minibatch_buffer_size, num_sgd_iter) self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.load_timer = TimerStat() @@ -57,7 +253,7 @@ def run(self): def step(self): with self.queue_timer: - batch = self.inqueue.get() + batch, _ = self.minibatch_buffer.get() with self.grad_timer: fetches = self.local_evaluator.compute_apply(batch) @@ -76,19 +272,24 @@ def __init__(self, num_gpus=1, lr=0.0005, train_batch_size=500, - grad_clip=40, - num_parallel_data_loaders=1): + num_data_loader_buffers=1, + minibatch_buffer_size=1, + num_sgd_iter=1, + _fake_gpus=False): # Multi-GPU requires TensorFlow to function. import tensorflow as tf - LearnerThread.__init__(self, local_evaluator) + LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size, + num_sgd_iter) self.lr = lr self.train_batch_size = train_batch_size if not num_gpus: self.devices = ["/cpu:0"] + elif _fake_gpus: + self.devices = ["/cpu:{}".format(i) for i in range(num_gpus)] else: self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)] - logger.info("TFMultiGPULearner devices {}".format(self.devices)) + logger.info("TFMultiGPULearner devices {}".format(self.devices)) assert self.train_batch_size % len(self.devices) == 0 assert self.train_batch_size >= len(self.devices), "batch too small" self.policy = self.local_evaluator.policy_map["default"] @@ -107,16 +308,15 @@ def __init__(self, else: rnn_inputs = [] adam = tf.train.AdamOptimizer(self.lr) - for _ in range(num_parallel_data_loaders): + for _ in range(num_data_loader_buffers): self.par_opt.append( LocalSyncParallelOptimizer( adam, self.devices, - [v for _, v in self.policy.loss_inputs()], + [v for _, v in self.policy._loss_inputs], rnn_inputs, 999999, # it will get rounded down - self.policy.copy, - grad_norm_clipping=grad_clip)) + self.policy.copy)) self.sess = self.local_evaluator.tf_sess self.sess.run(tf.global_variables_initializer()) @@ -129,17 +329,22 @@ def __init__(self, self.loader_thread = _LoaderThread(self, share_stats=(i == 0)) self.loader_thread.start() + self.minibatch_buffer = MinibatchBuffer( + self.ready_optimizers, minibatch_buffer_size, num_sgd_iter) + + @override(LearnerThread) def step(self): assert self.loader_thread.is_alive() with self.load_wait_timer: - opt = self.ready_optimizers.get() + opt, released = self.minibatch_buffer.get() + if released: + self.idle_optimizers.put(opt) with self.grad_timer: fetches = opt.optimize(self.sess, 0) self.weights_updated = True self.stats = fetches.get("stats", {}) - self.idle_optimizers.put(opt) self.outqueue.put(self.train_batch_size) self.learner_queue_size.push(self.inqueue.qsize()) @@ -158,9 +363,9 @@ def __init__(self, learner, share_stats): def run(self): while True: - self.step() + self._step() - def step(self): + def _step(self): s = self.learner with self.queue_timer: batch = s.inqueue.get() @@ -169,7 +374,7 @@ def step(self): with self.load_timer: tuples = s.policy._get_loss_inputs_dict(batch) - data_keys = [ph for _, ph in s.policy.loss_inputs()] + data_keys = [ph for _, ph in s.policy._loss_inputs] if s.policy._state_inputs: state_keys = s.policy._state_inputs + [s.policy._seq_lens] else: @@ -180,180 +385,41 @@ def step(self): s.ready_optimizers.put(opt) -class AsyncSamplesOptimizer(PolicyOptimizer): - """Main event loop of the IMPALA architecture. - - This class coordinates the data transfers between the learner thread - and remote evaluators (IMPALA actors). - """ - - def _init(self, - train_batch_size=500, - sample_batch_size=50, - num_envs_per_worker=1, - num_gpus=0, - lr=0.0005, - grad_clip=40, - replay_buffer_num_slots=0, - replay_proportion=0.0, - num_parallel_data_loaders=1, - max_sample_requests_in_flight_per_worker=2, - broadcast_interval=1): - self.learning_started = False - self.train_batch_size = train_batch_size - self.sample_batch_size = sample_batch_size - self.broadcast_interval = broadcast_interval - - if num_gpus > 1 or num_parallel_data_loaders > 1: - logger.info( - "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( - num_gpus, num_parallel_data_loaders)) - if train_batch_size // max(1, num_gpus) % ( - sample_batch_size // num_envs_per_worker) != 0: - raise ValueError( - "Sample batches must evenly divide across GPUs.") - self.learner = TFMultiGPULearner( - self.local_evaluator, - lr=lr, - num_gpus=num_gpus, - train_batch_size=train_batch_size, - grad_clip=grad_clip, - num_parallel_data_loaders=num_parallel_data_loaders) - else: - self.learner = LearnerThread(self.local_evaluator) - self.learner.start() - - assert len(self.remote_evaluators) > 0 - - # Stats - self.timers = {k: TimerStat() for k in ["train", "sample"]} - self.num_weight_syncs = 0 - self.num_replayed = 0 - self.learning_started = False - - # Kick off async background sampling - self.sample_tasks = TaskPool() - weights = self.local_evaluator.get_weights() - for ev in self.remote_evaluators: - ev.set_weights.remote(weights) - for _ in range(max_sample_requests_in_flight_per_worker): - self.sample_tasks.add(ev, ev.sample.remote()) - - self.batch_buffer = [] - - if replay_proportion: - assert replay_buffer_num_slots > 0 - assert (replay_buffer_num_slots * sample_batch_size > - train_batch_size) - self.replay_proportion = replay_proportion - self.replay_buffer_num_slots = replay_buffer_num_slots - self.replay_batches = [] - - def step(self): - assert self.learner.is_alive() - start = time.time() - sample_timesteps, train_timesteps = self._step() - time_delta = time.time() - start - self.timers["sample"].push(time_delta) - self.timers["sample"].push_units_processed(sample_timesteps) - if train_timesteps > 0: - self.learning_started = True - if self.learning_started: - self.timers["train"].push(time_delta) - self.timers["train"].push_units_processed(train_timesteps) - self.num_steps_sampled += sample_timesteps - self.num_steps_trained += train_timesteps - - def _augment_with_replay(self, sample_futures): - def can_replay(): - num_needed = int( - np.ceil(self.train_batch_size / self.sample_batch_size)) - return len(self.replay_batches) > num_needed - - for ev, sample_batch in sample_futures: - sample_batch = ray.get(sample_batch) - yield ev, sample_batch - - if can_replay(): - f = self.replay_proportion - while random.random() < f: - f -= 1 - replay_batch = random.choice(self.replay_batches) - self.num_replayed += replay_batch.count - yield None, replay_batch - - def _step(self): - sample_timesteps, train_timesteps = 0, 0 - num_sent = 0 - weights = None - - for ev, sample_batch in self._augment_with_replay( - self.sample_tasks.completed_prefetch()): - self.batch_buffer.append(sample_batch) - if sum(b.count - for b in self.batch_buffer) >= self.train_batch_size: - train_batch = self.batch_buffer[0].concat_samples( - self.batch_buffer) - self.learner.inqueue.put(train_batch) - self.batch_buffer = [] - - # If the batch was replayed, skip the update below. - if ev is None: - continue - - sample_timesteps += sample_batch.count - - # Put in replay buffer if enabled - if self.replay_buffer_num_slots > 0: - self.replay_batches.append(sample_batch) - if len(self.replay_batches) > self.replay_buffer_num_slots: - self.replay_batches.pop(0) - - # Note that it's important to pull new weights once - # updated to avoid excessive correlation between actors - if weights is None or (self.learner.weights_updated - and num_sent >= self.broadcast_interval): - self.learner.weights_updated = False - weights = ray.put(self.local_evaluator.get_weights()) - num_sent = 0 - ev.set_weights.remote(weights) - self.num_weight_syncs += 1 - num_sent += 1 - - # Kick off another sample request - self.sample_tasks.add(ev, ev.sample.remote()) - - while not self.learner.outqueue.empty(): - count = self.learner.outqueue.get() - train_timesteps += count - - return sample_timesteps, train_timesteps - - def stop(self): - self.learner.stopped = True - - def stats(self): - timing = { - "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) - for k in self.timers - } - timing["learner_grad_time_ms"] = round( - 1000 * self.learner.grad_timer.mean, 3) - timing["learner_load_time_ms"] = round( - 1000 * self.learner.load_timer.mean, 3) - timing["learner_load_wait_time_ms"] = round( - 1000 * self.learner.load_wait_timer.mean, 3) - timing["learner_dequeue_time_ms"] = round( - 1000 * self.learner.queue_timer.mean, 3) - stats = { - "sample_throughput": round(self.timers["sample"].mean_throughput, - 3), - "train_throughput": round(self.timers["train"].mean_throughput, 3), - "num_weight_syncs": self.num_weight_syncs, - "num_steps_replayed": self.num_replayed, - "timing_breakdown": timing, - "learner_queue": self.learner.learner_queue_size.stats(), - } - if self.learner.stats: - stats["learner"] = self.learner.stats - return dict(PolicyOptimizer.stats(self), **stats) +class MinibatchBuffer(object): + """Ring buffer of recent data batches for minibatch SGD.""" + + def __init__(self, inqueue, size, num_passes): + """Initialize a minibatch buffer. + + Arguments: + inqueue: Queue to populate the internal ring buffer from. + size: Max number of data items to buffer. + num_passes: Max num times each data item should be emitted. + """ + self.inqueue = inqueue + self.size = size + self.max_ttl = num_passes + self.cur_max_ttl = 1 # ramp up slowly to better mix the input data + self.buffers = [None] * size + self.ttl = [0] * size + self.idx = 0 + + def get(self): + """Get a new batch from the internal ring buffer. + + Returns: + buf: Data item saved from inqueue. + released: True if the item is now removed from the ring buffer. + """ + if self.ttl[self.idx] <= 0: + self.buffers[self.idx] = self.inqueue.get() + self.ttl[self.idx] = self.cur_max_ttl + if self.cur_max_ttl < self.max_ttl: + self.cur_max_ttl += 1 + buf = self.buffers[self.idx] + self.ttl[self.idx] -= 1 + released = self.ttl[self.idx] <= 0 + if released: + self.buffers[self.idx] = None + self.idx = (self.idx + 1) % len(self.buffers) + return buf, released diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index c548b20cc022d..0a03df41cb162 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -46,7 +46,6 @@ class LocalSyncParallelOptimizer(object): clipped. build_graph: Function that takes the specified inputs and returns a TF Policy Graph instance. - grad_norm_clipping: None or int stdev to clip grad norms by """ def __init__(self, diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 771acb5ac72c8..73c2416b16b81 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -12,7 +12,12 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer +from ray.rllib.optimizers.rollout import collect_samples, \ + collect_samples_straggler_mitigation +from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat +from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ + MultiAgentBatch logger = logging.getLogger(__name__) @@ -33,15 +38,22 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): may result in unexpected behavior. """ + @override(PolicyOptimizer) def _init(self, sgd_batch_size=128, num_sgd_iter=10, + sample_batch_size=200, + num_envs_per_worker=1, train_batch_size=1024, num_gpus=0, - standardize_fields=[]): + standardize_fields=[], + straggler_mitigation=False): self.batch_size = sgd_batch_size self.num_sgd_iter = num_sgd_iter + self.num_envs_per_worker = num_envs_per_worker + self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size + self.straggler_mitigation = straggler_mitigation if not num_gpus: self.devices = ["/cpu:0"] else: @@ -61,36 +73,40 @@ def _init(self, logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices)) - if set(self.local_evaluator.policy_map.keys()) != {"default"}: - raise ValueError( - "Multi-agent is not supported with multi-GPU. Try using the " - "simple optimizer instead.") - self.policy = self.local_evaluator.policy_map["default"] - if not isinstance(self.policy, TFPolicyGraph): - raise ValueError( - "Only TF policies are supported with multi-GPU. Try using the " - "simple optimizer instead.") + self.policies = dict( + self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p))) + logger.debug("Policies to train: {}".format(self.policies)) + for policy_id, policy in self.policies.items(): + if not isinstance(policy, TFPolicyGraph): + raise ValueError( + "Only TF policies are supported with multi-GPU. Try using " + "the simple optimizer instead.") # per-GPU graph copies created below must share vars with the policy # reuse is set to AUTO_REUSE because Adam nodes are created after # all of the device copies are created. + self.optimizers = {} with self.local_evaluator.tf_sess.graph.as_default(): with self.local_evaluator.tf_sess.as_default(): - with tf.variable_scope("default", reuse=tf.AUTO_REUSE): - if self.policy._state_inputs: - rnn_inputs = self.policy._state_inputs + [ - self.policy._seq_lens - ] - else: - rnn_inputs = [] - self.par_opt = LocalSyncParallelOptimizer( - self.policy.optimizer(), self.devices, - [v for _, v in self.policy.loss_inputs()], rnn_inputs, - self.per_device_batch_size, self.policy.copy) + for policy_id, policy in self.policies.items(): + with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE): + if policy._state_inputs: + rnn_inputs = policy._state_inputs + [ + policy._seq_lens + ] + else: + rnn_inputs = [] + self.optimizers[policy_id] = ( + LocalSyncParallelOptimizer( + policy._optimizer, self.devices, + [v + for _, v in policy._loss_inputs], rnn_inputs, + self.per_device_batch_size, policy.copy)) self.sess = self.local_evaluator.tf_sess self.sess.run(tf.global_variables_initializer()) + @override(PolicyOptimizer) def step(self): with self.update_weights_timer: if self.remote_evaluators: @@ -100,54 +116,85 @@ def step(self): with self.sample_timer: if self.remote_evaluators: - # TODO(rliaw): remove when refactoring - from ray.rllib.agents.ppo.rollout import collect_samples - samples = collect_samples(self.remote_evaluators, - self.train_batch_size) + if self.straggler_mitigation: + samples = collect_samples_straggler_mitigation( + self.remote_evaluators, self.train_batch_size) + else: + samples = collect_samples( + self.remote_evaluators, self.sample_batch_size, + self.num_envs_per_worker, self.train_batch_size) + if samples.count > self.train_batch_size * 2: + logger.info( + "Collected more training samples than expected " + "(actual={}, train_batch_size={}). ".format( + samples.count, self.train_batch_size) + + "This may be because you have many workers or " + "long episodes in 'complete_episodes' batch mode.") else: samples = self.local_evaluator.sample() - self._check_not_multiagent(samples) - - for field in self.standardize_fields: - value = samples[field] - standardized = (value - value.mean()) / max(1e-4, value.std()) - samples[field] = standardized - - # Important: don't shuffle RNN sequence elements - if not self.policy._state_inputs: - samples.shuffle() - + # Handle everything as if multiagent + if isinstance(samples, SampleBatch): + samples = MultiAgentBatch({ + DEFAULT_POLICY_ID: samples + }, samples.count) + + for policy_id, policy in self.policies.items(): + if policy_id not in samples.policy_batches: + continue + + batch = samples.policy_batches[policy_id] + for field in self.standardize_fields: + value = batch[field] + standardized = (value - value.mean()) / max(1e-4, value.std()) + batch[field] = standardized + + # Important: don't shuffle RNN sequence elements + if not policy._state_inputs: + batch.shuffle() + + num_loaded_tuples = {} with self.load_timer: - tuples = self.policy._get_loss_inputs_dict(samples) - data_keys = [ph for _, ph in self.policy.loss_inputs()] - if self.policy._state_inputs: - state_keys = ( - self.policy._state_inputs + [self.policy._seq_lens]) - else: - state_keys = [] - tuples_per_device = self.par_opt.load_data( - self.sess, [tuples[k] for k in data_keys], - [tuples[k] for k in state_keys]) - + for policy_id, batch in samples.policy_batches.items(): + if policy_id not in self.policies: + continue + + policy = self.policies[policy_id] + tuples = policy._get_loss_inputs_dict(batch) + data_keys = [ph for _, ph in policy._loss_inputs] + if policy._state_inputs: + state_keys = policy._state_inputs + [policy._seq_lens] + else: + state_keys = [] + num_loaded_tuples[policy_id] = ( + self.optimizers[policy_id].load_data( + self.sess, [tuples[k] for k in data_keys], + [tuples[k] for k in state_keys])) + + fetches = {} with self.grad_timer: - num_batches = ( - int(tuples_per_device) // int(self.per_device_batch_size)) - logger.debug("== sgd epochs ==") - for i in range(self.num_sgd_iter): - iter_extra_fetches = defaultdict(list) - permutation = np.random.permutation(num_batches) - for batch_index in range(num_batches): - batch_fetches = self.par_opt.optimize( - self.sess, - permutation[batch_index] * self.per_device_batch_size) - for k, v in batch_fetches.items(): - iter_extra_fetches[k].append(v) - logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) + for policy_id, tuples_per_device in num_loaded_tuples.items(): + optimizer = self.optimizers[policy_id] + num_batches = ( + int(tuples_per_device) // int(self.per_device_batch_size)) + logger.debug("== sgd epochs for {} ==".format(policy_id)) + for i in range(self.num_sgd_iter): + iter_extra_fetches = defaultdict(list) + permutation = np.random.permutation(num_batches) + for batch_index in range(num_batches): + batch_fetches = optimizer.optimize( + self.sess, permutation[batch_index] * + self.per_device_batch_size) + for k, v in batch_fetches.items(): + iter_extra_fetches[k].append(v) + logger.debug("{} {}".format(i, + _averaged(iter_extra_fetches))) + fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count - return _averaged(iter_extra_fetches) + return fetches + @override(PolicyOptimizer) def stats(self): return dict( PolicyOptimizer.stats(self), **{ diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index 3f958c4d66057..a0cc085eec898 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -7,7 +7,6 @@ import ray from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes -from ray.rllib.evaluation.sample_batch import MultiAgentBatch logger = logging.getLogger(__name__) @@ -63,7 +62,7 @@ def __init__(self, local_evaluator, remote_evaluators=None, config=None): def _init(self): """Subclasses should prefer overriding this instead of __init__.""" - pass + raise NotImplementedError def step(self): """Takes a logical optimization step. @@ -86,11 +85,33 @@ def stats(self): "num_steps_sampled": self.num_steps_sampled, } - def collect_metrics(self, timeout_seconds, min_history=100): + def save(self): + """Returns a serializable object representing the optimizer state.""" + + return [self.num_steps_trained, self.num_steps_sampled] + + def restore(self, data): + """Restores optimizer state from the given data object.""" + + self.num_steps_trained = data[0] + self.num_steps_sampled = data[1] + + def stop(self): + """Release any resources used by this optimizer.""" + pass + + def collect_metrics(self, + timeout_seconds, + min_history=100, + selected_evaluators=None): """Returns evaluator and optimizer stats. Arguments: + timeout_seconds (int): Max wait time for a evaluator before + dropping its results. This usually indicates a hung evaluator. min_history (int): Min history length to smooth results over. + selected_evaluators (list): Override the list of remote evaluators + to collect metrics from. Returns: res (dict): A training result dict from evaluator metrics with @@ -98,7 +119,7 @@ def collect_metrics(self, timeout_seconds, min_history=100): """ episodes, num_dropped = collect_episodes( self.local_evaluator, - self.remote_evaluators, + selected_evaluators or self.remote_evaluators, timeout_seconds=timeout_seconds) orig_episodes = list(episodes) missing = min_history - len(episodes) @@ -111,17 +132,6 @@ def collect_metrics(self, timeout_seconds, min_history=100): res.update(info=self.stats()) return res - def save(self): - """Returns a serializable object representing the optimizer state.""" - - return [self.num_steps_trained, self.num_steps_sampled] - - def restore(self, data): - """Restores optimizer state from the given data object.""" - - self.num_steps_trained = data[0] - self.num_steps_sampled = data[1] - def foreach_evaluator(self, func): """Apply the given function to each evaluator instance.""" @@ -143,16 +153,6 @@ def foreach_evaluator_with_index(self, func): ]) return local_result + remote_results - def stop(self): - """Release any resources used by this optimizer.""" - pass - - @staticmethod - def _check_not_multiagent(sample_batch): - if isinstance(sample_batch, MultiAgentBatch): - raise NotImplementedError( - "This optimizer does not support multi-agent yet.") - @classmethod def make(cls, env_creator, diff --git a/python/ray/rllib/optimizers/rollout.py b/python/ray/rllib/optimizers/rollout.py new file mode 100644 index 0000000000000..d13f7c5efae7f --- /dev/null +++ b/python/ray/rllib/optimizers/rollout.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import ray +from ray.rllib.evaluation.sample_batch import SampleBatch + +logger = logging.getLogger(__name__) + + +def collect_samples(agents, sample_batch_size, num_envs_per_worker, + train_batch_size): + """Collects at least train_batch_size samples, never discarding any.""" + + num_timesteps_so_far = 0 + trajectories = [] + agent_dict = {} + + for agent in agents: + fut_sample = agent.sample.remote() + agent_dict[fut_sample] = agent + + while agent_dict: + [fut_sample], _ = ray.wait(list(agent_dict)) + agent = agent_dict.pop(fut_sample) + next_sample = ray.get(fut_sample) + assert next_sample.count >= sample_batch_size * num_envs_per_worker + num_timesteps_so_far += next_sample.count + trajectories.append(next_sample) + + # Only launch more tasks if we don't already have enough pending + pending = len(agent_dict) * sample_batch_size * num_envs_per_worker + if num_timesteps_so_far + pending < train_batch_size: + fut_sample2 = agent.sample.remote() + agent_dict[fut_sample2] = agent + + return SampleBatch.concat_samples(trajectories) + + +def collect_samples_straggler_mitigation(agents, train_batch_size): + """Collects at least train_batch_size samples. + + This is the legacy behavior as of 0.6, and launches extra sample tasks to + potentially improve performance but can result in many wasted samples. + """ + + num_timesteps_so_far = 0 + trajectories = [] + agent_dict = {} + + for agent in agents: + fut_sample = agent.sample.remote() + agent_dict[fut_sample] = agent + + while num_timesteps_so_far < train_batch_size: + # TODO(pcm): Make wait support arbitrary iterators and remove the + # conversion to list here. + [fut_sample], _ = ray.wait(list(agent_dict)) + agent = agent_dict.pop(fut_sample) + # Start task with next trajectory and record it in the dictionary. + fut_sample2 = agent.sample.remote() + agent_dict[fut_sample2] = agent + + next_sample = ray.get(fut_sample) + num_timesteps_so_far += next_sample.count + trajectories.append(next_sample) + + logger.info("Discarding {} sample tasks".format(len(agent_dict))) + return SampleBatch.concat_samples(trajectories) diff --git a/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py new file mode 100644 index 0000000000000..744886f40ff21 --- /dev/null +++ b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py @@ -0,0 +1,101 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import ray +from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ + MultiAgentBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.timer import TimerStat + + +class SyncBatchReplayOptimizer(PolicyOptimizer): + """Variant of the sync replay optimizer that replays entire batches. + + This enables RNN support. Does not currently support prioritization.""" + + @override(PolicyOptimizer) + def _init(self, + learning_starts=1000, + buffer_size=10000, + train_batch_size=32): + self.replay_starts = learning_starts + self.max_buffer_size = buffer_size + self.train_batch_size = train_batch_size + assert self.max_buffer_size >= self.replay_starts + + # List of buffered sample batches + self.replay_buffer = [] + self.buffer_size = 0 + + # Stats + self.update_weights_timer = TimerStat() + self.sample_timer = TimerStat() + self.grad_timer = TimerStat() + self.learner_stats = {} + + @override(PolicyOptimizer) + def step(self): + with self.update_weights_timer: + if self.remote_evaluators: + weights = ray.put(self.local_evaluator.get_weights()) + for e in self.remote_evaluators: + e.set_weights.remote(weights) + + with self.sample_timer: + if self.remote_evaluators: + batches = ray.get( + [e.sample.remote() for e in self.remote_evaluators]) + else: + batches = [self.local_evaluator.sample()] + + # Handle everything as if multiagent + tmp = [] + for batch in batches: + if isinstance(batch, SampleBatch): + batch = MultiAgentBatch({ + DEFAULT_POLICY_ID: batch + }, batch.count) + tmp.append(batch) + batches = tmp + + for batch in batches: + self.replay_buffer.append(batch) + self.num_steps_sampled += batch.count + self.buffer_size += batch.count + while self.buffer_size > self.max_buffer_size: + evicted = self.replay_buffer.pop(0) + self.buffer_size -= evicted.count + + if self.num_steps_sampled >= self.replay_starts: + self._optimize() + + @override(PolicyOptimizer) + def stats(self): + return dict( + PolicyOptimizer.stats(self), **{ + "sample_time_ms": round(1000 * self.sample_timer.mean, 3), + "grad_time_ms": round(1000 * self.grad_timer.mean, 3), + "update_time_ms": round(1000 * self.update_weights_timer.mean, + 3), + "opt_peak_throughput": round(self.grad_timer.mean_throughput, + 3), + "opt_samples": round(self.grad_timer.mean_units_processed, 3), + "learner": self.learner_stats, + }) + + def _optimize(self): + samples = [random.choice(self.replay_buffer)] + while sum(s.count for s in samples) < self.train_batch_size: + samples.append(random.choice(self.replay_buffer)) + samples = SampleBatch.concat_samples(samples) + with self.grad_timer: + info_dict = self.local_evaluator.compute_apply(samples) + for policy_id, info in info_dict.items(): + if "stats" in info: + self.learner_stats[policy_id] = info["stats"] + self.grad_timer.push_units_processed(samples.count) + self.num_steps_trained += samples.count diff --git a/python/ray/rllib/optimizers/sync_replay_optimizer.py b/python/ray/rllib/optimizers/sync_replay_optimizer.py index 73df006014679..cdd187112e045 100644 --- a/python/ray/rllib/optimizers/sync_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_replay_optimizer.py @@ -11,8 +11,8 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch +from ray.rllib.utils.annotations import override from ray.rllib.utils.compression import pack_if_needed -from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.schedules import LinearSchedule @@ -24,6 +24,7 @@ class SyncReplayOptimizer(PolicyOptimizer): "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" + @override(PolicyOptimizer) def _init(self, learning_starts=1000, buffer_size=10000, @@ -52,7 +53,7 @@ def _init(self, self.sample_timer = TimerStat() self.replay_timer = TimerStat() self.grad_timer = TimerStat() - self.throughput = RunningStat() + self.learner_stats = {} # Set up replay buffer if prioritized_replay: @@ -69,6 +70,7 @@ def new_buffer(): assert buffer_size >= self.replay_starts + @override(PolicyOptimizer) def step(self): with self.update_weights_timer: if self.remote_evaluators: @@ -105,12 +107,29 @@ def step(self): self.num_steps_sampled += batch.count + @override(PolicyOptimizer) + def stats(self): + return dict( + PolicyOptimizer.stats(self), **{ + "sample_time_ms": round(1000 * self.sample_timer.mean, 3), + "replay_time_ms": round(1000 * self.replay_timer.mean, 3), + "grad_time_ms": round(1000 * self.grad_timer.mean, 3), + "update_time_ms": round(1000 * self.update_weights_timer.mean, + 3), + "opt_peak_throughput": round(self.grad_timer.mean_throughput, + 3), + "opt_samples": round(self.grad_timer.mean_units_processed, 3), + "learner": self.learner_stats, + }) + def _optimize(self): samples = self._replay() with self.grad_timer: info_dict = self.local_evaluator.compute_apply(samples) for policy_id, info in info_dict.items(): + if "stats" in info: + self.learner_stats[policy_id] = info["stats"] replay_buffer = self.replay_buffers[policy_id] if isinstance(replay_buffer, PrioritizedReplayBuffer): td_error = info["td_error"] @@ -138,26 +157,13 @@ def _replay(self): dones) = replay_buffer.sample(self.train_batch_size) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) - samples[policy_id] = SampleBatch({ - "obs": obses_t, - "actions": actions, - "rewards": rewards, - "new_obs": obses_tp1, - "dones": dones, - "weights": weights, - "batch_indexes": batch_indexes - }) + samples[policy_id] = SampleBatch({ + "obs": obses_t, + "actions": actions, + "rewards": rewards, + "new_obs": obses_tp1, + "dones": dones, + "weights": weights, + "batch_indexes": batch_indexes + }) return MultiAgentBatch(samples, self.train_batch_size) - - def stats(self): - return dict( - PolicyOptimizer.stats(self), **{ - "sample_time_ms": round(1000 * self.sample_timer.mean, 3), - "replay_time_ms": round(1000 * self.replay_timer.mean, 3), - "grad_time_ms": round(1000 * self.grad_timer.mean, 3), - "update_time_ms": round(1000 * self.update_weights_timer.mean, - 3), - "opt_peak_throughput": round(self.grad_timer.mean_throughput, - 3), - "opt_samples": round(self.grad_timer.mean_units_processed, 3), - }) diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index 38d5269f00393..b78e3ed01d70e 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -6,6 +6,7 @@ import logging from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.timer import TimerStat @@ -20,6 +21,7 @@ class SyncSamplesOptimizer(PolicyOptimizer): model weights are then broadcast to all remote evaluators. """ + @override(PolicyOptimizer) def _init(self, num_sgd_iter=1, train_batch_size=1): self.update_weights_timer = TimerStat() self.sample_timer = TimerStat() @@ -29,6 +31,7 @@ def _init(self, num_sgd_iter=1, train_batch_size=1): self.train_batch_size = train_batch_size self.learner_stats = {} + @override(PolicyOptimizer) def step(self): with self.update_weights_timer: if self.remote_evaluators: @@ -62,6 +65,7 @@ def step(self): self.num_steps_trained += samples.count return fetches + @override(PolicyOptimizer) def stats(self): return dict( PolicyOptimizer.stats(self), **{ diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 7249262dc9fdf..09960f2bb8d92 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -11,8 +11,7 @@ import gym import ray -from ray.rllib.agents.agent import get_agent_class -from ray.rllib.models import ModelCatalog +from ray.rllib.agents.registry import get_agent_class EXAMPLE_USAGE = """ Example Usage via RLlib CLI: @@ -24,6 +23,11 @@ --env CartPole-v0 --steps 1000000 --out rollouts.pkl """ +# Note: if you use any custom models or envs, register them here first, e.g.: +# +# ModelCatalog.register_custom_model("pa_model", ParametricActionsModel) +# register_env("pa_cartpole", lambda _: ParametricActionCartpole(10)) + def create_parser(parser_creator=None): parser_creator = parser_creator or argparse.ArgumentParser @@ -92,16 +96,19 @@ def run(args, parser): agent = cls(env=args.env, config=config) agent.restore(args.checkpoint) num_steps = int(args.steps) + rollout(agent, args.env, num_steps, args.out, args.no_render) + +def rollout(agent, env_name, num_steps, out=None, no_render=True): if hasattr(agent, "local_evaluator"): env = agent.local_evaluator.env else: - env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env)) - if args.out is not None: + env = gym.make(env_name) + if out is not None: rollouts = [] steps = 0 while steps < (num_steps or steps + 1): - if args.out is not None: + if out is not None: rollout = [] state = env.reset() done = False @@ -110,17 +117,17 @@ def run(args, parser): action = agent.compute_action(state) next_state, reward, done, _ = env.step(action) reward_total += reward - if not args.no_render: + if not no_render: env.render() - if args.out is not None: + if out is not None: rollout.append([state, action, next_state, reward, done]) steps += 1 state = next_state - if args.out is not None: + if out is not None: rollouts.append(rollout) print("Episode reward", reward_total) - if args.out is not None: - pickle.dump(rollouts, open(args.out, "wb")) + if out is not None: + pickle.dump(rollouts, open(out, "wb")) if __name__ == "__main__": diff --git a/python/ray/rllib/scripts.py b/python/ray/rllib/scripts.py index cc48b83cf3341..88d5d56292b13 100644 --- a/python/ray/rllib/scripts.py +++ b/python/ray/rllib/scripts.py @@ -14,7 +14,7 @@ rllib train --run DQN --env CartPole-v0 Example usage for rollout: - rllib rollout /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN + rllib rollout /trial_dir/checkpoint_1/checkpoint-1 --run DQN """ diff --git a/python/ray/rllib/setup-rllib-dev.py b/python/ray/rllib/setup-rllib-dev.py new file mode 100755 index 0000000000000..3876a83f7988f --- /dev/null +++ b/python/ray/rllib/setup-rllib-dev.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +"""This script allows you to develop RLlib without needing to compile Ray.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import click +import os +import subprocess + +import ray + +if __name__ == "__main__": + rllib_home = os.path.abspath(os.path.join(ray.__file__, "../rllib")) + local_home = os.path.abspath(os.path.dirname(__file__)) + assert os.path.isdir(rllib_home), rllib_home + assert os.path.isdir(local_home), local_home + click.confirm( + "This will replace:\n {}\nwith a symlink to:\n {}".format( + rllib_home, local_home), + abort=True) + if os.access(os.path.dirname(rllib_home), os.W_OK): + subprocess.check_call(["rm", "-rf", rllib_home]) + subprocess.check_call(["ln", "-s", local_home, rllib_home]) + else: + print("You don't have write permission to {}, using sudo:".format( + rllib_home)) + subprocess.check_call(["sudo", "rm", "-rf", rllib_home]) + subprocess.check_call(["sudo", "ln", "-s", local_home, rllib_home]) + print("Created links.\n\nIf you run into issues initializing Ray, please " + "ensure that your local repo and the installed Ray is in sync " + "(pip install -U the latest wheels at " + "https://ray.readthedocs.io/en/latest/installation.html, " + "and ensure you are up-to-date on the master branch on git).\n\n" + "Note that you may need to delete the rllib symlink when pip " + "installing new Ray versions to prevent pip from overwriting files " + "in your git repo.") diff --git a/python/ray/rllib/test/multiagent_pendulum.py b/python/ray/rllib/test/multiagent_pendulum.py new file mode 100644 index 0000000000000..c4ee5ce767b24 --- /dev/null +++ b/python/ray/rllib/test/multiagent_pendulum.py @@ -0,0 +1,42 @@ +"""Integration test: (1) pendulum works, (2) single-agent multi-agent works.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray +from ray.rllib.test.test_multi_agent_env import make_multiagent +from ray.tune import run_experiments +from ray.tune.registry import register_env + +if __name__ == "__main__": + ray.init() + MultiPendulum = make_multiagent("Pendulum-v0") + register_env("multi_pend", lambda _: MultiPendulum(1)) + trials = run_experiments({ + "test": { + "run": "PPO", + "env": "multi_pend", + "stop": { + "timesteps_total": 500000, + "episode_reward_mean": -200, + }, + "config": { + "train_batch_size": 2048, + "vf_clip_param": 10.0, + "num_workers": 0, + "num_envs_per_worker": 10, + "lambda": 0.1, + "gamma": 0.95, + "lr": 0.0003, + "sgd_minibatch_size": 64, + "num_sgd_iter": 10, + "model": { + "fcnet_hiddens": [64, 64], + }, + "batch_mode": "complete_episodes", + }, + } + }) + if trials[0].last_result["episode_reward_mean"] < -200: + raise ValueError("Did not get to -200 reward", trials[0].last_result) diff --git a/python/ray/rllib/test/run_regression_tests.py b/python/ray/rllib/test/run_regression_tests.py new file mode 100644 index 0000000000000..a542924bd50b7 --- /dev/null +++ b/python/ray/rllib/test/run_regression_tests.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# Runs one or more regression tests. Retries tests up to 3 times. +# +# Example usage: +# ./run_regression_tests.sh regression-tests/cartpole-es.yaml + +import yaml +import sys + +import ray +from ray.tune import run_experiments + +if __name__ == '__main__': + + ray.init() + + for test in sys.argv[1:]: + experiments = yaml.load(open(test).read()) + + print("== Test config ==") + print(yaml.dump(experiments)) + + for i in range(3): + trials = run_experiments(experiments) + + num_failures = 0 + for t in trials: + if (t.last_result["episode_reward_mean"] < + t.stopping_criterion["episode_reward_mean"]): + num_failures += 1 + + if not num_failures: + print("Regression test PASSED") + sys.exit(0) + + print("Regression test flaked, retry", i) + + print("Regression test FAILED") + sys.exit(1) diff --git a/python/ray/rllib/test/test_avail_actions_qmix.py b/python/ray/rllib/test/test_avail_actions_qmix.py new file mode 100644 index 0000000000000..606f358733b96 --- /dev/null +++ b/python/ray/rllib/test/test_avail_actions_qmix.py @@ -0,0 +1,68 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from gym.spaces import Tuple, Discrete, Dict, Box + +import ray +from ray.tune import register_env +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.agents.qmix import QMixAgent + + +class AvailActionsTestEnv(MultiAgentEnv): + action_space = Discrete(10) + observation_space = Dict({ + "obs": Discrete(3), + "action_mask": Box(0, 1, (10, )), + }) + + def __init__(self, env_config): + self.state = None + self.avail = env_config["avail_action"] + self.action_mask = [0] * 10 + self.action_mask[env_config["avail_action"]] = 1 + + def reset(self): + self.state = 0 + return { + "agent_1": { + "obs": self.state, + "action_mask": self.action_mask + } + } + + def step(self, action_dict): + if self.state > 0: + assert action_dict["agent_1"] == self.avail, \ + "Failed to obey available actions mask!" + self.state += 1 + rewards = {"agent_1": 1} + obs = {"agent_1": {"obs": 0, "action_mask": self.action_mask}} + dones = {"__all__": self.state > 20} + return obs, rewards, dones, {} + + +if __name__ == "__main__": + grouping = { + "group_1": ["agent_1"], # trivial grouping for testing + } + obs_space = Tuple([AvailActionsTestEnv.observation_space]) + act_space = Tuple([AvailActionsTestEnv.action_space]) + register_env( + "action_mask_test", + lambda config: AvailActionsTestEnv(config).with_agent_groups( + grouping, obs_space=obs_space, act_space=act_space)) + + ray.init() + agent = QMixAgent( + env="action_mask_test", + config={ + "num_envs_per_worker": 5, # test with vectorization on + "env_config": { + "avail_action": 3, + }, + }) + for _ in range(5): + agent.train() # OK if it doesn't trip the action assertion error + assert agent.train()["episode_reward_mean"] == 21.0 diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py index 852a02fc4d1e0..efa1aba0e2f07 100644 --- a/python/ray/rllib/test/test_catalog.py +++ b/python/ray/rllib/test/test_catalog.py @@ -72,13 +72,13 @@ def testDefaultModels(self): with tf.variable_scope("test1"): p1 = ModelCatalog.get_model({ - "obs": np.zeros((10, 3), dtype=np.float32) + "obs": tf.zeros((10, 3), dtype=tf.float32) }, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {}) self.assertEqual(type(p1), FullyConnectedNetwork) with tf.variable_scope("test2"): p2 = ModelCatalog.get_model({ - "obs": np.zeros((10, 84, 84, 3), dtype=np.float32) + "obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32) }, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {}) self.assertEqual(type(p2), VisionNetwork) diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index aa8fac28086ab..926c8573c911c 100644 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -4,10 +4,12 @@ from __future__ import division from __future__ import print_function +import os +import shutil import numpy as np import ray -from ray.rllib.agents.agent import get_agent_class +from ray.rllib.agents.registry import get_agent_class def get_mean_action(alg, obs): @@ -55,7 +57,7 @@ def get_mean_action(alg, obs): } -def test(use_object_store, alg_name, failures): +def test_ckpt_restore(use_object_store, alg_name, failures): cls = get_agent_class(alg_name) if "DDPG" in alg_name: alg1 = cls(config=CONFIGS[name], env="Pendulum-v0") @@ -86,11 +88,45 @@ def test(use_object_store, alg_name, failures): failures.append((alg_name, [a1, a2])) +def test_export(algo_name, failures): + cls = get_agent_class(algo_name) + if "DDPG" in algo_name: + algo = cls(config=CONFIGS[name], env="Pendulum-v0") + else: + algo = cls(config=CONFIGS[name], env="CartPole-v0") + + for _ in range(3): + res = algo.train() + print("current status: " + str(res)) + + export_dir = "/tmp/export_dir_%s" % algo_name + print("Exporting model ", algo_name, export_dir) + algo.export_policy_model(export_dir) + if not os.path.exists(os.path.join(export_dir, "saved_model.pb")) \ + or not os.listdir(os.path.join(export_dir, "variables")): + failures.append(algo_name) + shutil.rmtree(export_dir) + + print("Exporting checkpoint", algo_name, export_dir) + algo.export_policy_checkpoint(export_dir) + if not os.path.exists(os.path.join(export_dir, "model.meta")) \ + or not os.path.exists(os.path.join(export_dir, "model.index")) \ + or not os.path.exists(os.path.join(export_dir, "checkpoint")): + failures.append(algo_name) + shutil.rmtree(export_dir) + + if __name__ == "__main__": failures = [] for use_object_store in [False, True]: for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"]: - test(use_object_store, name, failures) + test_ckpt_restore(use_object_store, name, failures) assert not failures, failures print("All checkpoint restore tests passed!") + + failures = [] + for name in ["DQN", "DDPG", "PPO", "A3C"]: + test_export(name, failures) + assert not failures, failures + print("All export tests passed!") diff --git a/python/ray/rllib/test/test_env_with_subprocess.py b/python/ray/rllib/test/test_env_with_subprocess.py index 70ccb46cce507..fc940cdea05eb 100644 --- a/python/ray/rllib/test/test_env_with_subprocess.py +++ b/python/ray/rllib/test/test_env_with_subprocess.py @@ -38,10 +38,10 @@ def __init__(self, config): atexit.register(lambda: self.subproc.kill()) def reset(self): - return [0] + return 0 def step(self, action): - return [0], 0, True, {} + return 0, 0, True, {} def leaked_processes(): diff --git a/python/ray/rllib/test/test_evaluators.py b/python/ray/rllib/test/test_evaluators.py index 9ae0994f33466..c7a72d7a5bb87 100644 --- a/python/ray/rllib/test/test_evaluators.py +++ b/python/ray/rllib/test/test_evaluators.py @@ -4,7 +4,7 @@ import unittest -from ray.rllib.agents.dqn.dqn_policy_graph import adjust_nstep +from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep class DQNTest(unittest.TestCase): @@ -14,7 +14,7 @@ def testNStep(self): rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0] new_obs = [2, 3, 4, 5, 6, 7, 8] dones = [0, 0, 0, 0, 0, 0, 1] - adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones) + _adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones) self.assertEqual(obs, [1, 2, 3, 4, 5, 6, 7]) self.assertEqual(actions, ["a", "b", "a", "a", "a", "b", "a"]) self.assertEqual(new_obs, [4, 5, 6, 7, 8, 8, 8]) diff --git a/python/ray/rllib/test/test_io.py b/python/ray/rllib/test/test_io.py new file mode 100644 index 0000000000000..e455503407e6d --- /dev/null +++ b/python/ray/rllib/test/test_io.py @@ -0,0 +1,235 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import numpy as np +import os +import shutil +import tempfile +import time +import unittest + +import ray +from ray.rllib.agents.pg import PGAgent +from ray.rllib.evaluation import SampleBatch +from ray.rllib.offline import IOContext, JsonWriter, JsonReader +from ray.rllib.offline.json_writer import _to_json + +SAMPLES = SampleBatch({ + "actions": np.array([1, 2, 3]), + "obs": np.array([4, 5, 6]) +}) + + +def make_sample_batch(i): + return SampleBatch({ + "actions": np.array([i, i, i]), + "obs": np.array([i, i, i]) + }) + + +class AgentIOTest(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def writeOutputs(self, output): + agent = PGAgent( + env="CartPole-v0", + config={ + "output": output, + "sample_batch_size": 250, + }) + agent.train() + return agent + + def testAgentOutputOk(self): + self.writeOutputs(self.test_dir) + self.assertEqual(len(os.listdir(self.test_dir)), 1) + ioctx = IOContext(self.test_dir, {}, 0, None) + reader = JsonReader(ioctx, self.test_dir + "/*.json") + reader.next() + + def testAgentOutputLogdir(self): + agent = self.writeOutputs("logdir") + self.assertEqual(len(glob.glob(agent.logdir + "/output-*.json")), 1) + + def testAgentInputDir(self): + self.writeOutputs(self.test_dir) + agent = PGAgent( + env="CartPole-v0", + config={ + "input": self.test_dir, + "input_evaluation": None, + }) + result = agent.train() + self.assertEqual(result["timesteps_total"], 250) # read from input + self.assertTrue(np.isnan(result["episode_reward_mean"])) + + def testAgentInputEvalSim(self): + self.writeOutputs(self.test_dir) + agent = PGAgent( + env="CartPole-v0", + config={ + "input": self.test_dir, + "input_evaluation": "simulation", + }) + for _ in range(50): + result = agent.train() + if not np.isnan(result["episode_reward_mean"]): + return # simulation ok + time.sleep(0.1) + assert False, "did not see any simulation results" + + def testAgentInputList(self): + self.writeOutputs(self.test_dir) + agent = PGAgent( + env="CartPole-v0", + config={ + "input": glob.glob(self.test_dir + "/*.json"), + "input_evaluation": None, + "sample_batch_size": 99, + }) + result = agent.train() + self.assertEqual(result["timesteps_total"], 250) # read from input + self.assertTrue(np.isnan(result["episode_reward_mean"])) + + def testAgentInputDict(self): + self.writeOutputs(self.test_dir) + agent = PGAgent( + env="CartPole-v0", + config={ + "input": { + self.test_dir: 0.1, + "sampler": 0.9, + }, + "train_batch_size": 2000, + "input_evaluation": None, + }) + result = agent.train() + self.assertTrue(not np.isnan(result["episode_reward_mean"])) + + +class JsonIOTest(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def testWriteSimple(self): + ioctx = IOContext(self.test_dir, {}, 0, None) + writer = JsonWriter( + ioctx, self.test_dir, max_file_size=1000, compress_columns=["obs"]) + self.assertEqual(len(os.listdir(self.test_dir)), 0) + writer.write(SAMPLES) + writer.write(SAMPLES) + self.assertEqual(len(os.listdir(self.test_dir)), 1) + + def testWriteFileURI(self): + ioctx = IOContext(self.test_dir, {}, 0, None) + writer = JsonWriter( + ioctx, + "file:" + self.test_dir, + max_file_size=1000, + compress_columns=["obs"]) + self.assertEqual(len(os.listdir(self.test_dir)), 0) + writer.write(SAMPLES) + writer.write(SAMPLES) + self.assertEqual(len(os.listdir(self.test_dir)), 1) + + def testWritePaginate(self): + ioctx = IOContext(self.test_dir, {}, 0, None) + writer = JsonWriter( + ioctx, self.test_dir, max_file_size=5000, compress_columns=["obs"]) + self.assertEqual(len(os.listdir(self.test_dir)), 0) + for _ in range(100): + writer.write(SAMPLES) + self.assertEqual(len(os.listdir(self.test_dir)), 12) + + def testReadWrite(self): + ioctx = IOContext(self.test_dir, {}, 0, None) + writer = JsonWriter( + ioctx, self.test_dir, max_file_size=5000, compress_columns=["obs"]) + for i in range(100): + writer.write(make_sample_batch(i)) + reader = JsonReader(ioctx, self.test_dir + "/*.json") + seen_a = set() + seen_o = set() + for i in range(1000): + batch = reader.next() + seen_a.add(batch["actions"][0]) + seen_o.add(batch["obs"][0]) + self.assertGreater(len(seen_a), 90) + self.assertLess(len(seen_a), 101) + self.assertGreater(len(seen_o), 90) + self.assertLess(len(seen_o), 101) + + def testSkipsOverEmptyLinesAndFiles(self): + ioctx = IOContext(self.test_dir, {}, 0, None) + open(self.test_dir + "/empty", "w").close() + with open(self.test_dir + "/f1", "w") as f: + f.write("\n") + f.write("\n") + f.write(_to_json(make_sample_batch(0), [])) + with open(self.test_dir + "/f2", "w") as f: + f.write(_to_json(make_sample_batch(1), [])) + f.write("\n") + reader = JsonReader(ioctx, [ + self.test_dir + "/empty", + self.test_dir + "/f1", + "file:" + self.test_dir + "/f2", + ]) + seen_a = set() + for i in range(100): + batch = reader.next() + seen_a.add(batch["actions"][0]) + self.assertEqual(len(seen_a), 2) + + def testSkipsOverCorruptedLines(self): + ioctx = IOContext(self.test_dir, {}, 0, None) + with open(self.test_dir + "/f1", "w") as f: + f.write(_to_json(make_sample_batch(0), [])) + f.write("\n") + f.write(_to_json(make_sample_batch(1), [])) + f.write("\n") + f.write(_to_json(make_sample_batch(2), [])) + f.write("\n") + f.write(_to_json(make_sample_batch(3), [])) + f.write("\n") + f.write("{..corrupted_json_record") + reader = JsonReader(ioctx, [ + self.test_dir + "/f1", + ]) + seen_a = set() + for i in range(10): + batch = reader.next() + seen_a.add(batch["actions"][0]) + self.assertEqual(len(seen_a), 4) + + def testAbortOnAllEmptyInputs(self): + ioctx = IOContext(self.test_dir, {}, 0, None) + open(self.test_dir + "/empty", "w").close() + reader = JsonReader(ioctx, [ + self.test_dir + "/empty", + ]) + self.assertRaises(ValueError, lambda: reader.next()) + with open(self.test_dir + "/empty1", "w") as f: + for _ in range(100): + f.write("\n") + with open(self.test_dir + "/empty2", "w") as f: + for _ in range(100): + f.write("\n") + reader = JsonReader(ioctx, [ + self.test_dir + "/empty1", + self.test_dir + "/empty2", + ]) + self.assertRaises(ValueError, lambda: reader.next()) + + +if __name__ == "__main__": + ray.init(num_cpus=1) + unittest.main(verbosity=2) diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index 5712390c05c6e..6f5d3325dc52f 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -22,6 +22,12 @@ from ray.tune.registry import register_env +def one_hot(i, n): + out = [0.0] * n + out[i] = 1.0 + return out + + class BasicMultiAgent(MultiAgentEnv): """Env of N independent agents, each of which exits after 25 steps.""" @@ -64,7 +70,7 @@ def __init__(self, num, increment_obs=False): self.last_info = {} self.i = 0 self.num = num - self.observation_space = gym.spaces.Discrete(2) + self.observation_space = gym.spaces.Discrete(10) self.action_space = gym.spaces.Discrete(2) def reset(self): @@ -290,7 +296,7 @@ def testMultiAgentSampleWithHorizon(self): def testMultiAgentSampleRoundRobin(self): act_space = gym.spaces.Discrete(2) - obs_space = gym.spaces.Discrete(2) + obs_space = gym.spaces.Discrete(10) ev = PolicyEvaluator( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), policy_graph={ @@ -303,10 +309,20 @@ def testMultiAgentSampleRoundRobin(self): # since we round robin introduce agents into the env, some of the env # steps don't count as proper transitions self.assertEqual(batch.policy_batches["p0"].count, 42) - self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], - [0, 1, 2, 3, 4] * 2) - self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], - [1, 2, 3, 4, 5] * 2) + self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], [ + one_hot(0, 10), + one_hot(1, 10), + one_hot(2, 10), + one_hot(3, 10), + one_hot(4, 10), + ] * 2) + self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], [ + one_hot(1, 10), + one_hot(2, 10), + one_hot(3, 10), + one_hot(4, 10), + one_hot(5, 10), + ] * 2) self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10], [100, 100, 100, 100, 0] * 2) self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10], @@ -323,7 +339,8 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - episodes=None): + episodes=None, + **kwargs): return [0] * len(obs_batch), [[h] * len(obs_batch)], {} def get_initial_state(self): @@ -347,7 +364,8 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - episodes=None): + episodes=None, + **kwargs): # Pretend we did a model-based rollout and want to return # the extra trajectory. builder = episodes[0].new_batch_builder() diff --git a/python/ray/rllib/test/test_nested_spaces.py b/python/ray/rllib/test/test_nested_spaces.py index 490e6af1520af..bbdfb07ed062a 100644 --- a/python/ray/rllib/test/test_nested_spaces.py +++ b/python/ray/rllib/test/test_nested_spaces.py @@ -13,10 +13,13 @@ import ray from ray.rllib.agents.pg import PGAgent +from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.env import MultiAgentEnv from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.models import ModelCatalog from ray.rllib.models.model import Model +from ray.rllib.rollout import rollout from ray.rllib.test.test_external_env import SimpleServing from ray.tune.registry import register_env @@ -88,6 +91,34 @@ def step(self, action): return TUPLE_SAMPLES[self.steps], 1, self.steps >= 5, {} +class NestedMultiAgentEnv(MultiAgentEnv): + def __init__(self): + self.steps = 0 + + def reset(self): + return { + "dict_agent": DICT_SAMPLES[0], + "tuple_agent": TUPLE_SAMPLES[0], + } + + def step(self, actions): + self.steps += 1 + obs = { + "dict_agent": DICT_SAMPLES[self.steps], + "tuple_agent": TUPLE_SAMPLES[self.steps], + } + rew = { + "dict_agent": 0, + "tuple_agent": 0, + } + dones = {"__all__": self.steps >= 5} + infos = { + "dict_agent": {}, + "tuple_agent": {}, + } + return obs, rew, dones, infos + + class InvalidModel(Model): def _build_layers_v2(self, input_dict, num_outputs, options): return "not", "valid" @@ -107,7 +138,8 @@ def spy(pos, front_cam, task): # redis to communicate back to our suite ray.experimental.internal_kv._internal_kv_put( "d_spy_in_{}".format(DictSpyModel.capture_index), - pickle.dumps((pos, front_cam, task))) + pickle.dumps((pos, front_cam, task)), + overwrite=True) DictSpyModel.capture_index += 1 return 0 @@ -135,7 +167,8 @@ def spy(pos, cam, task): # redis to communicate back to our suite ray.experimental.internal_kv._internal_kv_put( "t_spy_in_{}".format(TupleSpyModel.capture_index), - pickle.dumps((pos, cam, task))) + pickle.dumps((pos, cam, task)), + overwrite=True) TupleSpyModel.capture_index += 1 return 0 @@ -182,6 +215,7 @@ def doTestNestedDict(self, make_env, test_lstm=False): config={ "num_workers": 0, "sample_batch_size": 5, + "train_batch_size": 5, "model": { "custom_model": "composite", "use_lstm": test_lstm, @@ -210,6 +244,7 @@ def doTestNestedTuple(self, make_env): config={ "num_workers": 0, "sample_batch_size": 5, + "train_batch_size": 5, "model": { "custom_model": "composite2", }, @@ -242,10 +277,8 @@ def testNestedDictServing(self): self.doTestNestedDict(lambda _: SimpleServing(NestedDictEnv())) def testNestedDictAsync(self): - self.assertRaisesRegexp( - ValueError, "Found raw Dict space.*", - lambda: self.doTestNestedDict( - lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv()))) + self.doTestNestedDict( + lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv())) def testNestedTupleGym(self): self.doTestNestedTuple(lambda _: NestedTupleEnv()) @@ -258,10 +291,73 @@ def testNestedTupleServing(self): self.doTestNestedTuple(lambda _: SimpleServing(NestedTupleEnv())) def testNestedTupleAsync(self): - self.assertRaisesRegexp( - ValueError, "Found raw Tuple space.*", - lambda: self.doTestNestedTuple( - lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv()))) + self.doTestNestedTuple( + lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv())) + + def testMultiAgentComplexSpaces(self): + ModelCatalog.register_custom_model("dict_spy", DictSpyModel) + ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel) + register_env("nested_ma", lambda _: NestedMultiAgentEnv()) + act_space = spaces.Discrete(2) + pg = PGAgent( + env="nested_ma", + config={ + "num_workers": 0, + "sample_batch_size": 5, + "train_batch_size": 5, + "multiagent": { + "policy_graphs": { + "tuple_policy": ( + PGPolicyGraph, TUPLE_SPACE, act_space, + {"model": {"custom_model": "tuple_spy"}}), + "dict_policy": ( + PGPolicyGraph, DICT_SPACE, act_space, + {"model": {"custom_model": "dict_spy"}}), + }, + "policy_mapping_fn": lambda a: { + "tuple_agent": "tuple_policy", + "dict_agent": "dict_policy"}[a], + }, + }) + pg.train() + + for i in range(4): + seen = pickle.loads( + ray.experimental.internal_kv._internal_kv_get( + "d_spy_in_{}".format(i))) + pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist() + cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist() + task_i = one_hot( + DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) + self.assertEqual(seen[0][0].tolist(), pos_i) + self.assertEqual(seen[1][0].tolist(), cam_i) + self.assertEqual(seen[2][0].tolist(), task_i) + + for i in range(4): + seen = pickle.loads( + ray.experimental.internal_kv._internal_kv_get( + "t_spy_in_{}".format(i))) + pos_i = TUPLE_SAMPLES[i][0].tolist() + cam_i = TUPLE_SAMPLES[i][1][0].tolist() + task_i = one_hot(TUPLE_SAMPLES[i][2], 5) + self.assertEqual(seen[0][0].tolist(), pos_i) + self.assertEqual(seen[1][0].tolist(), cam_i) + self.assertEqual(seen[2][0].tolist(), task_i) + + def testRolloutDictSpace(self): + register_env("nested", lambda _: NestedDictEnv()) + agent = PGAgent(env="nested") + agent.train() + path = agent.save() + agent.stop() + + # Test train works on restore + agent2 = PGAgent(env="nested") + agent2.restore(path) + agent2.train() + + # Test rollout works on restore + rollout(agent2, "nested", 100) if __name__ == "__main__": diff --git a/python/ray/rllib/test/test_optimizers.py b/python/ray/rllib/test/test_optimizers.py index 6a5022d368d72..074e0c081bb52 100644 --- a/python/ray/rllib/test/test_optimizers.py +++ b/python/ray/rllib/test/test_optimizers.py @@ -2,14 +2,19 @@ from __future__ import division from __future__ import print_function -import unittest - +import gym import numpy as np +import tensorflow as tf +import time +import unittest import ray -from ray.rllib.test.mock_evaluator import _MockEvaluator -from ray.rllib.optimizers import AsyncGradientsOptimizer +from ray.rllib.agents.ppo import PPOAgent +from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.evaluation import SampleBatch +from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer +from ray.rllib.test.mock_evaluator import _MockEvaluator class AsyncOptimizerTest(unittest.TestCase): @@ -27,6 +32,64 @@ def testBasic(self): self.assertTrue(all(local.get_weights() == 0)) +class PPOCollectTest(unittest.TestCase): + def tearDown(self): + ray.shutdown() + + def testPPOSampleWaste(self): + ray.init(num_cpus=4) + + # Check we at least collect the initial wave of samples + ppo = PPOAgent( + env="CartPole-v0", + config={ + "sample_batch_size": 200, + "train_batch_size": 128, + "num_workers": 3, + }) + ppo.train() + self.assertEqual(ppo.optimizer.num_steps_sampled, 600) + ppo.stop() + + # Check we collect at least the specified amount of samples + ppo = PPOAgent( + env="CartPole-v0", + config={ + "sample_batch_size": 200, + "train_batch_size": 900, + "num_workers": 3, + }) + ppo.train() + self.assertEqual(ppo.optimizer.num_steps_sampled, 1000) + ppo.stop() + + # Check in vectorized mode + ppo = PPOAgent( + env="CartPole-v0", + config={ + "sample_batch_size": 200, + "num_envs_per_worker": 2, + "train_batch_size": 900, + "num_workers": 3, + }) + ppo.train() + self.assertEqual(ppo.optimizer.num_steps_sampled, 1200) + ppo.stop() + + # Check legacy mode + ppo = PPOAgent( + env="CartPole-v0", + config={ + "sample_batch_size": 200, + "train_batch_size": 128, + "num_workers": 3, + "straggler_mitigation": True, + }) + ppo.train() + self.assertEqual(ppo.optimizer.num_steps_sampled, 200) + ppo.stop() + + class SampleBatchTest(unittest.TestCase): def testConcat(self): b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}) @@ -40,5 +103,137 @@ def testConcat(self): self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5]) +class AsyncSamplesOptimizerTest(unittest.TestCase): + @classmethod + def tearDownClass(cls): + ray.shutdown() + + @classmethod + def setUpClass(cls): + ray.init(num_cpus=8) + + def testSimple(self): + local, remotes = self._make_evs() + optimizer = AsyncSamplesOptimizer(local, remotes, {}) + self._wait_for(optimizer, 1000, 1000) + + def testMultiGPU(self): + local, remotes = self._make_evs() + optimizer = AsyncSamplesOptimizer(local, remotes, { + "num_gpus": 2, + "_fake_gpus": True + }) + self._wait_for(optimizer, 1000, 1000) + + def testMultiGPUParallelLoad(self): + local, remotes = self._make_evs() + optimizer = AsyncSamplesOptimizer(local, remotes, { + "num_gpus": 2, + "num_data_loader_buffers": 2, + "_fake_gpus": True + }) + self._wait_for(optimizer, 1000, 1000) + + def testMultiplePasses(self): + local, remotes = self._make_evs() + optimizer = AsyncSamplesOptimizer( + local, remotes, { + "minibatch_buffer_size": 10, + "num_sgd_iter": 10, + "sample_batch_size": 10, + "train_batch_size": 50, + }) + self._wait_for(optimizer, 1000, 10000) + self.assertLess(optimizer.stats()["num_steps_sampled"], 5000) + self.assertGreater(optimizer.stats()["num_steps_trained"], 8000) + + def testReplay(self): + local, remotes = self._make_evs() + optimizer = AsyncSamplesOptimizer( + local, remotes, { + "replay_buffer_num_slots": 100, + "replay_proportion": 10, + "sample_batch_size": 10, + "train_batch_size": 10, + }) + self._wait_for(optimizer, 1000, 1000) + self.assertLess(optimizer.stats()["num_steps_sampled"], 5000) + self.assertGreater(optimizer.stats()["num_steps_replayed"], 8000) + self.assertGreater(optimizer.stats()["num_steps_trained"], 8000) + + def testReplayAndMultiplePasses(self): + local, remotes = self._make_evs() + optimizer = AsyncSamplesOptimizer( + local, remotes, { + "minibatch_buffer_size": 10, + "num_sgd_iter": 10, + "replay_buffer_num_slots": 100, + "replay_proportion": 10, + "sample_batch_size": 10, + "train_batch_size": 10, + }) + self._wait_for(optimizer, 1000, 1000) + self.assertLess(optimizer.stats()["num_steps_sampled"], 5000) + self.assertGreater(optimizer.stats()["num_steps_replayed"], 8000) + self.assertGreater(optimizer.stats()["num_steps_trained"], 40000) + + def testRejectBadConfigs(self): + local, remotes = self._make_evs() + self.assertRaises( + ValueError, lambda: AsyncSamplesOptimizer( + local, remotes, + {"num_data_loader_buffers": 2, "minibatch_buffer_size": 4})) + optimizer = AsyncSamplesOptimizer( + local, remotes, { + "num_gpus": 2, + "train_batch_size": 100, + "sample_batch_size": 50, + "_fake_gpus": True + }) + self._wait_for(optimizer, 1000, 1000) + optimizer = AsyncSamplesOptimizer( + local, remotes, { + "num_gpus": 2, + "train_batch_size": 100, + "sample_batch_size": 25, + "_fake_gpus": True + }) + self._wait_for(optimizer, 1000, 1000) + optimizer = AsyncSamplesOptimizer( + local, remotes, { + "num_gpus": 2, + "train_batch_size": 100, + "sample_batch_size": 74, + "_fake_gpus": True + }) + self._wait_for(optimizer, 1000, 1000) + + def _make_evs(self): + def make_sess(): + return tf.Session(config=tf.ConfigProto(device_count={"CPU": 2})) + + local = PolicyEvaluator( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_graph=PPOPolicyGraph, + tf_session_creator=make_sess) + remotes = [ + PolicyEvaluator.as_remote().remote( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_graph=PPOPolicyGraph, + tf_session_creator=make_sess) + ] + return local, remotes + + def _wait_for(self, optimizer, num_steps_sampled, num_steps_trained): + start = time.time() + while time.time() - start < 30: + optimizer.step() + if optimizer.num_steps_sampled > num_steps_sampled and \ + optimizer.num_steps_trained > num_steps_trained: + print("OK", optimizer.stats()) + return + raise AssertionError("TIMED OUT", optimizer.stats()) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/python/ray/rllib/test/test_policy_evaluator.py b/python/ray/rllib/test/test_policy_evaluator.py index cf319a7e922b2..adff6aa91eb37 100644 --- a/python/ray/rllib/test/test_policy_evaluator.py +++ b/python/ray/rllib/test/test_policy_evaluator.py @@ -25,7 +25,8 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - episodes=None): + episodes=None, + **kwargs): return [0] * len(obs_batch), [], {} def postprocess_trajectory(self, @@ -42,7 +43,8 @@ def compute_actions(self, state_batches, prev_action_batch=None, prev_reward_batch=None, - episodes=None): + episodes=None, + **kwargs): raise Exception("intentional error") def postprocess_trajectory(self, @@ -186,6 +188,7 @@ def testCallbacks(self): env="CartPole-v0", config={ "num_workers": 0, "sample_batch_size": 50, + "train_batch_size": 50, "callbacks": { "on_episode_start": lambda x: counts.update({"start": 1}), "on_episode_step": lambda x: counts.update({"step": 1}), @@ -263,18 +266,6 @@ def testAsync(self): self.assertIn(key, batch) self.assertGreater(batch["advantages"][0], 1) - def testAutoConcat(self): - ev = PolicyEvaluator( - env_creator=lambda _: MockEnv(episode_length=40), - policy_graph=MockPolicyGraph, - sample_async=True, - batch_steps=10, - batch_mode="truncate_episodes", - observation_filter="ConcurrentMeanStdFilter") - time.sleep(2) - batch = ev.sample() - self.assertEqual(batch.count, 40) # auto-concat up to 5 episodes - def testAutoVectorization(self): ev = PolicyEvaluator( env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg), diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index b98a006bca3bb..fbfd1f5eae563 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -8,7 +8,7 @@ import sys import ray -from ray.rllib.agents.agent import get_agent_class +from ray.rllib.agents.registry import get_agent_class from ray.rllib.test.test_multi_agent_env import MultiCartpole, MultiMountainCar from ray.rllib.utils.error import UnsupportedSpaceException from ray.tune.registry import register_env @@ -120,12 +120,15 @@ def testAll(self): stats, check_bounds=True) check_support("DQN", {"timesteps_per_iteration": 1}, stats) - check_support("A3C", { - "num_workers": 1, - "optimizer": { - "grads_per_step": 1 - } - }, stats) + check_support( + "A3C", { + "num_workers": 1, + "optimizer": { + "grads_per_step": 1 + } + }, + stats, + check_bounds=True) check_support( "PPO", { "num_workers": 1, @@ -133,9 +136,6 @@ def testAll(self): "train_batch_size": 10, "sample_batch_size": 10, "sgd_minibatch_size": 1, - "model": { - "squash_to_range": True - }, }, stats, check_bounds=True) @@ -153,7 +153,13 @@ def testAll(self): "num_rollouts": 1, "rollouts_used": 1 }, stats) - check_support("PG", {"num_workers": 1, "optimizer": {}}, stats) + check_support( + "PG", { + "num_workers": 1, + "optimizer": {} + }, + stats, + check_bounds=True) num_unexpected_errors = 0 for (alg, a_name, o_name), stat in sorted(stats.items()): if stat not in ["ok", "unsupported"]: @@ -178,7 +184,6 @@ def testMultiAgent(self): "train_batch_size": 10, "sample_batch_size": 10, "sgd_minibatch_size": 1, - "simple_optimizer": True, }) check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}}) check_support_multiagent("DDPG", {"timesteps_per_iteration": 1}) diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index a1d8e13a1f57c..d9f7cf58e0b4c 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -38,30 +38,33 @@ def create_parser(parser_creator=None): "--redis-address", default=None, type=str, - help="The Redis address of the cluster.") + help="Connect to an existing Ray cluster at this address instead " + "of starting a new one.") parser.add_argument( "--ray-num-cpus", default=None, type=int, - help="--num-cpus to pass to Ray." - " This only has an affect in local mode.") + help="--num-cpus to use if starting a new cluster.") parser.add_argument( "--ray-num-gpus", default=None, type=int, - help="--num-gpus to pass to Ray." - " This only has an affect in local mode.") + help="--num-gpus to use if starting a new cluster.") parser.add_argument( "--ray-num-local-schedulers", default=None, type=int, help="Emulate multiple cluster nodes for debugging.") + parser.add_argument( + "--ray-redis-max-memory", + default=None, + type=int, + help="--redis-max-memory to use if starting a new cluster.") parser.add_argument( "--ray-object-store-memory", default=None, type=int, - help="--object-store-memory to pass to Ray." - " This only has an affect in local mode.") + help="--object-store-memory to use if starting a new cluster.") parser.add_argument( "--experiment-name", default="default", @@ -97,9 +100,9 @@ def run(args, parser): "run": args.run, "checkpoint_freq": args.checkpoint_freq, "local_dir": args.local_dir, - "trial_resources": ( - args.trial_resources and - resources_to_json(args.trial_resources)), + "resources_per_trial": ( + args.resources_per_trial and + resources_to_json(args.resources_per_trial)), "stop": args.stop, "config": dict(args.config, env=args.env), "restore": args.restore, @@ -122,12 +125,14 @@ def run(args, parser): "num_cpus": args.ray_num_cpus or 1, "num_gpus": args.ray_num_gpus or 0, }, - object_store_memory=args.ray_object_store_memory) + object_store_memory=args.ray_object_store_memory, + redis_max_memory=args.ray_redis_max_memory) ray.init(redis_address=cluster.redis_address) else: ray.init( redis_address=args.redis_address, object_store_memory=args.ray_object_store_memory, + redis_max_memory=args.ray_redis_max_memory, num_cpus=args.ray_num_cpus, num_gpus=args.ray_num_gpus) run_experiments( diff --git a/python/ray/rllib/tuned_examples/generate_regression_tests.py b/python/ray/rllib/tuned_examples/generate_regression_tests.py deleted file mode 100755 index 3196bd4d03d2f..0000000000000 --- a/python/ray/rllib/tuned_examples/generate_regression_tests.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python -# This script generates all the regression tests for RLlib. - -import glob -import re -import os -import os.path as osp - -CONFIG_DIR = osp.join(osp.dirname(osp.abspath(__file__)), "regression_tests") - -TEMPLATE = """ -class Test{name}(Regression): - _file = "{filename}" - - def setup_cache(self): - return _evaulate_config(self._file) - -""" - -if __name__ == '__main__': - os.chdir(CONFIG_DIR) - - with open("regression_test.py", "a") as f: - for filename in sorted(glob.glob("*.yaml")): - splits = re.findall(r"\w+", osp.splitext(filename)[0]) - test_name = "".join([s.capitalize() for s in splits]) - f.write(TEMPLATE.format(name=test_name, filename=filename)) diff --git a/python/ray/rllib/tuned_examples/pendulum-ppo.yaml b/python/ray/rllib/tuned_examples/pendulum-ppo.yaml index b8c0293a3e338..3e9d45179bc59 100644 --- a/python/ray/rllib/tuned_examples/pendulum-ppo.yaml +++ b/python/ray/rllib/tuned_examples/pendulum-ppo.yaml @@ -5,7 +5,8 @@ pendulum-ppo: config: train_batch_size: 2048 vf_clip_param: 10.0 - num_workers: 2 + num_workers: 0 + num_envs_per_worker: 10 lambda: 0.1 gamma: 0.95 lr: 0.0003 diff --git a/python/ray/rllib/tuned_examples/pong-impala-fast.yaml b/python/ray/rllib/tuned_examples/pong-impala-fast.yaml index 3c29f4e0c08e4..2162e08d3342a 100644 --- a/python/ray/rllib/tuned_examples/pong-impala-fast.yaml +++ b/python/ray/rllib/tuned_examples/pong-impala-fast.yaml @@ -13,7 +13,7 @@ pong-impala-fast: num_envs_per_worker: 5 broadcast_interval: 5 max_sample_requests_in_flight_per_worker: 1 - num_parallel_data_loaders: 4 + num_data_loader_buffers: 4 num_gpus: 2 model: dim: 42 diff --git a/python/ray/rllib/tuned_examples/pong-ppo.yaml b/python/ray/rllib/tuned_examples/pong-ppo.yaml index 1447481643fe5..d7e273cc6e2bd 100644 --- a/python/ray/rllib/tuned_examples/pong-ppo.yaml +++ b/python/ray/rllib/tuned_examples/pong-ppo.yaml @@ -1,17 +1,26 @@ -# On a Tesla K80 GPU, this achieves the maximum reward in about 1-1.5 hours. +# On a single GPU, this achieves maximum reward in ~15-20 minutes. # -# $ python train.py -f tuned_examples/pong-ppo.yaml --ray-num-gpus=1 +# $ python train.py -f tuned_examples/pong-ppo.yaml # -# - PPO_PongDeterministic-v4_0: TERMINATED [pid=16387], 4984 s, 1117981 ts, 21 rew -# - PPO_PongDeterministic-v4_0: TERMINATED [pid=83606], 4592 s, 1068671 ts, 21 rew -# -pong-deterministic-ppo: - env: PongDeterministic-v4 +pong-ppo: + env: PongNoFrameskip-v4 run: PPO - stop: - episode_reward_mean: 21 config: - gamma: 0.99 - num_workers: 4 - num_sgd_iter: 20 + lambda: 0.95 + kl_coeff: 0.5 + clip_rewards: True + clip_param: 0.1 + vf_clip_param: 10.0 + entropy_coeff: 0.01 + train_batch_size: 5000 + sample_batch_size: 20 + sgd_minibatch_size: 500 + num_sgd_iter: 10 + num_workers: 32 + num_envs_per_worker: 5 + batch_mode: truncate_episodes + observation_filter: NoFilter + vf_share_layers: true num_gpus: 1 + model: + dim: 42 diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c-pytorch.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c-pytorch.yaml deleted file mode 100644 index a25da3c7769a9..0000000000000 --- a/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c-pytorch.yaml +++ /dev/null @@ -1,10 +0,0 @@ -cartpole-a3c: - env: CartPole-v0 - run: A3C - stop: - episode_reward_mean: 200 - time_total_s: 600 - config: - num_workers: 1 - gamma: 0.95 - use_pytorch: true diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c.yaml index f20ea73c3b681..08ff2206fa428 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c.yaml @@ -2,8 +2,8 @@ cartpole-a3c: env: CartPole-v0 run: A3C stop: - episode_reward_mean: 200 - time_total_s: 600 + episode_reward_mean: 100 + timesteps_total: 100000 config: num_workers: 1 gamma: 0.95 diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-ars.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ars.yaml index 550170c2ec14a..bae79b2bb513b 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/cartpole-ars.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ars.yaml @@ -2,15 +2,15 @@ cartpole-ars: env: CartPole-v0 run: ARS stop: - episode_reward_mean: 200 - time_total_s: 600 + episode_reward_mean: 50 + timesteps_total: 500000 config: noise_stdev: 0.02 num_rollouts: 50 rollouts_used: 25 num_workers: 2 sgd_stepsize: 0.01 - noise_size: 250000000 + noise_size: 25000000 eval_prob: 0.5 - policy_type: MLPPolicy - fcnet_hiddens: [16, 16] + model: + fcnet_hiddens: [] # a linear policy diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-dqn.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-dqn.yaml index 04aa2dc6edcce..5a6ba5033392a 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/cartpole-dqn.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-dqn.yaml @@ -2,8 +2,8 @@ cartpole-dqn: env: CartPole-v0 run: DQN stop: - episode_reward_mean: 200 - time_total_s: 600 + episode_reward_mean: 150 + timesteps_total: 50000 config: n_step: 3 gamma: 0.95 diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-es.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-es.yaml index a0246f1e26b06..5c411188d3890 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/cartpole-es.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-es.yaml @@ -2,8 +2,8 @@ cartpole-es: env: CartPole-v0 run: ES stop: - episode_reward_mean: 200 - time_total_s: 300 + episode_reward_mean: 75 + timesteps_total: 400000 config: num_workers: 2 noise_size: 25000000 diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-pg.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-pg.yaml index 2bf9e7548b865..58c29e9e5f9d0 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/cartpole-pg.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-pg.yaml @@ -2,7 +2,7 @@ cartpole-pg: env: CartPole-v0 run: PG stop: - episode_reward_mean: 200 - time_total_s: 300 + episode_reward_mean: 100 + timesteps_total: 100000 config: - num_workers: 1 + num_workers: 0 diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml index 82ea5846e733c..3f326cf83062e 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml @@ -2,8 +2,8 @@ cartpole-ppo: env: CartPole-v0 run: PPO stop: - episode_reward_mean: 200 - time_total_s: 300 + episode_reward_mean: 150 + timesteps_total: 100000 config: num_workers: 1 batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/regression_tests/pendulum-ddpg.yaml b/python/ray/rllib/tuned_examples/regression_tests/pendulum-ddpg.yaml index 124f756ecc1c6..696c251c99747 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/pendulum-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/pendulum-ddpg.yaml @@ -2,8 +2,8 @@ pendulum-ddpg: env: Pendulum-v0 run: DDPG stop: - episode_reward_mean: -160 - time_total_s: 900 + episode_reward_mean: -900 + timesteps_total: 100000 config: use_huber: True clip_rewards: False diff --git a/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml b/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml index 63536d3be3704..015429110e228 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml @@ -2,12 +2,13 @@ pendulum-ppo: env: Pendulum-v0 run: PPO stop: - episode_reward_mean: -160 - timesteps_total: 600000 + episode_reward_mean: -200 + timesteps_total: 500000 config: train_batch_size: 2048 vf_clip_param: 10.0 - num_workers: 4 + num_workers: 0 + num_envs_per_worker: 10 lambda: 0.1 gamma: 0.95 lr: 0.0003 diff --git a/python/ray/rllib/tuned_examples/regression_tests/regression_test.py b/python/ray/rllib/tuned_examples/regression_tests/regression_test.py deleted file mode 100644 index ff994b904ed8c..0000000000000 --- a/python/ray/rllib/tuned_examples/regression_tests/regression_test.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python -""" -This class runs the regression YAMLs in the ASV format. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import defaultdict -import numpy as np -import os -import yaml - -import ray -from ray import tune - -CONFIG_DIR = os.path.dirname(os.path.abspath(__file__)) - - -def _evaulate_config(filename): - with open(os.path.join(CONFIG_DIR, filename)) as f: - experiments = yaml.load(f) - for _, config in experiments.items(): - config["num_samples"] = 3 - ray.init() - trials = tune.run_experiments(experiments) - results = defaultdict(list) - for t in trials: - results["time_total_s"] += [t.last_result["time_total_s"]] - results["episode_reward_mean"] += [ - t.last_result["episode_reward_mean"] - ] - results["training_iteration"] += [t.last_result["training_iteration"]] - - return {k: np.median(v) for k, v in results.items()} - - -class Regression(): - def setup_cache(self): - # We need to implement this in separate classes - # below so that ASV will register the setup/class - # as a separate test. - raise NotImplementedError - - def teardown(self, *args): - ray.shutdown() - - def track_time(self, result): - return result["time_total_s"] - - def track_reward(self, result): - return result["episode_reward_mean"] - - def track_iterations(self, result): - return result["training_iteration"] diff --git a/python/ray/rllib/tuned_examples/run_regression_tests.py b/python/ray/rllib/tuned_examples/run_regression_tests.py deleted file mode 100755 index 823da327cc5e4..0000000000000 --- a/python/ray/rllib/tuned_examples/run_regression_tests.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python -# This script runs all the integration tests for RLlib. -# TODO(ekl) add large-scale tests on different envs here. - -import glob -import yaml - -import ray -from ray.tune import run_experiments - -if __name__ == '__main__': - experiments = {} - - for test in glob.glob("regression_tests/*.yaml"): - config = yaml.load(open(test).read()) - experiments.update(config) - - print("== Test config ==") - print(yaml.dump(experiments)) - - ray.init() - trials = run_experiments(experiments) - - num_failures = 0 - for t in trials: - if (t.last_result["episode_reward_mean"] < - t.stopping_criterion["episode_reward_mean"]): - num_failures += 1 - - if num_failures: - raise Exception("{} trials did not converge".format(num_failures)) diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index a738e7419c3be..c25cefeb41b52 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -1,43 +1,10 @@ -import copy - from ray.rllib.utils.filter_manager import FilterManager from ray.rllib.utils.filter import Filter from ray.rllib.utils.policy_client import PolicyClient from ray.rllib.utils.policy_server import PolicyServer +from ray.tune.util import merge_dicts, deep_update -__all__ = ["Filter", "FilterManager", "PolicyClient", "PolicyServer"] - - -def merge_dicts(d1, d2): - """Returns a new dict that is d1 and d2 deep merged.""" - merged = copy.deepcopy(d1) - deep_update(merged, d2, True, []) - return merged - - -def deep_update(original, new_dict, new_keys_allowed, whitelist): - """Updates original dict with values from new_dict recursively. - If new key is introduced in new_dict, then if new_keys_allowed is not - True, an error will be thrown. Further, for sub-dicts, if the key is - in the whitelist, then new subkeys can be introduced. - - Args: - original (dict): Dictionary with default values. - new_dict (dict): Dictionary with values to be updated - new_keys_allowed (bool): Whether new keys are allowed. - whitelist (list): List of keys that correspond to dict values - where new subkeys can be introduced. This is only at - the top level. - """ - for k, value in new_dict.items(): - if k not in original: - if not new_keys_allowed: - raise Exception("Unknown config parameter `{}` ".format(k)) - if type(original.get(k)) is dict: - if k in whitelist: - deep_update(original[k], value, True, []) - else: - deep_update(original[k], value, new_keys_allowed, []) - else: - original[k] = value - return original +__all__ = [ + "Filter", "FilterManager", "PolicyClient", "PolicyServer", "merge_dicts", + "deep_update" +] diff --git a/python/ray/rllib/utils/actors.py b/python/ray/rllib/utils/actors.py index 7018073313112..689aa945cabfc 100644 --- a/python/ray/rllib/utils/actors.py +++ b/python/ray/rllib/utils/actors.py @@ -39,8 +39,8 @@ def completed_prefetch(self): for worker, obj_id in self.completed(): plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id()) - (ray.worker.global_worker.local_scheduler_client. - fetch_or_reconstruct([obj_id], True)) + (ray.worker.global_worker.raylet_client.fetch_or_reconstruct( + [obj_id], True)) self._fetching.append((worker, obj_id)) remaining = [] diff --git a/python/ray/rllib/utils/annotations.py b/python/ray/rllib/utils/annotations.py new file mode 100644 index 0000000000000..d68f76a69600e --- /dev/null +++ b/python/ray/rllib/utils/annotations.py @@ -0,0 +1,20 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def override(cls): + """Annotation for documenting method overrides. + + Arguments: + cls (type): The superclass that provides the overriden method. If this + cls does not actually have the method, an error is raised. + """ + + def check_override(method): + if method.__name__ not in dir(cls): + raise NameError("{} does not override any method of {}".format( + method, cls)) + return method + + return check_override diff --git a/python/ray/rllib/utils/compression.py b/python/ray/rllib/utils/compression.py index aed0dd5985600..8fc7d5890b152 100644 --- a/python/ray/rllib/utils/compression.py +++ b/python/ray/rllib/utils/compression.py @@ -7,6 +7,7 @@ import base64 import numpy as np import pyarrow +from six import string_types logger = logging.getLogger(__name__) @@ -14,9 +15,9 @@ import lz4.frame LZ4_ENABLED = True except ImportError: - logger.warn("lz4 not available, disabling sample compression. " - "This will significantly impact RLlib performance. " - "To install lz4, run `pip install lz4`.") + logger.warning("lz4 not available, disabling sample compression. " + "This will significantly impact RLlib performance. " + "To install lz4, run `pip install lz4`.") LZ4_ENABLED = False @@ -26,7 +27,7 @@ def pack(data): data = lz4.frame.compress(data) # TODO(ekl) we shouldn't need to base64 encode this data, but this # seems to not survive a transfer through the object store if we don't. - data = base64.b64encode(data) + data = base64.b64encode(data).decode("ascii") return data @@ -45,7 +46,7 @@ def unpack(data): def unpack_if_needed(data): - if isinstance(data, bytes): + if isinstance(data, bytes) or isinstance(data, string_types): data = unpack(data) return data diff --git a/python/ray/rllib/utils/filter.py b/python/ray/rllib/utils/filter.py index fbdb39ae18b3f..9a1f37dbd15a5 100644 --- a/python/ray/rllib/utils/filter.py +++ b/python/ray/rllib/utils/filter.py @@ -2,9 +2,12 @@ from __future__ import division from __future__ import print_function +import logging import numpy as np import threading +logger = logging.getLogger(__name__) + class Filter(object): """Processes input, possibly statefully.""" @@ -39,7 +42,10 @@ def __init__(self, *args): pass def __call__(self, x, update=True): - return np.asarray(x) + try: + return np.asarray(x) + except Exception: + raise ValueError("Failed to convert to array", x) def apply_changes(self, other, *args, **kwargs): pass diff --git a/python/ray/rllib/utils/policy_client.py b/python/ray/rllib/utils/policy_client.py index 1bb4b5e134046..ad1334886c8d6 100644 --- a/python/ray/rllib/utils/policy_client.py +++ b/python/ray/rllib/utils/policy_client.py @@ -11,8 +11,9 @@ import requests # `requests` is not part of stdlib. except ImportError: requests = None - logger.warn("Couldn't import `requests` library. Be sure to install it on" - " the client side.") + logger.warning( + "Couldn't import `requests` library. Be sure to install it on" + " the client side.") class PolicyClient(object): diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 13fa63efbd402..b84db6757c86a 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -116,6 +116,22 @@ def cli(logging_level, logging_format): type=int, help="the maximum amount of memory (in bytes) to allow the " "object store to use") +@click.option( + "--redis-max-memory", + required=False, + type=int, + help=("The max amount of memory (in bytes) to allow redis to use, or None " + "for no limit. Once the limit is exceeded, redis will start LRU " + "eviction of entries. This only applies to the sharded " + "redis tables (task and object tables).")) +@click.option( + "--collect-profiling-data", + default=True, + type=bool, + help=("Whether to collect profiling data. Note that " + "profiling data cannot be LRU evicted, so if you set " + "redis_max_memory then profiling will also be disabled to prevent " + "it from consuming all available redis memory.")) @click.option( "--num-workers", required=False, @@ -202,11 +218,11 @@ def cli(logging_level, logging_format): def start(node_ip_address, redis_address, redis_port, num_redis_shards, redis_max_clients, redis_password, redis_shard_ports, object_manager_port, node_manager_port, object_store_memory, - num_workers, num_cpus, num_gpus, resources, head, no_ui, block, - plasma_directory, huge_pages, autoscaling_config, - no_redirect_worker_output, no_redirect_output, - plasma_store_socket_name, raylet_socket_name, temp_dir, - internal_config): + redis_max_memory, collect_profiling_data, num_workers, num_cpus, + num_gpus, resources, head, no_ui, block, plasma_directory, + huge_pages, autoscaling_config, no_redirect_worker_output, + no_redirect_output, plasma_store_socket_name, raylet_socket_name, + temp_dir, internal_config): # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) @@ -262,6 +278,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, redis_port=redis_port, redis_shard_ports=redis_shard_ports, object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + collect_profiling_data=collect_profiling_data, num_workers=num_workers, cleanup=False, redirect_worker_output=not no_redirect_worker_output, @@ -591,9 +609,21 @@ def submit(cluster_config_file, screen, tmux, stop, start, cluster_name, cmd = " ".join(["python", target] + list(script_args)) exec_cluster(cluster_config_file, cmd, screen, tmux, stop, False, cluster_name, port_forward) - if tmux: - logger.info("Use `ray attach {} --tmux` " - "to check on command status.".format(cluster_config_file)) + + if tmux or screen: + attach_command_parts = ["ray attach", cluster_config_file] + if cluster_name is not None: + attach_command_parts.append( + "--cluster-name={}".format(cluster_name)) + if tmux: + attach_command_parts.append("--tmux") + elif screen: + attach_command_parts.append("--screen") + + attach_command = " ".join(attach_command_parts) + attach_info = "Use `{}` to check on command status.".format( + attach_command) + logger.info(attach_info) @cli.command() @@ -627,11 +657,24 @@ def submit(cluster_config_file, screen, tmux, stop, start, cluster_name, def exec_cmd(cluster_config_file, cmd, screen, tmux, stop, start, cluster_name, port_forward): assert not (screen and tmux), "Can specify only one of `screen` or `tmux`." + exec_cluster(cluster_config_file, cmd, screen, tmux, stop, start, cluster_name, port_forward) - if tmux: - logger.info("Use `ray attach {} --tmux` " - "to check on command status.".format(cluster_config_file)) + + if tmux or screen: + attach_command_parts = ["ray attach", cluster_config_file] + if cluster_name is not None: + attach_command_parts.append( + "--cluster-name={}".format(cluster_name)) + if tmux: + attach_command_parts.append("--tmux") + elif screen: + attach_command_parts.append("--screen") + + attach_command = " ".join(attach_command_parts) + attach_info = "Use `{}` to check on command status.".format( + attach_command) + logger.info(attach_info) @cli.command() diff --git a/python/ray/services.py b/python/ray/services.py index 841fababd1b53..77138715de574 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -416,7 +416,8 @@ def start_redis(node_ip_address, redirect_worker_output=False, cleanup=True, password=None, - use_credis=None): + use_credis=None, + redis_max_memory=None): """Start the Redis global state store. Args: @@ -445,6 +446,10 @@ def start_redis(node_ip_address, use_credis: If True, additionally load the chain-replicated libraries into the redis servers. Defaults to None, which means its value is set by the presence of "RAY_USE_NEW_GCS" in os.environ. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). Returns: A tuple of the address for the primary Redis shard and a list of @@ -475,7 +480,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - password=password) + password=password, + redis_max_memory=None) else: assigned_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -489,7 +495,8 @@ def start_redis(node_ip_address, # as the latter contains an extern declaration that the former # supplies. modules=[CREDIS_MASTER_MODULE, REDIS_MODULE], - password=password) + password=password, + redis_max_memory=None) if port is not None: assert assigned_port == port port = assigned_port @@ -523,7 +530,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - password=password) + password=password, + redis_max_memory=redis_max_memory) else: assert num_redis_shards == 1, \ "For now, RAY_USE_NEW_GCS supports 1 shard, and credis "\ @@ -540,7 +548,8 @@ def start_redis(node_ip_address, # It is important to load the credis module BEFORE the ray # module, as the latter contains an extern declaration that the # former supplies. - modules=[CREDIS_MEMBER_MODULE, REDIS_MODULE]) + modules=[CREDIS_MEMBER_MODULE, REDIS_MODULE], + redis_max_memory=redis_max_memory) if redis_shard_ports[i] is not None: assert redis_shard_port == redis_shard_ports[i] @@ -570,7 +579,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1", cleanup=True, password=None, executable=REDIS_EXECUTABLE, - modules=None): + modules=None, + redis_max_memory=None): """Start a single Redis server. Args: @@ -594,6 +604,9 @@ def _start_redis_instance(node_ip_address="127.0.0.1", modules (list of str): A list of pathnames, pointing to the redis module(s) that will be loaded in this redis server. If None, load the default Ray redis module. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. Returns: A tuple of the port used by Redis and a handle to the process that was @@ -657,6 +670,14 @@ def _start_redis_instance(node_ip_address="127.0.0.1", # hosts can connect to it. TODO(rkn): Do this in a more secure way. redis_client.config_set("protected-mode", "no") + # Discard old task and object metadata. + if redis_max_memory is not None: + redis_client.config_set("maxmemory", str(redis_max_memory)) + redis_client.config_set("maxmemory-policy", "allkeys-lru") + redis_client.config_set("maxmemory-samples", "10") + logger.info("Starting Redis shard with {} GB max memory.".format( + round(redis_max_memory / 1e9, 2))) + # If redis_max_clients is provided, attempt to raise the number of maximum # number of Redis clients. if redis_max_clients is not None: @@ -861,7 +882,8 @@ def start_raylet(redis_address, stderr_file=None, cleanup=True, config=None, - redis_password=None): + redis_password=None, + collect_profiling_data=True): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -894,6 +916,7 @@ def start_raylet(redis_address, config (dict|None): Optional Raylet configuration that will override defaults in RayConfig. redis_password (str): The password of the redis server. + collect_profiling_data: Whether to collect profiling data from workers. Returns: The raylet socket name. @@ -923,9 +946,11 @@ def start_raylet(redis_address, "--object-store-name={} " "--raylet-name={} " "--redis-address={} " + "--collect-profiling-data={} " "--temp-dir={}".format( sys.executable, worker_path, node_ip_address, plasma_store_name, raylet_name, redis_address, + "1" if collect_profiling_data else "0", get_temp_root())) if redis_password: start_worker_command += " --redis-password {}".format(redis_password) @@ -1028,9 +1053,6 @@ def determine_plasma_store_config(object_store_memory=None, "when calling ray.init() or ray start.") object_store_memory = MAX_DEFAULT_MEM - if plasma_directory is not None: - plasma_directory = os.path.abspath(plasma_directory) - # Determine which directory to use. By default, use /tmp on MacOS and # /dev/shm on Linux, unless the shared-memory file system is too small, # in which case we default to /tmp on Linux. @@ -1055,10 +1077,15 @@ def determine_plasma_store_config(object_store_memory=None, else: plasma_directory = "/tmp" - # Do some sanity checks. - if object_store_memory > system_memory: - raise Exception("The requested object store memory size is greater " - "than the total available memory.") + # Do some sanity checks. + if object_store_memory > system_memory: + raise Exception( + "The requested object store memory size is greater " + "than the total available memory.") + else: + plasma_directory = os.path.abspath(plasma_directory) + logger.warning("WARNING: object_store_memory is not verified when " + "plasma_directory is set.") if not os.path.isdir(plasma_directory): raise Exception("The file {} does not exist or is not a directory." @@ -1258,6 +1285,8 @@ def start_ray_processes(address_info=None, num_workers=None, num_local_schedulers=1, object_store_memory=None, + redis_max_memory=None, + collect_profiling_data=True, num_redis_shards=1, redis_max_clients=None, redis_password=None, @@ -1304,6 +1333,14 @@ def start_ray_processes(address_info=None, address_info. object_store_memory: The amount of memory (in bytes) to start the object store with. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). + collect_profiling_data: Whether to collect profiling data. Note that + profiling data cannot be LRU evicted, so if you set + redis_max_memory then profiling will also be disabled to prevent + it from consuming all available redis memory. num_redis_shards: The number of Redis shards to start in addition to the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this @@ -1395,7 +1432,8 @@ def start_ray_processes(address_info=None, redirect_output=True, redirect_worker_output=redirect_worker_output, cleanup=cleanup, - password=redis_password) + password=redis_password, + redis_max_memory=redis_max_memory) address_info["redis_address"] = redis_address time.sleep(0.1) @@ -1495,6 +1533,7 @@ def start_ray_processes(address_info=None, stderr_file=raylet_stderr_file, cleanup=cleanup, redis_password=redis_password, + collect_profiling_data=collect_profiling_data, config=config)) # Try to start the web UI. @@ -1546,8 +1585,7 @@ def start_ray_node(node_ip_address, this node (typically just one). num_workers (int): The number of workers to start. num_local_schedulers (int): The number of local schedulers to start. - This is also the number of plasma stores and plasma managers to - start. + This is also the number of plasma stores and raylets to start. object_store_memory (int): The maximum amount of memory (in bytes) to let the plasma store use. redis_password (str): Prevents external clients without the password @@ -1615,6 +1653,8 @@ def start_ray_head(address_info=None, num_workers=None, num_local_schedulers=1, object_store_memory=None, + redis_max_memory=None, + collect_profiling_data=True, worker_path=None, cleanup=True, redirect_worker_output=False, @@ -1660,6 +1700,11 @@ def start_ray_head(address_info=None, address_info. object_store_memory: The amount of memory (in bytes) to start the object store with. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). + collect_profiling_data: Whether to collect profiling data from workers. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here @@ -1710,6 +1755,8 @@ def start_ray_head(address_info=None, num_workers=num_workers, num_local_schedulers=num_local_schedulers, object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + collect_profiling_data=collect_profiling_data, worker_path=worker_path, cleanup=cleanup, redirect_worker_output=redirect_worker_output, diff --git a/python/ray/tempfile_services.py b/python/ray/tempfile_services.py index d4e94aec8a2ae..791b8a257f089 100644 --- a/python/ray/tempfile_services.py +++ b/python/ray/tempfile_services.py @@ -64,7 +64,10 @@ def try_to_create_directory(directory_path): "exists.".format(directory_path)) # Change the log directory permissions so others can use it. This is # important when multiple people are using the same machine. + try: os.chmod(directory_path, 0o0777) + except PermissionError: + pass def get_temp_root(): diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py index 41dc3b6cdd26a..aff302efc4344 100644 --- a/python/ray/test/cluster_utils.py +++ b/python/ray/test/cluster_utils.py @@ -34,6 +34,7 @@ def __init__(self, self.head_node = None self.worker_nodes = {} self.redis_address = None + self.connected = False if not initialize_head and connect: raise RuntimeError("Cannot connect to uninitialized cluster.") @@ -41,14 +42,19 @@ def __init__(self, head_node_args = head_node_args or {} self.add_node(**head_node_args) if connect: - redis_password = head_node_args.get("redis_password") - output_info = ray.init( - redis_address=self.redis_address, - redis_password=redis_password) - logger.info(output_info) + self.connect(head_node_args) if shutdown_at_exit: atexit.register(self.shutdown) + def connect(self, head_node_args): + assert self.redis_address is not None + assert not self.connected + redis_password = head_node_args.get("redis_password") + output_info = ray.init( + redis_address=self.redis_address, redis_password=redis_password) + logger.info(output_info) + self.connected = True + def add_node(self, **override_kwargs): """Adds a node to the local Ray Cluster. @@ -83,7 +89,7 @@ def add_node(self, **override_kwargs): process_dict_copy = services.all_processes.copy() for key in services.all_processes: services.all_processes[key] = [] - node = Node(process_dict_copy) + node = Node(address_info, process_dict_copy) self.head_node = node else: address_info = services.start_ray_node( @@ -93,7 +99,7 @@ def add_node(self, **override_kwargs): process_dict_copy = services.all_processes.copy() for key in services.all_processes: services.all_processes[key] = [] - node = Node(process_dict_copy) + node = Node(address_info, process_dict_copy) self.worker_nodes[node] = address_info logger.info("Starting Node with raylet socket {}".format( address_info["raylet_socket_names"])) @@ -182,8 +188,9 @@ def shutdown(self): class Node(object): """Abstraction for a Ray node.""" - def __init__(self, process_dict): + def __init__(self, address_info, process_dict): # TODO(rliaw): Is there a unique identifier for a node? + self.address_info = address_info self.process_dict = process_dict def kill_plasma_store(self): @@ -224,3 +231,11 @@ def any_processes_alive(self): def all_processes_alive(self): return not any(self.dead_processes()) + + def get_plasma_store_name(self): + """Return the plasma store name. + + Assuming one plasma store per raylet, this may be used as a unique + identifier for a node. + """ + return self.address_info['object_store_addresses'][0] diff --git a/python/ray/test/test_utils.py b/python/ray/test/test_utils.py index a3614650e97ba..189f9ae35f58a 100644 --- a/python/ray/test/test_utils.py +++ b/python/ray/test/test_utils.py @@ -19,8 +19,7 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20): """Wait until the nodes have joined the cluster. - This will wait until exactly num_nodes have joined the cluster and each - node has a local scheduler and a plasma manager. + This will wait until exactly num_nodes have joined the cluster. Args: num_nodes: The number of nodes to wait for. @@ -35,10 +34,6 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20): client_table = ray.global_state.client_table() num_ready_nodes = len(client_table) if num_ready_nodes == num_nodes: - # Check that for each node, a local scheduler and a plasma manager - # are present. - # In raylet mode, this is a list of map. - # The GCS info will appear as a whole instead of part by part. return if num_ready_nodes > num_nodes: # Too many nodes have joined. Something must be wrong. diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 83d4f4fdece37..1e341b26526ea 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -7,9 +7,16 @@ from ray.tune.experiment import Experiment from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable -from ray.tune.suggest import grid_search, function +from ray.tune.suggest import grid_search, function, sample_from __all__ = [ - "Trainable", "TuneError", "grid_search", "register_env", - "register_trainable", "run_experiments", "Experiment", "function" + "Trainable", + "TuneError", + "grid_search", + "register_env", + "register_trainable", + "run_experiments", + "Experiment", + "function", + "sample_from", ] diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 64ef399d342c4..22adfc397ecc6 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -83,7 +83,7 @@ def make_parser(parser_creator=None, **kwargs): help="Algorithm-specific configuration (e.g. env, hyperparams), " "specified in JSON.") parser.add_argument( - "--trial-resources", + "--resources-per-trial", default=None, type=json_to_resources, help="Override the machine resources to allocate per trial, e.g. " @@ -106,6 +106,21 @@ def make_parser(parser_creator=None, **kwargs): default="", type=str, help="Optional URI to sync training results to (e.g. s3://bucket).") + parser.add_argument( + "--trial-name-creator", + default=None, + help="Optional creator function for the trial string, used in " + "generating a trial directory.") + parser.add_argument( + "--sync-function", + default=None, + help="Function for syncing the local_dir to upload_dir. If string, " + "then it must be a string template for syncer to run and needs to " + "include replacement fields '{local_dir}' and '{remote_dir}'.") + parser.add_argument( + "--custom-loggers", + default=None, + help="List of custom logger creators to be used with each Trial.") parser.add_argument( "--checkpoint-freq", default=0, @@ -182,8 +197,9 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): args = parser.parse_args(to_argv(spec)) except SystemExit: raise TuneError("Error parsing args, see above message", spec) - if "trial_resources" in spec: - trial_kwargs["resources"] = json_to_resources(spec["trial_resources"]) + if "resources_per_trial" in spec: + trial_kwargs["resources"] = json_to_resources( + spec["resources_per_trial"]) return Trial( # Submitting trial via server in py2.7 creates Unicode, which does not # convert to string in a straightforward manner. @@ -198,5 +214,9 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): # str(None) doesn't create None restore_path=spec.get("restore"), upload_dir=args.upload_dir, + trial_name_creator=spec.get("trial_name_creator"), + custom_loggers=spec.get("custom_loggers"), + # str(None) doesn't create None + sync_function=spec.get("sync_function"), max_failures=args.max_failures, **trial_kwargs) diff --git a/python/ray/tune/examples/README.rst b/python/ray/tune/examples/README.rst index a762a057021c3..172bb4ef42913 100644 --- a/python/ray/tune/examples/README.rst +++ b/python/ray/tune/examples/README.rst @@ -22,6 +22,8 @@ General Examples Example of using a Trainable class with PopulationBasedTraining scheduler. - `pbt_ppo_example `__: Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler. +- `logging_example `__: + Example of custom loggers and custom trial directory naming. Keras Examples diff --git a/python/ray/tune/examples/async_hyperband_example.py b/python/ray/tune/examples/async_hyperband_example.py index e07f11b325a81..a2e4b63e42bac 100644 --- a/python/ray/tune/examples/async_hyperband_example.py +++ b/python/ray/tune/examples/async_hyperband_example.py @@ -12,7 +12,7 @@ import numpy as np import ray -from ray.tune import Trainable, run_experiments +from ray.tune import Trainable, run_experiments, sample_from from ray.tune.schedulers import AsyncHyperBandScheduler @@ -71,13 +71,15 @@ def _restore(self, checkpoint_path): "training_iteration": 1 if args.smoke_test else 99999 }, "num_samples": 20, - "trial_resources": { + "resources_per_trial": { "cpu": 1, "gpu": 0 }, "config": { - "width": lambda spec: 10 + int(90 * random.random()), - "height": lambda spec: int(100 * random.random()), + "width": sample_from( + lambda spec: 10 + int(90 * random.random())), + "height": sample_from( + lambda spec: int(100 * random.random())), }, } }, diff --git a/python/ray/tune/examples/hyperband_example.py b/python/ray/tune/examples/hyperband_example.py index baf133b411bf6..d403a0e0f8af1 100755 --- a/python/ray/tune/examples/hyperband_example.py +++ b/python/ray/tune/examples/hyperband_example.py @@ -12,7 +12,7 @@ import numpy as np import ray -from ray.tune import Trainable, run_experiments, Experiment +from ray.tune import Trainable, run_experiments, Experiment, sample_from from ray.tune.schedulers import HyperBandScheduler @@ -67,8 +67,8 @@ def _restore(self, checkpoint_path): num_samples=20, stop={"training_iteration": 1 if args.smoke_test else 99999}, config={ - "width": lambda spec: 10 + int(90 * random.random()), - "height": lambda spec: int(100 * random.random()) + "width": sample_from(lambda spec: 10 + int(90 * random.random())), + "height": sample_from(lambda spec: int(100 * random.random())) }) run_experiments(exp, scheduler=hyperband) diff --git a/python/ray/tune/examples/hyperopt_example.py b/python/ray/tune/examples/hyperopt_example.py index 2898bf26d8539..d70d16b9488ed 100644 --- a/python/ray/tune/examples/hyperopt_example.py +++ b/python/ray/tune/examples/hyperopt_example.py @@ -17,7 +17,7 @@ def easy_objective(config, reporter): time.sleep(0.2) assert type(config["activation"]) == str, \ "Config is incorrect: {}".format(type(config["activation"])) - for i in range(100): + for i in range(config["iterations"]): reporter( timesteps_total=i, neg_mean_loss=-(config["height"] - 14)**2 + @@ -47,6 +47,9 @@ def easy_objective(config, reporter): "my_exp": { "run": "exp", "num_samples": 10 if args.smoke_test else 1000, + "config": { + "iterations": 100, + }, "stop": { "timesteps_total": 100 }, diff --git a/python/ray/tune/examples/logging_example.py b/python/ray/tune/examples/logging_example.py new file mode 100755 index 0000000000000..d2e40c1e8e76f --- /dev/null +++ b/python/ray/tune/examples/logging_example.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import json +import os +import random + +import numpy as np + +import ray +from ray import tune +from ray.tune import Trainable, run_experiments, Experiment + + +class TestLogger(tune.logger.Logger): + def on_result(self, result): + print("TestLogger", result) + + +def trial_str_creator(trial): + return "{}_{}_123".format(trial.trainable_name, trial.trial_id) + + +class MyTrainableClass(Trainable): + """Example agent whose learning curve is a random sigmoid. + + The dummy hyperparameters "width" and "height" determine the slope and + maximum reward value reached. + """ + + def _setup(self, config): + self.timestep = 0 + + def _train(self): + self.timestep += 1 + v = np.tanh(float(self.timestep) / self.config["width"]) + v *= self.config["height"] + + # Here we use `episode_reward_mean`, but you can also report other + # objectives such as loss or accuracy. + return {"episode_reward_mean": v} + + def _save(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"timestep": self.timestep})) + return path + + def _restore(self, checkpoint_path): + with open(checkpoint_path) as f: + self.timestep = json.loads(f.read())["timestep"] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + ray.init() + exp = Experiment( + name="hyperband_test", + run=MyTrainableClass, + num_samples=1, + trial_name_creator=tune.function(trial_str_creator), + custom_loggers=[TestLogger], + stop={"training_iteration": 1 if args.smoke_test else 99999}, + config={ + "width": lambda spec: 10 + int(90 * random.random()), + "height": lambda spec: int(100 * random.random()) + }) + + trials = run_experiments(exp) diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index bfd319bff3c9d..a5fe48e5d1ba1 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -137,13 +137,12 @@ def test(): data, target = Variable(data, volatile=True), Variable(target) output = model(data) test_loss += F.nll_loss( - output, target, - size_average=False).data[0] # sum up batch loss + output, target, size_average=False).item() # sum up batch loss pred = output.data.max( 1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() - test_loss = test_loss.item() / len(test_loader.dataset) + test_loss = test_loss / len(test_loader.dataset) accuracy = correct.item() / len(test_loader.dataset) reporter(mean_loss=test_loss, mean_accuracy=accuracy) @@ -176,14 +175,16 @@ def test(): "mean_accuracy": 0.98, "training_iteration": 1 if args.smoke_test else 20 }, - "trial_resources": { + "resources_per_trial": { "cpu": 3 }, "run": "train_mnist", "num_samples": 1 if args.smoke_test else 10, "config": { - "lr": lambda spec: np.random.uniform(0.001, 0.1), - "momentum": lambda spec: np.random.uniform(0.1, 0.9), + "lr": tune.sample_from( + lambda spec: np.random.uniform(0.001, 0.1)), + "momentum": tune.sample_from( + lambda spec: np.random.uniform(0.1, 0.9)), } } }, diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index b5ab0f2ab9f5a..24fc4951dd37c 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -145,13 +145,13 @@ def _test(self): output = self.model(data) # sum up batch loss - test_loss += F.nll_loss(output, target, size_average=False).data[0] + test_loss += F.nll_loss(output, target, size_average=False).item() # get the index of the max log-probability pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() - test_loss = test_loss.item() / len(self.test_loader.dataset) + test_loss = test_loss / len(self.test_loader.dataset) accuracy = correct.item() / len(self.test_loader.dataset) return {"mean_loss": test_loss, "mean_accuracy": accuracy} @@ -187,7 +187,7 @@ def _restore(self, checkpoint_path): "mean_accuracy": 0.95, "training_iteration": 1 if args.smoke_test else 20, }, - "trial_resources": { + "resources_per_trial": { "cpu": 3 }, "run": TrainMNIST, @@ -195,8 +195,10 @@ def _restore(self, checkpoint_path): "checkpoint_at_end": True, "config": { "args": args, - "lr": lambda spec: np.random.uniform(0.001, 0.1), - "momentum": lambda spec: np.random.uniform(0.1, 0.9), + "lr": tune.sample_from( + lambda spec: np.random.uniform(0.001, 0.1)), + "momentum": tune.sample_from( + lambda spec: np.random.uniform(0.1, 0.9)), } } }, diff --git a/python/ray/tune/examples/pbt_ppo_example.py b/python/ray/tune/examples/pbt_ppo_example.py index efd7ee4a89580..a81d4109f62c1 100755 --- a/python/ray/tune/examples/pbt_ppo_example.py +++ b/python/ray/tune/examples/pbt_ppo_example.py @@ -13,7 +13,7 @@ import random import ray -from ray.tune import run_experiments +from ray.tune import run_experiments, sample_from from ray.tune.schedulers import PopulationBasedTraining if __name__ == "__main__": @@ -63,12 +63,12 @@ def explore(config): "clip_param": 0.2, "lr": 1e-4, # These params start off randomly drawn from a set. - "num_sgd_iter": - lambda spec: random.choice([10, 20, 30]), - "sgd_minibatch_size": - lambda spec: random.choice([128, 512, 2048]), - "train_batch_size": - lambda spec: random.choice([10000, 20000, 40000]) + "num_sgd_iter": sample_from( + lambda spec: random.choice([10, 20, 30])), + "sgd_minibatch_size": sample_from( + lambda spec: random.choice([128, 512, 2048])), + "train_batch_size": sample_from( + lambda spec: random.choice([10000, 20000, 40000])) }, }, }, diff --git a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py index 63e3d00e8d1fd..2b7520aeb8b71 100755 --- a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py +++ b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py @@ -23,7 +23,7 @@ from tensorflow.python.keras.preprocessing.image import ImageDataGenerator import ray -from ray.tune import grid_search, run_experiments +from ray.tune import grid_search, run_experiments, sample_from from ray.tune import Trainable from ray.tune.schedulers import PopulationBasedTraining @@ -181,7 +181,7 @@ def _stop(self): train_spec = { "run": Cifar10Model, - "trial_resources": { + "resources_per_trial": { "cpu": 1, "gpu": 1 }, @@ -193,7 +193,7 @@ def _stop(self): "epochs": 1, "batch_size": 64, "lr": grid_search([10**-4, 10**-5]), - "decay": lambda spec: spec.config.lr / 100.0, + "decay": sample_from(lambda spec: spec.config.lr / 100.0), "dropout": grid_search([0.25, 0.5]), }, "num_samples": 4, diff --git a/python/ray/tune/examples/tune_mnist_keras.py b/python/ray/tune/examples/tune_mnist_keras.py index da2e0c9e4b981..03afe65944be6 100644 --- a/python/ray/tune/examples/tune_mnist_keras.py +++ b/python/ray/tune/examples/tune_mnist_keras.py @@ -187,15 +187,19 @@ def create_parser(): }, "run": "train_mnist", "num_samples": 1 if args.smoke_test else 10, - "trial_resources": { + "resources_per_trial": { "cpu": args.threads, "gpu": 0.5 if args.use_gpu else 0 }, "config": { - "lr": lambda spec: np.random.uniform(0.001, 0.1), - "momentum": lambda spec: np.random.uniform(0.1, 0.9), - "hidden": lambda spec: np.random.randint(32, 512), - "dropout1": lambda spec: np.random.uniform(0.2, 0.8), + "lr": tune.sample_from( + lambda spec: np.random.uniform(0.001, 0.1)), + "momentum": tune.sample_from( + lambda spec: np.random.uniform(0.1, 0.9)), + "hidden": tune.sample_from( + lambda spec: np.random.randint(32, 512)), + "dropout1": tune.sample_from( + lambda spec: np.random.uniform(0.2, 0.8)), } } }, diff --git a/python/ray/tune/examples/tune_mnist_ray_hyperband.py b/python/ray/tune/examples/tune_mnist_ray_hyperband.py index 9dbc467752326..bce19deca6859 100755 --- a/python/ray/tune/examples/tune_mnist_ray_hyperband.py +++ b/python/ray/tune/examples/tune_mnist_ray_hyperband.py @@ -31,7 +31,7 @@ import ray from ray.tune import grid_search, run_experiments, register_trainable, \ - Trainable + Trainable, sample_from from ray.tune.schedulers import HyperBandScheduler from tensorflow.examples.tutorials.mnist import input_data @@ -221,7 +221,8 @@ def _restore(self, path): 'time_total_s': 600, }, 'config': { - 'learning_rate': lambda spec: 10**np.random.uniform(-5, -3), + 'learning_rate': sample_from( + lambda spec: 10**np.random.uniform(-5, -3)), 'activation': grid_search(['relu', 'elu', 'tanh']), }, "num_samples": 10, diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 3a4ddc9c7aab8..5859487527db6 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -7,13 +7,31 @@ import six import types -from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.error import TuneError +from ray.tune.log_sync import validate_sync_function from ray.tune.registry import register_trainable +from ray.tune.result import DEFAULT_RESULTS_DIR logger = logging.getLogger(__name__) +def _raise_deprecation_note(deprecated, replacement, soft=False): + """User notification for deprecated parameter. + + Arguments: + deprecated (str): Deprecated parameter. + replacement (str): Replacement parameter to use instead. + soft (bool): Fatal if True. + """ + error_msg = ("`{deprecated}` is deprecated. Please use `{replacement}`. " + "`{deprecated}` will be removed in future versions of " + "Ray.".format(deprecated=deprecated, replacement=replacement)) + if soft: + logger.warning(error_msg) + else: + raise DeprecationWarning(error_msg) + + class Experiment(object): """Tracks experiment specifications. @@ -30,12 +48,10 @@ class Experiment(object): config (dict): Algorithm-specific configuration for Tune variant generation (e.g. env, hyperparams). Defaults to empty dict. Custom search algorithms may ignore this. - trial_resources (dict): Machine resources to allocate per trial, + resources_per_trial (dict): Machine resources to allocate per trial, e.g. ``{"cpu": 64, "gpu": 8}``. Note that GPUs will not be assigned unless you specify them here. Defaults to 1 CPU and 0 GPUs in ``Trainable.default_resource_request()``. - repeat (int): Deprecated and will be removed in future versions of - Ray. Use `num_samples` instead. num_samples (int): Number of times to sample from the hyperparameter space. Defaults to 1. If `grid_search` is provided as an argument, the grid will be repeated @@ -44,6 +60,14 @@ class Experiment(object): Defaults to ``~/ray_results``. upload_dir (str): Optional URI to sync training results to (e.g. ``s3://bucket``). + trial_name_creator (func): Optional function for generating + the trial string representation. + custom_loggers (list): List of custom logger creators to be used with + each Trial. See `ray/tune/logger.py`. + sync_function (func|str): Function for syncing the local_dir to + upload_dir. If string, then it must be a string template for + syncer to run. If not provided, the sync command defaults + to standard S3 or gsutil sync comamnds. checkpoint_freq (int): How many training iterations between checkpoints. A value of 0 (default) disables checkpointing. checkpoint_at_end (bool): Whether to checkpoint at the end of the @@ -53,6 +77,10 @@ class Experiment(object): checkpointing is enabled. Defaults to 3. restore (str): Path to checkpoint. Only makes sense to set if running 1 trial. Defaults to None. + repeat: Deprecated and will be removed in future versions of + Ray. Use `num_samples` instead. + trial_resources: Deprecated and will be removed in future versions of + Ray. Use `resources_per_trial` instead. Examples: @@ -64,7 +92,7 @@ class Experiment(object): >>> "alpha": tune.grid_search([0.2, 0.4, 0.6]), >>> "beta": tune.grid_search([1, 2]), >>> }, - >>> trial_resources={ + >>> resources_per_trial={ >>> "cpu": 1, >>> "gpu": 0 >>> }, @@ -81,23 +109,41 @@ def __init__(self, run, stop=None, config=None, - trial_resources=None, - repeat=1, + resources_per_trial=None, num_samples=1, local_dir=None, upload_dir=None, + trial_name_creator=None, + custom_loggers=None, + sync_function=None, checkpoint_freq=0, checkpoint_at_end=False, max_failures=3, - restore=None): + restore=None, + repeat=None, + trial_resources=None): + validate_sync_function(sync_function) + if sync_function: + assert upload_dir, "Need `upload_dir` if sync_function given." + + if repeat: + _raise_deprecation_note("repeat", "num_samples", soft=False) + if trial_resources: + _raise_deprecation_note( + "trial_resources", "resources_per_trial", soft=True) + resources_per_trial = trial_resources + spec = { "run": self._register_if_needed(run), "stop": stop or {}, "config": config or {}, - "trial_resources": trial_resources, + "resources_per_trial": resources_per_trial, "num_samples": num_samples, "local_dir": local_dir or DEFAULT_RESULTS_DIR, "upload_dir": upload_dir or "", # argparse converts None to "null" + "trial_name_creator": trial_name_creator, + "custom_loggers": custom_loggers, + "sync_function": sync_function or "", # See `upload_dir`. "checkpoint_freq": checkpoint_freq, "checkpoint_at_end": checkpoint_at_end, "max_failures": max_failures, @@ -118,13 +164,6 @@ def from_json(cls, name, spec): if "run" not in spec: raise TuneError("No trainable specified!") - if "repeat" in spec: - raise DeprecationWarning("The parameter `repeat` is deprecated; \ - converting to `num_samples`. `repeat` will be removed in \ - future versions of Ray.") - spec["num_samples"] = spec["repeat"] - del spec["repeat"] - # Special case the `env` param for RLlib by automatically # moving it into the `config` section. if "env" in spec: diff --git a/python/ray/tune/log_sync.py b/python/ray/tune/log_sync.py index 109c11a01707f..2046165c0129d 100644 --- a/python/ray/tune/log_sync.py +++ b/python/ray/tune/log_sync.py @@ -7,6 +7,7 @@ import os import subprocess import time +import types try: # py3 from shlex import quote @@ -17,6 +18,7 @@ from ray.tune.cluster_info import get_ssh_key, get_ssh_user from ray.tune.error import TuneError from ray.tune.result import DEFAULT_RESULTS_DIR +from ray.tune.suggest.variant_generator import function as tune_function logger = logging.getLogger(__name__) @@ -28,9 +30,9 @@ ALLOWED_REMOTE_PREFIXES = (S3_PREFIX, GCS_PREFIX) -def get_syncer(local_dir, remote_dir=None): +def get_syncer(local_dir, remote_dir=None, sync_function=None): if remote_dir: - if not any( + if not sync_function and not any( remote_dir.startswith(prefix) for prefix in ALLOWED_REMOTE_PREFIXES): raise TuneError("Upload uri must start with one of: {}" @@ -53,7 +55,7 @@ def get_syncer(local_dir, remote_dir=None): key = (local_dir, remote_dir) if key not in _syncers: - _syncers[key] = _LogSyncer(local_dir, remote_dir) + _syncers[key] = _LogSyncer(local_dir, remote_dir, sync_function) return _syncers[key] @@ -63,15 +65,47 @@ def wait_for_log_sync(): syncer.wait() +def validate_sync_function(sync_function): + if sync_function is None: + return + elif isinstance(sync_function, str): + assert "{remote_dir}" in sync_function, ( + "Sync template missing '{remote_dir}'.") + assert "{local_dir}" in sync_function, ( + "Sync template missing '{local_dir}'.") + elif not (isinstance(sync_function, types.FunctionType) + or isinstance(sync_function, tune_function)): + raise ValueError("Sync function {} must be string or function".format( + sync_function)) + + class _LogSyncer(object): """Log syncer for tune. This syncs files from workers to the local node, and optionally also from - the local node to a remote directory (e.g. S3).""" - - def __init__(self, local_dir, remote_dir=None): + the local node to a remote directory (e.g. S3). + + Arguments: + logdir (str): Directory to sync from. + upload_uri (str): Directory to sync to. + sync_function (func|str): Function for syncing the local_dir to + upload_dir. If string, then it must be a string template + for syncer to run and needs to include replacement fields + '{local_dir}' and '{remote_dir}'. + """ + + def __init__(self, local_dir, remote_dir=None, sync_function=None): self.local_dir = local_dir self.remote_dir = remote_dir + + # Resolve sync_function into template or function + self.sync_func = None + self.sync_cmd_tmpl = None + if isinstance(sync_function, types.FunctionType) or isinstance( + sync_function, tune_function): + self.sync_func = sync_function + elif isinstance(sync_function, str): + self.sync_cmd_tmpl = sync_function self.last_sync_time = 0 self.sync_process = None self.local_ip = ray.services.get_node_ip_address() @@ -116,12 +150,14 @@ def sync_now(self, force=False): quote(ssh_key), quote(source), quote(target))) if self.remote_dir: - if self.remote_dir.startswith(S3_PREFIX): - local_to_remote_sync_cmd = ("aws s3 sync {} {}".format( - quote(self.local_dir), quote(self.remote_dir))) - elif self.remote_dir.startswith(GCS_PREFIX): - local_to_remote_sync_cmd = ("gsutil rsync -r {} {}".format( - quote(self.local_dir), quote(self.remote_dir))) + if self.sync_func: + local_to_remote_sync_cmd = None + try: + self.sync_func(self.local_dir, self.remote_dir) + except Exception: + logger.exception("Sync function failed.") + else: + local_to_remote_sync_cmd = self.get_remote_sync_cmd() else: local_to_remote_sync_cmd = None @@ -148,3 +184,24 @@ def sync_now(self, force=False): def wait(self): if self.sync_process: self.sync_process.wait() + + def get_remote_sync_cmd(self): + if self.sync_cmd_tmpl: + local_to_remote_sync_cmd = (self.sync_cmd_tmpl.format( + local_dir=quote(self.local_dir), + remote_dir=quote(self.remote_dir))) + elif self.remote_dir.startswith(S3_PREFIX): + local_to_remote_sync_cmd = ( + "aws s3 sync {local_dir} {remote_dir}".format( + local_dir=quote(self.local_dir), + remote_dir=quote(self.remote_dir))) + elif self.remote_dir.startswith(GCS_PREFIX): + local_to_remote_sync_cmd = ( + "gsutil rsync -r {local_dir} {remote_dir}".format( + local_dir=quote(self.local_dir), + remote_dir=quote(self.remote_dir))) + else: + logger.warning("Remote sync unsupported, skipping.") + local_to_remote_sync_cmd = None + + return local_to_remote_sync_cmd diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index d2c79e6d871a0..a0f84d57b6fe6 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -8,8 +8,11 @@ import numpy as np import os import yaml +import distutils.version +import ray.cloudpickle as cloudpickle from ray.tune.log_sync import get_syncer +from ray.cloudpickle import cloudpickle from ray.tune.result import NODE_IP, TRAINING_ITERATION, TIME_TOTAL_S, \ TIMESTEPS_TOTAL @@ -17,17 +20,26 @@ try: import tensorflow as tf + use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= + distutils.version.LooseVersion("1.5.0")) except ImportError: tf = None + use_tf150_api = True logger.warning("Couldn't import TensorFlow - " "disabling TensorBoard logging.") class Logger(object): - """Logging interface for ray.tune; specialized implementations follow. + """Logging interface for ray.tune. By default, the UnifiedLogger implementation is used which logs results in - multiple formats (TensorBoard, rllab/viskit, plain json) at once. + multiple formats (TensorBoard, rllab/viskit, plain json, custom loggers) + at once. + + Arguments: + config: Configuration passed to all logger creators. + logdir: Directory for all logger creators to log to. + upload_uri (str): Optional URI where the logdir is sync'ed to. """ def __init__(self, config, logdir, upload_uri=None): @@ -58,17 +70,41 @@ def flush(self): class UnifiedLogger(Logger): """Unified result logger for TensorBoard, rllab/viskit, plain json. - This class also periodically syncs output to the given upload uri.""" + This class also periodically syncs output to the given upload uri. + + Arguments: + config: Configuration passed to all logger creators. + logdir: Directory for all logger creators to log to. + upload_uri (str): Optional URI where the logdir is sync'ed to. + custom_loggers (list): List of custom logger creators. + sync_function (func|str): Optional function for syncer to run. + See ray/python/ray/tune/log_sync.py + """ + + def __init__(self, + config, + logdir, + upload_uri=None, + custom_loggers=None, + sync_function=None): + self._logger_list = [_JsonLogger, _TFLogger, _VisKitLogger] + self._sync_function = sync_function + if custom_loggers: + assert isinstance(custom_loggers, list), "Improper custom loggers." + self._logger_list += custom_loggers + + Logger.__init__(self, config, logdir, upload_uri) def _init(self): self._loggers = [] - for cls in [_JsonLogger, _TFLogger, _VisKitLogger]: - if cls is _TFLogger and tf is None: - logger.info("TF not installed - " - "cannot log with {}...".format(cls)) - continue - self._loggers.append(cls(self.config, self.logdir, self.uri)) - self._log_syncer = get_syncer(self.logdir, self.uri) + for cls in self._logger_list: + try: + self._loggers.append(cls(self.config, self.logdir, self.uri)) + except Exception: + logger.exception("Could not instantiate {} - skipping.".format( + str(cls))) + self._log_syncer = get_syncer( + self.logdir, self.uri, sync_function=self._sync_function) def on_result(self, result): for logger in self._loggers: @@ -103,6 +139,11 @@ def _init(self): indent=2, sort_keys=True, cls=_SafeFallbackEncoder) + pkl_out = os.path.join(self.logdir, "params.pkl") + with open(config_out, "wb") as f: + cloudpickle.dump( + self.config, + f) local_file = os.path.join(self.logdir, "result.json") self.local_out = open(local_file, "w") @@ -122,7 +163,11 @@ def to_tf_values(result, path): values = [] for attr, value in result.items(): if value is not None: - if type(value) in [int, float, np.float32, np.float64, np.int32]: + if use_tf150_api: + type_list = [int, float, np.float32, np.float64, np.int32] + else: + type_list = [int, float] + if type(value) in type_list: values.append( tf.Summary.Value( tag="/".join(path + [attr]), simple_value=value)) diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 7b725e05f342e..6b107b17c82f9 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -110,19 +110,27 @@ def _stop_trial(self, trial, error=False, error_msg=None, if stop_logger: trial.close_logger() - def start_trial(self, trial, checkpoint_obj=None): - """Starts the trial.""" + def start_trial(self, trial, checkpoint=None): + """Starts the trial. + + Will not return resources if trial repeatedly fails on start. + + Args: + trial (Trial): Trial to be started. + checkpoint (Checkpoint): A Python object or path storing the state + of trial. + """ self._commit_resources(trial.resources) try: - self._start_trial(trial, checkpoint_obj) + self._start_trial(trial, checkpoint) except Exception: logger.exception("Error stopping runner - retrying...") error_msg = traceback.format_exc() time.sleep(2) self._stop_trial(trial, error=True, error_msg=error_msg) try: - self._start_trial(trial) + self._start_trial(trial, checkpoint) except Exception: logger.exception("Error starting runner, aborting!") error_msg = traceback.format_exc() @@ -140,6 +148,7 @@ def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): self._stop_trial( trial, error=error, error_msg=error_msg, stop_logger=stop_logger) if prior_status == Trial.RUNNING: + logger.debug("Returning resources for this trial.") self._return_resources(trial.resources) out = self._find_item(self._running, trial) for result_id in out: diff --git a/python/ray/tune/suggest/__init__.py b/python/ray/tune/suggest/__init__.py index f0146ca5e7992..9f4a5e6a7ad06 100644 --- a/python/ray/tune/suggest/__init__.py +++ b/python/ray/tune/suggest/__init__.py @@ -2,9 +2,15 @@ from ray.tune.suggest.basic_variant import BasicVariantGenerator from ray.tune.suggest.suggestion import SuggestionAlgorithm from ray.tune.suggest.hyperopt import HyperOptSearch -from ray.tune.suggest.variant_generator import grid_search, function +from ray.tune.suggest.variant_generator import grid_search, function, \ + sample_from __all__ = [ - "SearchAlgorithm", "BasicVariantGenerator", "HyperOptSearch", - "SuggestionAlgorithm", "grid_search", "function" + "SearchAlgorithm", + "BasicVariantGenerator", + "HyperOptSearch", + "SuggestionAlgorithm", + "grid_search", + "function", + "sample_from", ] diff --git a/python/ray/tune/suggest/suggestion.py b/python/ray/tune/suggest/suggestion.py index aa6cbe7173cf5..f6e7d532a9c5f 100644 --- a/python/ray/tune/suggest/suggestion.py +++ b/python/ray/tune/suggest/suggestion.py @@ -7,10 +7,11 @@ from ray.tune.error import TuneError from ray.tune.trial import Trial +from ray.tune.util import merge_dicts from ray.tune.experiment import convert_to_experiment_list from ray.tune.config_parser import make_parser, create_trial_from_spec from ray.tune.suggest.search import SearchAlgorithm -from ray.tune.suggest.variant_generator import format_vars +from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict class SuggestionAlgorithm(SearchAlgorithm): @@ -33,9 +34,6 @@ class SuggestionAlgorithm(SearchAlgorithm): def __init__(self): """Constructs a generator given experiment specifications. - - Arguments: - experiments (Experiment | list | dict): Experiments to run. """ self._parser = make_parser() self._trial_generator = [] @@ -91,10 +89,11 @@ def _generate_trials(self, experiment_spec, output_path=""): else: break spec = copy.deepcopy(experiment_spec) - spec["config"] = suggested_config + spec["config"] = merge_dicts(spec["config"], suggested_config) + flattened_config = resolve_nested_dict(spec["config"]) self._counter += 1 tag = "{0}_{1}".format( - str(self._counter), format_vars(spec["config"])) + str(self._counter), format_vars(flattened_config)) yield create_trial_from_spec( spec, output_path, diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index c33e7925167db..09729f9883f8f 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -3,12 +3,15 @@ from __future__ import print_function import copy +import logging import numpy import random import types from ray.tune import TuneError +logger = logging.getLogger(__name__) + def generate_variants(unresolved_spec): """Generates variants from a spec (dict) with unresolved values. @@ -30,10 +33,6 @@ def generate_variants(unresolved_spec): "cpu": lambda spec: spec.config.num_workers "batch_size": lambda spec: random.uniform(1, 1000) - It is also possible to nest the two, e.g. have a lambda function - return a grid search or vice versa, as long as there are no cyclic - dependencies between unresolved values. - Finally, to support defining specs in plain JSON / YAML, grid search and lambda functions can also be defined alternatively as follows: @@ -55,8 +54,29 @@ def grid_search(values): return {"grid_search": values} +class sample_from(object): + """Specify that tune should sample configuration values from this function. + + The use of function arguments in tune configs must be disambiguated by + either wrapped the function in tune.eval() or tune.function(). + + Arguments: + func: An callable function to draw a sample from. + """ + + def __init__(self, func): + self.func = func + + class function(object): - """Wraps `func` to make sure it is not expanded during resolution.""" + """Wraps `func` to make sure it is not expanded during resolution. + + The use of function arguments in tune configs must be disambiguated by + either wrapped the function in tune.eval() or tune.function(). + + Arguments: + func: A function literal. + """ def __init__(self, func): self.func = func @@ -73,10 +93,25 @@ def __call__(self, *args, **kwargs): _MAX_RESOLUTION_PASSES = 20 +def resolve_nested_dict(nested_dict): + """Flattens a nested dict by joining keys into tuple of paths. + + Can then be passed into `format_vars`. + """ + res = {} + for k, v in nested_dict.items(): + if isinstance(v, dict): + for k_, v_ in resolve_nested_dict(v).items(): + res[(k, ) + k_] = v_ + else: + res[(k, )] = v + return res + + def format_vars(resolved_vars): out = [] for path, value in sorted(resolved_vars.items()): - if path[0] in ["run", "env", "trial_resources"]: + if path[0] in ["run", "env", "resources_per_trial"]: continue # TrialRunner already has these in the experiment_tag pieces = [] last_string = True @@ -126,7 +161,7 @@ def _generate_variants(spec): raise ValueError( "The variable `{}` could not be unambiguously " "resolved to a single value. Consider simplifying " - "your variable dependencies.".format(k)) + "your configuration.".format(k)) resolved_vars[k] = v yield resolved_vars, spec @@ -203,8 +238,17 @@ def _is_resolved(v): def _try_resolve(v): if isinstance(v, types.FunctionType): - # Lambda function + logger.warning( + "Deprecation warning: Function values are ambiguous in Tune " + "configuations. Either wrap the function with " + "`tune.function(func)` to specify a function literal, or " + "`tune.sample_from(func)` to tell Tune to " + "sample values from the function during variant generation: " + "{}".format(v)) return False, v + elif isinstance(v, sample_from): + # Function to sample from + return False, v.func elif isinstance(v, dict) and len(v) == 1 and "eval" in v: # Lambda function in eval syntax return False, lambda spec: eval( diff --git a/python/ray/tune/test/cluster_tests.py b/python/ray/tune/test/cluster_tests.py index f9425cc3e301a..59f12181b8ff9 100644 --- a/python/ray/tune/test/cluster_tests.py +++ b/python/ray/tune/test/cluster_tests.py @@ -3,45 +3,22 @@ from __future__ import print_function import json -import time import pytest try: import pytest_timeout except ImportError: pytest_timeout = None -from ray.test.cluster_utils import Cluster import ray -from ray import tune +from ray.rllib import _register_all +from ray.test.cluster_utils import Cluster from ray.tune.error import TuneError from ray.tune.trial import Trial from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import BasicVariantGenerator -def register_test_trainable(): - class _Train(tune.Trainable): - def _setup(self, config): - self.state = {"hi": 1} - - def _train(self): - self.state["hi"] += 1 - time.sleep(0.5) - return {} - - def _save(self, path): - return self.state - - def _restore(self, state): - self.state = state - - tune.register_trainable("test", _Train) - - -@pytest.fixture -def start_connected_cluster(): - # Start the Ray processes. - +def _start_new_cluster(): cluster = Cluster( initialize_head=True, connect=True, @@ -51,7 +28,15 @@ def start_connected_cluster(): "num_heartbeats_timeout": 10 }) }) - register_test_trainable() + # Pytest doesn't play nicely with imports + _register_all() + return cluster + + +@pytest.fixture +def start_connected_cluster(): + # Start the Ray processes. + cluster = _start_new_cluster() yield cluster # The code after the yield will run as teardown code. ray.shutdown() @@ -71,39 +56,36 @@ def start_connected_emptyhead_cluster(): "num_heartbeats_timeout": 10 }) }) - register_test_trainable() + # Pytest doesn't play nicely with imports + _register_all() yield cluster # The code after the yield will run as teardown code. ray.shutdown() cluster.shutdown() -@pytest.mark.skipif( - pytest_timeout is None, - reason="Timeout package not installed; skipping test.") -@pytest.mark.timeout(10, method="thread") def test_counting_resources(start_connected_cluster): """Tests that Tune accounting is consistent with actual cluster.""" cluster = start_connected_cluster - assert ray.global_state.cluster_resources()["CPU"] == 1 nodes = [] - nodes += [cluster.add_node(resources=dict(CPU=1))] - assert cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 2 - + assert ray.global_state.cluster_resources()["CPU"] == 1 runner = TrialRunner(BasicVariantGenerator()) kwargs = {"stopping_criterion": {"training_iteration": 10}} - trials = [Trial("test", **kwargs), Trial("test", **kwargs)] + trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] for t in trials: runner.add_trial(t) runner.step() # run 1 + nodes += [cluster.add_node(resources=dict(CPU=1))] + assert cluster.wait_for_nodes() + assert ray.global_state.cluster_resources()["CPU"] == 2 cluster.remove_node(nodes.pop()) assert cluster.wait_for_nodes() assert ray.global_state.cluster_resources()["CPU"] == 1 runner.step() # run 2 + assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1 for i in range(5): nodes += [cluster.add_node(resources=dict(CPU=1))] @@ -111,12 +93,7 @@ def test_counting_resources(start_connected_cluster): assert ray.global_state.cluster_resources()["CPU"] == 6 runner.step() # 1 result - - for i in range(5): - node = nodes.pop() - cluster.remove_node(node) - assert cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2 @pytest.mark.skip("Add this test once reconstruction is fixed") @@ -133,7 +110,7 @@ def test_remove_node_before_result(start_connected_cluster): runner = TrialRunner(BasicVariantGenerator()) kwargs = {"stopping_criterion": {"training_iteration": 3}} - trials = [Trial("test", **kwargs), Trial("test", **kwargs)] + trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] for t in trials: runner.add_trial(t) @@ -179,7 +156,7 @@ def test_trial_migration(start_connected_emptyhead_cluster): } # Test recovery of trial that hasn't been checkpointed - t = Trial("test", **kwargs) + t = Trial("__fake", **kwargs) runner.add_trial(t) runner.step() # start runner.step() # 1 result @@ -199,7 +176,7 @@ def test_trial_migration(start_connected_emptyhead_cluster): assert t.status == Trial.TERMINATED # Test recovery of trial that has been checkpointed - t2 = Trial("test", **kwargs) + t2 = Trial("__fake", **kwargs) runner.add_trial(t2) runner.step() # start runner.step() # 1 result @@ -216,7 +193,7 @@ def test_trial_migration(start_connected_emptyhead_cluster): assert t2.status == Trial.TERMINATED # Test recovery of trial that won't be checkpointed - t3 = Trial("test", **{"stopping_criterion": {"training_iteration": 3}}) + t3 = Trial("__fake", **{"stopping_criterion": {"training_iteration": 3}}) runner.add_trial(t3) runner.step() # start runner.step() # 1 result @@ -238,6 +215,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster): """Removing a node in full cluster causes Trial to be requeued.""" cluster = start_connected_emptyhead_cluster node = cluster.add_node(resources=dict(CPU=1)) + assert cluster.wait_for_nodes() runner = TrialRunner(BasicVariantGenerator()) kwargs = { @@ -248,7 +226,7 @@ def test_trial_requeue(start_connected_emptyhead_cluster): "max_failures": 1 } - trials = [Trial("test", **kwargs), Trial("test", **kwargs)] + trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] for t in trials: runner.add_trial(t) diff --git a/python/ray/tune/test/ray_trial_executor_test.py b/python/ray/tune/test/ray_trial_executor_test.py index 35c413e717bb4..86c4bb189595f 100644 --- a/python/ray/tune/test/ray_trial_executor_test.py +++ b/python/ray/tune/test/ray_trial_executor_test.py @@ -9,8 +9,9 @@ from ray.rllib import _register_all from ray.tune import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor +from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.suggest import BasicVariantGenerator -from ray.tune.trial import Trial, Checkpoint +from ray.tune.trial import Trial, Checkpoint, Resources class RayTrialExecutorTest(unittest.TestCase): @@ -50,6 +51,12 @@ def testPauseResume(self): self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) + def testStartFailure(self): + _global_registry.register(TRAINABLE_CLASS, "asdf", None) + trial = Trial("asdf", resources=Resources(1, 0)) + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.ERROR, trial.status) + def testPauseResume2(self): """Tests that pausing works for trials being processed.""" trial = Trial("__fake") diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 6b142d354ec7f..141936b706664 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -3,12 +3,14 @@ from __future__ import print_function import os +import sys import time import unittest import ray from ray.rllib import _register_all +from ray import tune from ray.tune import Trainable, TuneError from ray.tune import register_env, register_trainable, run_experiments from ray.tune.ray_trial_executor import RayTrialExecutor @@ -16,6 +18,7 @@ from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE, EPISODES_TOTAL) +from ray.tune.logger import Logger from ray.tune.util import pin_in_object_store, get_pinned_object from ray.tune.experiment import Experiment from ray.tune.trial import Trial, Resources @@ -23,7 +26,13 @@ from ray.tune.suggest import grid_search, BasicVariantGenerator from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm, SuggestionAlgorithm) -from ray.tune.suggest.variant_generator import RecursiveDependencyError +from ray.tune.suggest.variant_generator import (RecursiveDependencyError, + resolve_nested_dict) + +if sys.version_info >= (3, 3): + from unittest.mock import patch +else: + from mock import patch class TrainableFunctionApiTest(unittest.TestCase): @@ -286,7 +295,7 @@ def f(): run_experiments({ "foo": { "run": "PPO", - "trial_resources": { + "resources_per_trial": { "asdf": 1 } } @@ -673,6 +682,99 @@ def _save(self, path): self.assertEqual(trial.status, Trial.TERMINATED) self.assertTrue(trial.has_checkpoint()) + def testDeprecatedResources(self): + class train(Trainable): + def _train(self): + return {"timesteps_this_iter": 1, "done": True} + + trials = run_experiments({ + "foo": { + "run": train, + "trial_resources": { + "cpu": 1 + } + } + }) + for trial in trials: + self.assertEqual(trial.status, Trial.TERMINATED) + + def testCustomLogger(self): + class CustomLogger(Logger): + def on_result(self, result): + with open(os.path.join(self.logdir, "test.log"), "w") as f: + f.write("hi") + + [trial] = run_experiments({ + "foo": { + "run": "__fake", + "stop": { + "training_iteration": 1 + }, + "custom_loggers": [CustomLogger] + } + }) + self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log"))) + + def testCustomTrialString(self): + [trial] = run_experiments({ + "foo": { + "run": "__fake", + "stop": { + "training_iteration": 1 + }, + "trial_name_creator": tune.function( + lambda t: "{}_{}_321".format(t.trainable_name, t.trial_id)) + } + }) + self.assertEquals( + str(trial), "{}_{}_321".format(trial.trainable_name, + trial.trial_id)) + + def testSyncFunction(self): + def fail_sync_local(): + [trial] = run_experiments({ + "foo": { + "run": "__fake", + "stop": { + "training_iteration": 1 + }, + "upload_dir": "test", + "sync_function": "ls {remote_dir}" + } + }) + + self.assertRaises(AssertionError, fail_sync_local) + + def fail_sync_remote(): + [trial] = run_experiments({ + "foo": { + "run": "__fake", + "stop": { + "training_iteration": 1 + }, + "upload_dir": "test", + "sync_function": "ls {local_dir}" + } + }) + + self.assertRaises(AssertionError, fail_sync_remote) + + def sync_func(local, remote): + with open(os.path.join(local, "test.log"), "w") as f: + f.write(remote) + + [trial] = run_experiments({ + "foo": { + "run": "__fake", + "stop": { + "training_iteration": 1 + }, + "upload_dir": "test", + "sync_function": tune.function(sync_func) + } + }) + self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log"))) + class VariantGeneratorTest(unittest.TestCase): def setUp(self): @@ -801,6 +903,20 @@ def testDependentGridSearch(self): self.assertEqual(trials[0].config, {"x": 100, "y": 1}) self.assertEqual(trials[1].config, {"x": 200, "y": 1}) + def test_resolve_dict(self): + config = { + "a": { + "b": 1, + "c": 2, + }, + "b": { + "a": 3 + } + } + resolved = resolve_nested_dict(config) + for k, v in [(("a", "b"), 1), (("a", "c"), 2), (("b", "a"), 3)]: + self.assertEqual(resolved.get(k), v) + def testRecursiveDep(self): try: list( @@ -845,6 +961,25 @@ def testMaxConcurrentSuggestions(self): self.assertEqual(len(searcher.next_trials()), 0) +def create_mock_components(): + class _MockScheduler(FIFOScheduler): + errored_trials = [] + + def on_trial_error(self, trial_runner, trial): + self.errored_trials += [trial] + + class _MockSearchAlg(BasicVariantGenerator): + errored_trials = [] + + def on_trial_complete(self, trial_id, error=False, **kwargs): + if error: + self.errored_trials += [trial_id] + + searchalg = _MockSearchAlg() + scheduler = _MockScheduler() + return searchalg, scheduler + + class TrialRunnerTest(unittest.TestCase): def tearDown(self): ray.shutdown() @@ -889,16 +1024,6 @@ def train(config, reporter): self.assertLessEqual(len(trial.logdir), 200) trial_executor.stop_trial(trial) - def testTrialErrorOnStart(self): - ray.init() - trial_executor = RayTrialExecutor() - _global_registry.register(TRAINABLE_CLASS, "asdf", None) - trial = Trial("asdf", resources=Resources(1, 0)) - try: - trial_executor.start_trial(trial) - except Exception as e: - self.assertIn("a class", str(e)) - def testExtraResources(self): ray.init(num_cpus=4, num_gpus=2) runner = TrialRunner(BasicVariantGenerator()) @@ -1055,7 +1180,9 @@ def testThrowOnOverstep(self): def testFailureRecoveryDisabled(self): ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + searchalg, scheduler = create_mock_components() + + runner = TrialRunner(searchalg, scheduler=scheduler) kwargs = { "resources": Resources(cpu=1, gpu=1), "checkpoint_freq": 1, @@ -1074,10 +1201,15 @@ def testFailureRecoveryDisabled(self): runner.step() self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].num_failures, 1) + self.assertEqual(len(searchalg.errored_trials), 1) + self.assertEqual(len(scheduler.errored_trials), 1) def testFailureRecoveryEnabled(self): ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + searchalg, scheduler = create_mock_components() + + runner = TrialRunner(searchalg, scheduler=scheduler) + kwargs = { "resources": Resources(cpu=1, gpu=1), "checkpoint_freq": 1, @@ -1098,6 +1230,40 @@ def testFailureRecoveryEnabled(self): self.assertEqual(trials[0].num_failures, 1) runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) + self.assertEqual(len(searchalg.errored_trials), 0) + self.assertEqual(len(scheduler.errored_trials), 0) + + def testFailureRecoveryNodeRemoval(self): + ray.init(num_cpus=1, num_gpus=1) + searchalg, scheduler = create_mock_components() + + runner = TrialRunner(searchalg, scheduler=scheduler) + + kwargs = { + "resources": Resources(cpu=1, gpu=1), + "checkpoint_freq": 1, + "max_failures": 1, + "config": { + "mock_error": True, + }, + } + runner.add_trial(Trial("__fake", **kwargs)) + trials = runner.get_trials() + + with patch('ray.global_state.cluster_resources') as resource_mock: + resource_mock.return_value = {"CPU": 1, "GPU": 1} + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + + # Mimic a node failure + resource_mock.return_value = {"CPU": 0, "GPU": 0} + runner.step() + self.assertEqual(trials[0].status, Trial.PENDING) + self.assertEqual(trials[0].num_failures, 1) + self.assertEqual(len(searchalg.errored_trials), 0) + self.assertEqual(len(scheduler.errored_trials), 1) def testFailureRecoveryMaxFailures(self): ray.init(num_cpus=1, num_gpus=1) @@ -1500,5 +1666,18 @@ def _suggest(self, trial_id): self.assertRaises(TuneError, runner.step) +class SearchAlgorithmTest(unittest.TestCase): + def testNestedSuggestion(self): + class TestSuggestion(SuggestionAlgorithm): + def _suggest(self, trial_id): + return {"a": {"b": {"c": {"d": 4, "e": 5}}}} + + alg = TestSuggestion() + alg.add_configurations({"test": {"run": "__fake"}}) + trial = alg.next_trials()[0] + self.assertTrue("e=5" in trial.experiment_tag) + self.assertTrue("d=4" in trial.experiment_tag) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index 1c32f72e07a1f..e8e2f20544b49 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -574,6 +574,7 @@ def __init__(self, i, config): self.trainable_name = "trial_{}".format(i) self.config = config self.experiment_tag = "tag" + self.trial_name_creator = None self.logger_running = False self.restored_checkpoint = None self.resources = Resources(1, 0) diff --git a/python/ray/tune/test/tune_server_test.py b/python/ray/tune/test/tune_server_test.py index a535b421b4dc7..db99aae2e22a0 100644 --- a/python/ray/tune/test/tune_server_test.py +++ b/python/ray/tune/test/tune_server_test.py @@ -65,7 +65,7 @@ def testAddTrial(self): "stop": { "training_iteration": 3 }, - "trial_resources": { + "resources_per_trial": { 'cpu': 1, 'gpu': 1 }, diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 65683eeb53c71..3a766a400f32d 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -124,6 +124,9 @@ def __init__(self, checkpoint_at_end=False, restore_path=None, upload_dir=None, + trial_name_creator=None, + custom_loggers=None, + sync_function=None, max_failures=0): """Initialize a new trial. @@ -146,6 +149,9 @@ def __init__(self, or self._get_trainable_cls().default_resource_request(self.config)) self.stopping_criterion = stopping_criterion or {} self.upload_dir = upload_dir + self.trial_name_creator = trial_name_creator + self.custom_loggers = custom_loggers + self.sync_function = sync_function self.verbose = True self.max_failures = max_failures @@ -160,10 +166,7 @@ def __init__(self, self.logdir = None self.result_logger = None self.last_debug = 0 - if trial_id is not None: - self.trial_id = trial_id - else: - self.trial_id = Trial.generate_id() + self.trial_id = Trial.generate_id() if trial_id is None else trial_id self.error_file = None self.num_failures = 0 @@ -181,8 +184,12 @@ def init_logger(self): prefix="{}_{}".format( str(self)[:MAX_LEN_IDENTIFIER], date_str()), dir=self.local_dir) - self.result_logger = UnifiedLogger(self.config, self.logdir, - self.upload_dir) + self.result_logger = UnifiedLogger( + self.config, + self.logdir, + upload_uri=self.upload_dir, + custom_loggers=self.custom_loggers, + sync_function=self.sync_function) def close_logger(self): """Close logger.""" @@ -216,17 +223,19 @@ def should_stop(self, result): return False - def should_checkpoint(self, result): + def should_checkpoint(self): """Whether this trial is due for checkpointing.""" + result = self.last_result or {} if result.get(DONE) and self.checkpoint_at_end: return True - if not self.checkpoint_freq: + if self.checkpoint_freq: + return result.get(TRAINING_ITERATION, + 0) % self.checkpoint_freq == 0 + else: return False - return self.last_result[TRAINING_ITERATION] % self.checkpoint_freq == 0 - def progress_string(self): """Returns a progress message for printing out to the console.""" @@ -281,10 +290,12 @@ def has_checkpoint(self): def should_recover(self): """Returns whether the trial qualifies for restoring. - This is if a checkpoint frequency is set, which includes settings - where there may not yet be a checkpoint. + This is if a checkpoint frequency is set and has not failed more than + max_failures. This may return true even when there may not yet + be a checkpoint. """ - return self.checkpoint_freq > 0 + return (self.checkpoint_freq > 0 + and self.num_failures < self.max_failures) def update_last_result(self, result, terminate=False): if terminate: @@ -312,12 +323,20 @@ def __repr__(self): return str(self) def __str__(self): - """Combines ``env`` with ``trainable_name`` and ``experiment_tag``.""" + """Combines ``env`` with ``trainable_name`` and ``experiment_tag``. + + Can be overriden with a custom string creator. + """ + if self.trial_name_creator: + return self.trial_name_creator(self) + if "env" in self.config: - identifier = "{}_{}".format(self.trainable_name, - self.config["env"]) + env = self.config["env"] + if isinstance(env, type): + env = env.__name__ + identifier = "{}_{}".format(self.trainable_name, env) else: identifier = self.trainable_name if self.experiment_tag: identifier += "_" + self.experiment_tag - return identifier + return identifier.replace("/", "_") diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index e0b541218bf19..063129780b47a 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -32,12 +32,10 @@ def has_resources(self, resources): "has_resources() method") def start_trial(self, trial, checkpoint=None): - """Starts the trial restoring from checkpoint if checkpoint != None. - - If an error is encountered when starting the trial, an exception will - be thrown. + """Starts the trial restoring from checkpoint if checkpoint is provided. Args: + trial (Trial): Trial to be started. checkpoint(Checkpoint): A Python object or path storing the state of trial. """ @@ -59,26 +57,6 @@ def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): raise NotImplementedError("Subclasses of TrialExecutor must provide " "stop_trial() method") - def restart_trial(self, trial, error_msg=None): - """Restarts or requeues the trial. - - The state of the trial should restore from the last checkpoint. Trial - is requeued if the cluster no longer has resources to accomodate it. - - Args: - error_msg (str): Optional error message. - """ - self.stop_trial( - trial, - error=error_msg is not None, - error_msg=error_msg, - stop_logger=False) - trial.result_logger.flush() - if self.has_resources(trial.resources): - self.start_trial(trial) - else: - trial.status = Trial.PENDING - def continue_training(self, trial): """Continues the training of this trial.""" pass diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index d89b3dda7ee1e..84457ff8d9e95 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -12,7 +12,7 @@ from ray.tune import TuneError from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import TIME_THIS_ITER_S -from ray.tune.trial import Trial +from ray.tune.trial import Trial, Checkpoint from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.web_server import TuneServer @@ -279,17 +279,14 @@ def _process_events(self): result, terminate=(decision == TrialScheduler.STOP)) if decision == TrialScheduler.CONTINUE: - if trial.should_checkpoint(result): - # TODO(rliaw): This is a blocking call - self.trial_executor.save(trial) + self._checkpoint_if_needed(trial) self.trial_executor.continue_training(trial) elif decision == TrialScheduler.PAUSE: self.trial_executor.pause_trial(trial) elif decision == TrialScheduler.STOP: # Checkpoint before ending the trial # if checkpoint_at_end experiment option is set to True - if trial.should_checkpoint(result): - self.trial_executor.save(trial) + self._checkpoint_if_needed(trial) self.trial_executor.stop_trial(trial) else: assert False, "Invalid scheduling decision: {}".format( @@ -298,24 +295,61 @@ def _process_events(self): logger.exception("Error processing event.") error_msg = traceback.format_exc() if trial.status == Trial.RUNNING: - if trial.should_recover() and \ - trial.num_failures < trial.max_failures: + if trial.should_recover(): self._try_recover(trial, error_msg) else: self._scheduler_alg.on_trial_error(self, trial) self._search_alg.on_trial_complete( trial.trial_id, error=True) - self.trial_executor.stop_trial(trial, True, error_msg) + self.trial_executor.stop_trial( + trial, error=True, error_msg=error_msg) + + def _checkpoint_if_needed(self, trial): + """Checkpoints trial based off trial.last_result.""" + if trial.should_checkpoint(): + # Save trial runtime if possible + if hasattr(trial, "runner") and trial.runner: + self.trial_executor.save(trial, storage=Checkpoint.DISK) def _try_recover(self, trial, error_msg): + """Tries to recover trial. + + Notifies SearchAlgorithm and Scheduler if failure to recover. + + Args: + trial (Trial): Trial to recover. + error_msg (str): Error message from prior to invoking this method. + """ try: - logger.info("Attempting to recover" - " trial state from last checkpoint.") - self.trial_executor.restart_trial(trial, error_msg) + self.trial_executor.stop_trial( + trial, + error=error_msg is not None, + error_msg=error_msg, + stop_logger=False) + trial.result_logger.flush() + if self.trial_executor.has_resources(trial.resources): + logger.info("Attempting to recover" + " trial state from last checkpoint.") + self.trial_executor.start_trial(trial) + if trial.status == Trial.ERROR: + raise RuntimeError("Trial did not start correctly.") + else: + logger.debug("Notifying Scheduler and requeueing trial.") + self._requeue_trial(trial) except Exception: - error_msg = traceback.format_exc() - logger.warning("Error recovering trial from checkpoint, abort.") - self.trial_executor.stop_trial(trial, True, error_msg=error_msg) + logger.exception("Error recovering trial from checkpoint, abort.") + self._scheduler_alg.on_trial_error(self, trial) + self._search_alg.on_trial_complete(trial.trial_id, error=True) + + def _requeue_trial(self, trial): + """Notification to TrialScheduler and requeue trial. + + This does not notify the SearchAlgorithm because + the function evaluation is still in progress. + """ + self._scheduler_alg.on_trial_error(self, trial) + trial.status = Trial.PENDING + self._scheduler_alg.on_trial_add(self, trial) def _update_trial_queue(self, blocking=False, timeout=600): """Adds next trials to queue if possible. diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index 9c047fd80043e..5d2db272679f1 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -3,6 +3,7 @@ from __future__ import print_function import base64 +import copy import numpy as np import ray @@ -35,6 +36,41 @@ def get_pinned_object(pinned_id): ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):])))) +def merge_dicts(d1, d2): + """Returns a new dict that is d1 and d2 deep merged.""" + merged = copy.deepcopy(d1) + deep_update(merged, d2, True, []) + return merged + + +def deep_update(original, new_dict, new_keys_allowed, whitelist): + """Updates original dict with values from new_dict recursively. + If new key is introduced in new_dict, then if new_keys_allowed is not + True, an error will be thrown. Further, for sub-dicts, if the key is + in the whitelist, then new subkeys can be introduced. + + Args: + original (dict): Dictionary with default values. + new_dict (dict): Dictionary with values to be updated + new_keys_allowed (bool): Whether new keys are allowed. + whitelist (list): List of keys that correspond to dict values + where new subkeys can be introduced. This is only at + the top level. + """ + for k, value in new_dict.items(): + if k not in original: + if not new_keys_allowed: + raise Exception("Unknown config parameter `{}` ".format(k)) + if type(original.get(k)) is dict: + if k in whitelist: + deep_update(original[k], value, True, []) + else: + deep_update(original[k], value, new_keys_allowed, []) + else: + original[k] = value + return original + + def _to_pinnable(obj): """Converts obj to a form that can be pinned in object store memory. diff --git a/python/ray/utils.py b/python/ray/utils.py index e75e006721444..cb3c33ecceac6 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -69,7 +69,7 @@ def push_error_to_driver(worker, if driver_id is None: driver_id = ray_constants.NIL_JOB_ID.id() data = {} if data is None else data - worker.local_scheduler_client.push_error( + worker.raylet_client.push_error( ray.ObjectID(driver_id), error_type, message, time.time()) @@ -165,10 +165,24 @@ def random_string(): return random_id -def decode(byte_str): - """Make this unicode in Python 3, otherwise leave it as bytes.""" +def decode(byte_str, allow_none=False): + """Make this unicode in Python 3, otherwise leave it as bytes. + + Args: + byte_str: The byte string to decode. + allow_none: If true, then we will allow byte_str to be None in which + case we will return an empty string. TODO(rkn): Remove this flag. + This is only here to simplify upgrading to flatbuffers 1.10.0. + + Returns: + A byte string in Python 2 and a unicode string in Python 3. + """ + if byte_str is None and allow_none: + return "" + if not isinstance(byte_str, bytes): - raise ValueError("The argument must be a bytes object.") + raise ValueError( + "The argument {} must be a bytes object.".format(byte_str)) if sys.version_info >= (3, 0): return byte_str.decode("ascii") else: diff --git a/python/ray/worker.py b/python/ray/worker.py index c3c01f4859fc0..9cdf3bbc9cd8b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -13,6 +13,7 @@ import os import redis import signal +from six.moves import queue import sys import threading import time @@ -35,7 +36,7 @@ import ray.ray_constants as ray_constants from ray import import_thread from ray import profiling -from ray.function_manager import FunctionActorManager +from ray.function_manager import (FunctionActorManager, FunctionDescriptor) from ray.utils import ( check_oversized_pickle, is_cython, @@ -53,7 +54,6 @@ # This must match the definition of NIL_ACTOR_ID in task.h. NIL_ID = ray_constants.ID_SIZE * b"\xff" NIL_LOCAL_SCHEDULER_ID = NIL_ID -NIL_FUNCTION_ID = NIL_ID NIL_ACTOR_ID = NIL_ID NIL_ACTOR_HANDLE_ID = NIL_ID NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" @@ -76,10 +76,6 @@ import setproctitle except ImportError: setproctitle = None - logger.warning( - "WARNING: Not updating worker name since `setproctitle` is not " - "installed. Install this with `pip install setproctitle` " - "(or ray[debug]) to enable monitoring of worker processes.") class RayTaskError(Exception): @@ -101,82 +97,35 @@ class RayTaskError(Exception): traceback_str (str): The traceback from the exception. """ - def __init__(self, function_name, exception, traceback_str): + def __init__(self, function_name, traceback_str): """Initialize a RayTaskError.""" - self.function_name = function_name - if (isinstance(exception, RayGetError) - or isinstance(exception, RayGetArgumentError)): - self.exception = exception + if setproctitle: + self.proctitle = setproctitle.getproctitle() else: - self.exception = None + self.proctitle = "ray_worker" + self.pid = os.getpid() + self.host = os.uname()[1] + self.function_name = function_name self.traceback_str = traceback_str + assert traceback_str is not None def __str__(self): """Format a RayTaskError as a string.""" - if self.traceback_str is None: - # This path is taken if getting the task arguments failed. - return ("Remote function {}{}{} failed with:\n\n{}".format( - colorama.Fore.RED, self.function_name, colorama.Fore.RESET, - self.exception)) - else: - # This path is taken if the task execution failed. - return ("Remote function {}{}{} failed with:\n\n{}".format( - colorama.Fore.RED, self.function_name, colorama.Fore.RESET, - self.traceback_str)) - - -class RayGetError(Exception): - """An exception used when get is called on an output of a failed task. - - Attributes: - objectid (lib.ObjectID): The ObjectID that get was called on. - task_error (RayTaskError): The RayTaskError object created by the - failed task. - """ - - def __init__(self, objectid, task_error): - """Initialize a RayGetError object.""" - self.objectid = objectid - self.task_error = task_error - - def __str__(self): - """Format a RayGetError as a string.""" - return ("Could not get objectid {}. It was created by remote function " - "{}{}{} which failed with:\n\n{}".format( - self.objectid, colorama.Fore.RED, - self.task_error.function_name, colorama.Fore.RESET, - self.task_error)) - - -class RayGetArgumentError(Exception): - """An exception used when a task's argument was produced by a failed task. - - Attributes: - argument_index (int): The index (zero indexed) of the failed argument - in present task's remote function call. - function_name (str): The name of the function for the current task. - objectid (lib.ObjectID): The ObjectID that was passed in as the - argument. - task_error (RayTaskError): The RayTaskError object created by the - failed task. - """ - - def __init__(self, function_name, argument_index, objectid, task_error): - """Initialize a RayGetArgumentError object.""" - self.argument_index = argument_index - self.function_name = function_name - self.objectid = objectid - self.task_error = task_error - - def __str__(self): - """Format a RayGetArgumentError as a string.""" - return ("Failed to get objectid {} as argument {} for remote function " - "{}{}{}. It was created by remote function {}{}{} which " - "failed with:\n{}".format( - self.objectid, self.argument_index, colorama.Fore.RED, - self.function_name, colorama.Fore.RESET, colorama.Fore.RED, - self.task_error.function_name, colorama.Fore.RESET, - self.task_error)) + lines = self.traceback_str.split("\n") + out = [] + in_worker = False + for line in lines: + if line.startswith("Traceback "): + out.append("{}{}{} (pid={}, host={})".format( + colorama.Fore.CYAN, self.proctitle, colorama.Fore.RESET, + self.pid, self.host)) + elif in_worker: + in_worker = False + elif "ray/worker.py" in line or "ray/function_manager.py" in line: + in_worker = True + else: + out.append(line) + return "\n".join(out) class Worker(object): @@ -211,17 +160,13 @@ def __init__(self): self.make_actor = None self.actors = {} self.actor_task_counter = 0 - # A set of all of the actor class keys that have been imported by the - # import thread. It is safe to convert this worker into an actor of - # these types. - self.imported_actor_classes = set() # The number of threads Plasma should use when putting an object in the # object store. self.memcopy_threads = 12 # When the worker is constructed. Record the original value of the # CUDA_VISIBLE_DEVICES environment variable. self.original_gpu_ids = ray.utils.get_cuda_visible_devices() - self.profiler = profiling.Profiler(self) + self.profiler = None self.memory_monitor = memory_monitor.MemoryMonitor() self.state_lock = threading.Lock() # A dictionary that maps from driver id to SerializationContext @@ -418,6 +363,17 @@ def put_object(self, object_id, value): logger.info( "The object with ID {} already exists in the object store." .format(object_id)) + except TypeError: + # This error can happen because one of the members of the object + # may not be serializable for cloudpickle. So we need these extra + # fallbacks here to start from the beginning. Hopefully the object + # could have a `__reduce__` method. + register_custom_serializer(type(value), use_pickle=True) + warning_message = ("WARNING: Serializing the class {} failed, " + "so are are falling back to cloudpickle." + .format(type(value))) + logger.warning(warning_message) + self.store_and_register(object_id, value) def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): start_time = time.time() @@ -442,7 +398,7 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): # TODO(ekl): the local scheduler could include relevant # metadata in the task kill case for a better error message invalid_error = RayTaskError( - "", None, + "", "Invalid return value: likely worker died or was killed " "while executing the task; check previous logs or dmesg " "for errors.") @@ -495,7 +451,7 @@ def get_object(self, object_ids): ] for i in range(0, len(object_ids), ray._config.worker_fetch_request_size()): - self.local_scheduler_client.fetch_or_reconstruct( + self.raylet_client.fetch_or_reconstruct( object_ids[i:(i + ray._config.worker_fetch_request_size())], True) @@ -530,7 +486,7 @@ def get_object(self, object_ids): ray._config.worker_fetch_request_size()) for i in range(0, len(object_ids_to_fetch), fetch_request_size): - self.local_scheduler_client.fetch_or_reconstruct( + self.raylet_client.fetch_or_reconstruct( ray_object_ids_to_fetch[i:( i + fetch_request_size)], False, current_task_id) @@ -551,13 +507,13 @@ def get_object(self, object_ids): # If there were objects that we weren't able to get locally, # let the local scheduler know that we're now unblocked. - self.local_scheduler_client.notify_unblocked(current_task_id) + self.raylet_client.notify_unblocked(current_task_id) assert len(final_results) == len(object_ids) return final_results def submit_task(self, - function_id, + function_descriptor, args, actor_id=None, actor_handle_id=None, @@ -565,19 +521,21 @@ def submit_task(self, is_actor_checkpoint_method=False, actor_creation_id=None, actor_creation_dummy_object_id=None, + max_actor_reconstructions=0, execution_dependencies=None, num_return_vals=None, resources=None, placement_resources=None, - driver_id=None): + driver_id=None, + language=ray.gcs_utils.Language.PYTHON): """Submit a remote task to the scheduler. - Tell the scheduler to schedule the execution of the function with ID - function_id with arguments args. Retrieve object IDs for the outputs of - the function from the scheduler and immediately return them. + Tell the scheduler to schedule the execution of the function with + function_descriptor with arguments args. Retrieve object IDs for the + outputs of the function from the scheduler and immediately return them. Args: - function_id: The ID of the function to execute. + function_descriptor: The function descriptor to execute. args: The arguments to pass into the function. Arguments can be object IDs or they can be values. If they are values, they must be serializable objects. @@ -661,14 +619,16 @@ def submit_task(self, # The parent task must be set for the submitted task. assert not self.current_task_id.is_nil() # Submit the task to local scheduler. + function_descriptor_list = ( + function_descriptor.get_function_descriptor_list()) task = ray.raylet.Task( - driver_id, ray.ObjectID( - function_id.id()), args_for_local_scheduler, + driver_id, function_descriptor_list, args_for_local_scheduler, num_return_vals, self.current_task_id, task_index, - actor_creation_id, actor_creation_dummy_object_id, actor_id, - actor_handle_id, actor_counter, execution_dependencies, - resources, placement_resources) - self.local_scheduler_client.submit(task) + actor_creation_id, actor_creation_dummy_object_id, + max_actor_reconstructions, actor_id, actor_handle_id, + actor_counter, execution_dependencies, resources, + placement_resources) + self.raylet_client.submit_task(task) return task.returns() @@ -750,7 +710,7 @@ def _get_arguments_for_execution(self, function_name, serialized_args): passed by value. Raises: - RayGetArgumentError: This exception is raised if a task that + RayTaskError: This exception is raised if a task that created one of the arguments failed. """ arguments = [] @@ -759,10 +719,7 @@ def _get_arguments_for_execution(self, function_name, serialized_args): # get the object from the local object store argument = self.get_object([arg])[0] if isinstance(argument, RayTaskError): - # If the result is a RayTaskError, then the task that - # created this object failed, and we should propagate the - # error message here. - raise RayGetArgumentError(function_name, i, arg, argument) + raise argument else: # pass the argument by value argument = arg @@ -819,10 +776,12 @@ def _process_task(self, task, function_execution_info): self.task_driver_id = task.driver_id() self.current_task_id = task.task_id() - function_id = task.function_id() + function_descriptor = FunctionDescriptor.from_bytes_list( + task.function_descriptor_list()) args = task.arguments() return_object_ids = task.returns() - if task.actor_id().id() != NIL_ACTOR_ID: + if (task.actor_id().id() != NIL_ACTOR_ID + or task.actor_creation_id().id() != NIL_ACTOR_ID): dummy_return_id = return_object_ids.pop() function_executor = function_execution_info.function function_name = function_execution_info.function_name @@ -835,34 +794,38 @@ def _process_task(self, task, function_execution_info): with profiling.profile("task:deserialize_arguments", worker=self): arguments = self._get_arguments_for_execution( function_name, args) - except (RayGetError, RayGetArgumentError) as e: - self._handle_process_task_failure(function_id, function_name, - return_object_ids, e, None) + except RayTaskError as e: + self._handle_process_task_failure( + function_descriptor, return_object_ids, e, + ray.utils.format_error_message(traceback.format_exc())) return except Exception as e: self._handle_process_task_failure( - function_id, function_name, return_object_ids, e, + function_descriptor, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) return # Execute the task. try: with profiling.profile("task:execute", worker=self): - if task.actor_id().id() == NIL_ACTOR_ID: + if (task.actor_id().id() == NIL_ACTOR_ID + and task.actor_creation_id().id() == NIL_ACTOR_ID): outputs = function_executor(*arguments) else: - outputs = function_executor( - dummy_return_id, self.actors[task.actor_id().id()], - *arguments) + if task.actor_id().id() != NIL_ACTOR_ID: + key = task.actor_id().id() + else: + key = task.actor_creation_id().id() + outputs = function_executor(dummy_return_id, + self.actors[key], *arguments) except Exception as e: # Determine whether the exception occured during a task, not an # actor method. task_exception = task.actor_id().id() == NIL_ACTOR_ID traceback_str = ray.utils.format_error_message( traceback.format_exc(), task_exception=task_exception) - self._handle_process_task_failure(function_id, function_name, - return_object_ids, e, - traceback_str) + self._handle_process_task_failure( + function_descriptor, return_object_ids, e, traceback_str) return # Store the outputs in the local object store. @@ -877,12 +840,14 @@ def _process_task(self, task, function_execution_info): self._store_outputs_in_object_store(return_object_ids, outputs) except Exception as e: self._handle_process_task_failure( - function_id, function_name, return_object_ids, e, + function_descriptor, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) - def _handle_process_task_failure(self, function_id, function_name, + def _handle_process_task_failure(self, function_descriptor, return_object_ids, error, backtrace): - failure_object = RayTaskError(function_name, error, backtrace) + function_name = function_descriptor.function_name + function_id = function_descriptor.function_id + failure_object = RayTaskError(function_name, backtrace) failure_objects = [ failure_object for _ in range(len(return_object_ids)) ] @@ -895,53 +860,34 @@ def _handle_process_task_failure(self, function_id, function_name, driver_id=self.task_driver_id.id(), data={ "function_id": function_id.id(), - "function_name": function_name + "function_name": function_name, + "module_name": function_descriptor.module_name, + "class_name": function_descriptor.class_name }) # Mark the actor init as failed if self.actor_id != NIL_ACTOR_ID and function_name == "__init__": self.mark_actor_init_failed(error) - def _become_actor(self, task): - """Turn this worker into an actor. - - Args: - task: The actor creation task. - """ - assert self.actor_id == NIL_ACTOR_ID - arguments = task.arguments() - assert len(arguments) == 1 - self.actor_id = task.actor_creation_id().id() - class_id = arguments[0] - - key = b"ActorClass:" + class_id - - # Wait for the actor class key to have been imported by the import - # thread. TODO(rkn): It shouldn't be possible to end up in an infinite - # loop here, but we should push an error to the driver if too much time - # is spent here. - while key not in self.imported_actor_classes: - time.sleep(0.001) - - with self.lock: - self.function_actor_manager.fetch_and_register_actor(key) - def _wait_for_and_process_task(self, task): """Wait for a task to be ready and process the task. Args: task: The task to execute. """ - function_id = task.function_id() + function_descriptor = FunctionDescriptor.from_bytes_list( + task.function_descriptor_list()) driver_id = task.driver_id().id() # TODO(rkn): It would be preferable for actor creation tasks to share # more of the code path with regular task execution. if (task.actor_creation_id() != ray.ObjectID(NIL_ACTOR_ID)): - self._become_actor(task) - return + assert self.actor_id == NIL_ACTOR_ID + self.actor_id = task.actor_creation_id().id() + self.function_actor_manager.load_actor(driver_id, + function_descriptor) execution_info = self.function_actor_manager.get_execution_info( - driver_id, function_id) + driver_id, function_descriptor) # Execute the task. # TODO(rkn): Consider acquiring this lock with a timeout and pushing a @@ -955,8 +901,14 @@ def _wait_for_and_process_task(self, task): "task_id": task.task_id().hex() } if task.actor_id().id() == NIL_ACTOR_ID: - title = "ray_worker:{}()".format(function_name) - next_title = "ray_worker" + if (task.actor_creation_id() == ray.ObjectID(NIL_ACTOR_ID)): + title = "ray_worker:{}()".format(function_name) + next_title = "ray_worker" + else: + actor = self.actors[task.actor_creation_id().id()] + title = "ray_{}:{}()".format(actor.__class__.__name__, + function_name) + next_title = "ray_{}".format(actor.__class__.__name__) else: actor = self.actors[task.actor_id().id()] title = "ray_{}:{}()".format(actor.__class__.__name__, @@ -974,12 +926,12 @@ def _wait_for_and_process_task(self, task): # Increase the task execution counter. self.function_actor_manager.increase_task_counter( - driver_id, function_id.id()) + driver_id, function_descriptor) reached_max_executions = (self.function_actor_manager.get_task_counter( - driver_id, function_id.id()) == execution_info.max_calls) + driver_id, function_descriptor) == execution_info.max_calls) if reached_max_executions: - self.local_scheduler_client.disconnect() + self.raylet_client.disconnect() sys.exit(0) def _get_next_task_from_local_scheduler(self): @@ -989,7 +941,7 @@ def _get_next_task_from_local_scheduler(self): A task from the local scheduler. """ with profiling.profile("worker_idle", worker=self): - task = self.local_scheduler_client.get_task() + task = self.raylet_client.get_task() # Automatically restrict the GPUs available to this task. ray.utils.set_cuda_visible_devices(ray.get_gpu_ids()) @@ -1025,7 +977,7 @@ def get_gpu_ids(): raise Exception("ray.get_gpu_ids() currently does not work in PYTHON " "MODE.") - all_resource_ids = global_worker.local_scheduler_client.resource_ids() + all_resource_ids = global_worker.raylet_client.resource_ids() assigned_ids = [ resource_id for resource_id, _ in all_resource_ids.get("GPU", []) ] @@ -1053,7 +1005,7 @@ def get_resource_ids(): "ray.get_resource_ids() currently does not work in PYTHON " "MODE.") - return global_worker.local_scheduler_client.resource_ids() + return global_worker.raylet_client.resource_ids() def _webui_url_helper(client): @@ -1189,18 +1141,6 @@ def actor_handle_deserializer(serialized_obj): local=True, driver_id=driver_id, class_id="ray.RayTaskError") - register_custom_serializer( - RayGetError, - use_dict=True, - local=True, - driver_id=driver_id, - class_id="ray.RayGetError") - register_custom_serializer( - RayGetArgumentError, - use_dict=True, - local=True, - driver_id=driver_id, - class_id="ray.RayGetArgumentError") # Tell Ray to serialize lambdas with pickle. register_custom_serializer( type(lambda: 0), @@ -1335,6 +1275,8 @@ def _init(address_info=None, num_workers=None, num_local_schedulers=None, object_store_memory=None, + redis_max_memory=None, + collect_profiling_data=True, local_mode=False, driver_mode=None, redirect_worker_output=False, @@ -1379,6 +1321,11 @@ def _init(address_info=None, This is only provided if start_ray_local is True. object_store_memory: The maximum amount of memory (in bytes) to allow the object store to use. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). + collect_profiling_data: Whether to collect profiling data from workers. local_mode (bool): True if the code should be executed serially without Ray. This is useful for debugging. redirect_worker_output: True if the stdout and stderr of worker @@ -1433,6 +1380,12 @@ def _init(address_info=None, else: driver_mode = SCRIPT_MODE + if redis_max_memory and collect_profiling_data: + logger.warning( + "Profiling data cannot be LRU evicted, so it is disabled " + "when redis_max_memory is set.") + collect_profiling_data = False + # Get addresses of existing services. if address_info is None: address_info = {} @@ -1472,6 +1425,8 @@ def _init(address_info=None, num_workers=num_workers, num_local_schedulers=num_local_schedulers, object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + collect_profiling_data=collect_profiling_data, redirect_worker_output=redirect_worker_output, redirect_output=redirect_output, start_workers_from_local_scheduler=( @@ -1512,6 +1467,9 @@ def _init(address_info=None, if object_store_memory is not None: raise Exception("When connecting to an existing cluster, " "object_store_memory must not be provided.") + if redis_max_memory is not None: + raise Exception("When connecting to an existing cluster, " + "redis_max_memory must not be provided.") if plasma_directory is not None: raise Exception("When connecting to an existing cluster, " "plasma_directory must not be provided.") @@ -1549,7 +1507,7 @@ def _init(address_info=None, "node_ip_address": node_ip_address, "redis_address": address_info["redis_address"], "store_socket_name": address_info["object_store_addresses"][0], - "webui_url": address_info["webui_url"] + "webui_url": address_info["webui_url"], } driver_address_info["raylet_socket_name"] = ( address_info["raylet_socket_names"][0]) @@ -1562,7 +1520,8 @@ def _init(address_info=None, mode=driver_mode, worker=global_worker, driver_id=driver_id, - redis_password=redis_password) + redis_password=redis_password, + collect_profiling_data=collect_profiling_data) return address_info @@ -1571,6 +1530,8 @@ def init(redis_address=None, num_gpus=None, resources=None, object_store_memory=None, + redis_max_memory=None, + collect_profiling_data=True, node_ip_address=None, object_id_seed=None, num_workers=None, @@ -1627,6 +1588,11 @@ def init(redis_address=None, of that resource available. object_store_memory: The amount of memory (in bytes) to start the object store with. + redis_max_memory: The max amount of memory (in bytes) to allow redis + to use, or None for no limit. Once the limit is exceeded, redis + will start LRU eviction of entries. This only applies to the + sharded redis tables (task and object tables). + collect_profiling_data: Whether to collect profiling data from workers. node_ip_address (str): The IP address of the node that we are on. object_id_seed (int): Used to seed the deterministic generation of object IDs. The same value can be used across multiple runs of the @@ -1674,17 +1640,24 @@ def init(redis_address=None, Exception: An exception is raised if an inappropriate combination of arguments is passed in. """ + + if configure_logging: + logging.basicConfig(level=logging_level, format=logging_format) + # Add the use_raylet option for backwards compatibility. if use_raylet is not None: if use_raylet: - logger.warn("WARNING: The use_raylet argument has been " - "deprecated. Please remove it.") + logger.warning("WARNING: The use_raylet argument has been " + "deprecated. Please remove it.") else: raise DeprecationWarning("The use_raylet argument is deprecated. " "Please remove it.") - if configure_logging: - logging.basicConfig(level=logging_level, format=logging_format) + if setproctitle is None: + logger.warning( + "WARNING: Not updating worker name since `setproctitle` is not " + "installed. Install this with `pip install setproctitle` " + "(or ray[debug]) to enable monitoring of worker processes.") if global_worker.connected: if ignore_reinit_error: @@ -1720,6 +1693,8 @@ def init(redis_address=None, huge_pages=huge_pages, include_webui=include_webui, object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + collect_profiling_data=collect_profiling_data, driver_id=driver_id, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, @@ -1754,8 +1729,8 @@ def shutdown(worker=global_worker): will need to reload the module. """ disconnect(worker) - if hasattr(worker, "local_scheduler_client"): - del worker.local_scheduler_client + if hasattr(worker, "raylet_client"): + del worker.raylet_client if hasattr(worker, "plasma_client"): worker.plasma_client.disconnect() @@ -1792,12 +1767,38 @@ def custom_excepthook(type, value, tb): sys.excepthook = custom_excepthook +# The last time we raised a TaskError in this process. We use this value to +# suppress redundant error messages pushed from the workers. +last_task_error_raise_time = 0 -def print_error_messages_raylet(worker): - """Print error messages in the background on the driver. +# The max amount of seconds to wait before printing out an uncaught error. +UNCAUGHT_ERROR_GRACE_PERIOD = 5 - This runs in a separate thread on the driver and prints error messages in - the background. + +def print_error_messages_raylet(task_error_queue): + """Prints message received in the given output queue. + + This checks periodically if any un-raised errors occured in the background. + """ + + while True: + error, t = task_error_queue.get() + # Delay errors a little bit of time to attempt to suppress redundant + # messages originating from the worker. + while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time(): + time.sleep(1) + if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD: + logger.debug("Suppressing error from worker: {}".format(error)) + else: + logger.error( + "Possible unhandled error from worker: {}".format(error)) + + +def listen_error_messages_raylet(worker, task_error_queue): + """Listen to error messages in the background on the driver. + + This runs in a separate thread on the driver and pushes (error, time) + tuples to the output queue. """ worker.error_message_pubsub_client = worker.redis_client.pubsub( ignore_subscribe_messages=True) @@ -1834,7 +1835,12 @@ def print_error_messages_raylet(worker): continue error_message = ray.utils.decode(error_data.ErrorMessage()) - logger.error(error_message) + if (ray.utils.decode( + error_data.Type()) == ray_constants.TASK_PUSH_ERROR): + # Delay it a bit to see if we can suppress it + task_error_queue.put((error_message, time.time())) + else: + logger.error(error_message) except redis.ConnectionError: # When Redis terminates the listen call will throw a ConnectionError, @@ -1899,12 +1905,13 @@ def connect(info, mode=WORKER_MODE, worker=global_worker, driver_id=None, - redis_password=None): + redis_password=None, + collect_profiling_data=True): """Connect this worker to the local scheduler, to Plasma, and to Redis. Args: info (dict): A dictionary with address of the Redis server and the - sockets of the plasma store, plasma manager, and local scheduler. + sockets of the plasma store and raylet. object_id_seed: A seed to use to make the generation of object IDs deterministic. mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and @@ -1912,6 +1919,7 @@ def connect(info, driver_id: The ID of driver. If it's None, then we will generate one. redis_password (str): Prevents external clients without the password from connecting to Redis if provided. + collect_profiling_data: Whether to collect profiling data from workers. """ # Do some basic checking to make sure we didn't call ray.init twice. error_message = "Perhaps you called ray.init twice by accident?" @@ -1921,6 +1929,11 @@ def connect(info, # Enable nice stack traces on SIGSEGV etc. faulthandler.enable(all_threads=False) + if collect_profiling_data: + worker.profiler = profiling.Profiler(worker) + else: + worker.profiler = profiling.NoopProfiler() + # Initialize some fields. if mode is WORKER_MODE: worker.worker_id = random_string() @@ -2043,7 +2056,7 @@ def connect(info, # Create an object store client. worker.plasma_client = thread_safe_client( - plasma.connect(info["store_socket_name"], "", 64)) + plasma.connect(info["store_socket_name"])) raylet_socket = info["raylet_socket_name"] @@ -2076,15 +2089,14 @@ def connect(info, # rerun the driver. nil_actor_counter = 0 - driver_task = ray.raylet.Task(worker.task_driver_id, - ray.ObjectID(NIL_FUNCTION_ID), [], 0, - worker.current_task_id, - worker.task_index, - ray.ObjectID(NIL_ACTOR_ID), - ray.ObjectID(NIL_ACTOR_ID), - ray.ObjectID(NIL_ACTOR_ID), - ray.ObjectID(NIL_ACTOR_ID), - nil_actor_counter, [], {"CPU": 0}, {}) + function_descriptor = FunctionDescriptor.for_driver_task() + driver_task = ray.raylet.Task( + worker.task_driver_id, + function_descriptor.get_function_descriptor_list(), [], 0, + worker.current_task_id, worker.task_index, + ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), 0, + ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), + nil_actor_counter, [], {"CPU": 0}, {}) # Add the driver task to the task table. global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD", @@ -2103,7 +2115,7 @@ def connect(info, # multithreading per worker. worker.multithreading_warned = False - worker.local_scheduler_client = ray.raylet.LocalSchedulerClient( + worker.raylet_client = ray.raylet.RayletClient( raylet_socket, worker.worker_id, is_worker, worker.current_task_id) # Start the import thread @@ -2116,14 +2128,19 @@ def connect(info, # temporarily using this implementation which constantly queries the # scheduler for new error messages. if mode == SCRIPT_MODE: - t = threading.Thread( + q = queue.Queue() + listener = threading.Thread( + target=listen_error_messages_raylet, + name="ray_listen_error_messages", + args=(worker, q)) + printer = threading.Thread( target=print_error_messages_raylet, name="ray_print_error_messages", - args=(worker, )) - # Making the thread a daemon causes it to exit when the main thread - # exits. - t.daemon = True - t.start() + args=(q, )) + listener.daemon = True + listener.start() + printer.daemon = True + printer.start() # If we are using the raylet code path and we are not in local mode, start # a background thread to periodically flush profiling data to the GCS. @@ -2351,11 +2368,13 @@ def get(object_ids, worker=global_worker): # In LOCAL_MODE, ray.get is the identity operation (the input will # actually be a value not an objectid). return object_ids + global last_task_error_raise_time if isinstance(object_ids, list): values = worker.get_object(object_ids) for i, value in enumerate(values): if isinstance(value, RayTaskError): - raise RayGetError(object_ids[i], value) + last_task_error_raise_time = time.time() + raise value return values else: value = worker.get_object([object_ids])[0] @@ -2363,7 +2382,8 @@ def get(object_ids, worker=global_worker): # If the result is a RayTaskError, then the task that created # this object failed, and we should propagate the error message # here. - raise RayGetError(object_ids, value) + last_task_error_raise_time = time.time() + raise value return value @@ -2381,7 +2401,7 @@ def put(value, worker=global_worker): if worker.mode == LOCAL_MODE: # In LOCAL_MODE, ray.put is the identity operation. return value - object_id = worker.local_scheduler_client.compute_put_id( + object_id = worker.raylet_client.compute_put_id( worker.current_task_id, worker.put_index) worker.put_object(object_id, value) worker.put_index += 1 @@ -2461,7 +2481,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): current_task_id = worker.get_current_thread_task_id() timeout = timeout if timeout is not None else 2**30 - ready_ids, remaining_ids = worker.local_scheduler_client.wait( + ready_ids, remaining_ids = worker.raylet_client.wait( object_ids, num_returns, timeout, False, current_task_id) return ready_ids, remaining_ids @@ -2487,6 +2507,7 @@ def make_decorator(num_return_vals=None, resources=None, max_calls=None, checkpoint_interval=None, + max_reconstructions=None, worker=None): def decorator(function_or_class): if (inspect.isfunction(function_or_class) @@ -2495,6 +2516,9 @@ def decorator(function_or_class): if checkpoint_interval is not None: raise Exception("The keyword 'checkpoint_interval' is not " "allowed for remote functions.") + if max_reconstructions is not None: + raise Exception("The keyword 'max_reconstructions' is not " + "allowed for remote functions.") return ray.remote_function.RemoteFunction( function_or_class, num_cpus, num_gpus, resources, @@ -2524,7 +2548,7 @@ def decorator(function_or_class): return worker.make_actor(function_or_class, cpus_to_use, num_gpus, resources, actor_method_cpus, - checkpoint_interval) + checkpoint_interval, max_reconstructions) raise Exception("The @ray.remote decorator must be applied to " "either a function or to a class.") @@ -2566,6 +2590,11 @@ def method(self): third-party libraries or to reclaim resources that cannot easily be released, e.g., GPU memory that was acquired by TensorFlow). By default this is infinite. + * **max_reconstructions**: Only for *actors*. This specifies the maximum + number of times that the actor should be reconstructed when it dies + unexpectedly. The minimum valid value is 0 (default), which indicates + that the actor doesn't need to be reconstructed. And the maximum valid + value is ray.ray_constants.INFINITE_RECONSTRUCTIONS. This can be done as follows: @@ -2591,14 +2620,15 @@ def method(self): "with no arguments and no parentheses, for example " "'@ray.remote', or it must be applied using some of " "the arguments 'num_return_vals', 'num_cpus', 'num_gpus', " - "'resources', 'max_calls', or 'checkpoint_interval', like " + "'resources', 'max_calls', 'checkpoint_interval'," + "or 'max_reconstructions', like " "'@ray.remote(num_return_vals=2, " "resources={\"CustomResource\": 1})'.") assert len(args) == 0 and len(kwargs) > 0, error_string for key in kwargs: assert key in [ "num_return_vals", "num_cpus", "num_gpus", "resources", - "max_calls", "checkpoint_interval" + "max_calls", "checkpoint_interval", "max_reconstructions" ], error_string num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None @@ -2616,6 +2646,7 @@ def method(self): num_return_vals = kwargs.get("num_return_vals") max_calls = kwargs.get("max_calls") checkpoint_interval = kwargs.get("checkpoint_interval") + max_reconstructions = kwargs.get("max_reconstructions") return make_decorator( num_return_vals=num_return_vals, @@ -2624,4 +2655,5 @@ def method(self): resources=resources, max_calls=max_calls, checkpoint_interval=checkpoint_interval, + max_reconstructions=max_reconstructions, worker=worker) diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index b9c9500e70870..dc1085783b8aa 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -55,6 +55,11 @@ type=str, default=ray_constants.LOGGER_FORMAT, help=ray_constants.LOGGER_FORMAT_HELP) +parser.add_argument( + "--collect-profiling-data", + type=int, # int since argparse can't handle bool values + default=1, + help="Whether to collect profiling data from workers.") parser.add_argument( "--temp-dir", required=False, @@ -71,7 +76,7 @@ "redis_password": args.redis_password, "store_socket_name": args.object_store_name, "manager_socket_name": args.object_store_manager_name, - "raylet_socket_name": args.raylet_name + "raylet_socket_name": args.raylet_name, } logging.basicConfig( @@ -82,7 +87,10 @@ tempfile_services.set_temp_root(args.temp_dir) ray.worker.connect( - info, mode=ray.WORKER_MODE, redis_password=args.redis_password) + info, + mode=ray.WORKER_MODE, + redis_password=args.redis_password, + collect_profiling_data=args.collect_profiling_data) error_explanation = """ This error is unexpected and should not have happened. Somehow a worker diff --git a/python/setup.py b/python/setup.py index c92ffa65b481d..41d04fb85ff86 100644 --- a/python/setup.py +++ b/python/setup.py @@ -23,7 +23,7 @@ "ray/core/src/ray/thirdparty/redis/src/redis-server", "ray/core/src/ray/gcs/redis_module/libray_redis_module.so", "ray/core/src/plasma/plasma_store_server", - "ray/core/src/ray/raylet/liblocal_scheduler_library_python.so", + "ray/core/src/ray/raylet/libraylet_library_python.so", "ray/core/src/ray/raylet/raylet_monitor", "ray/core/src/ray/raylet/raylet", "ray/WebUI.ipynb" ] @@ -135,7 +135,8 @@ def find_version(*filepath): requires = [ - "numpy", + "numpy >= 1.10.4", + "filelock", "funcsigs", "click", "colorama", @@ -153,6 +154,8 @@ def find_version(*filepath): setup( name="ray", version=find_version("ray", "__init__.py"), + author="Ray Team", + author_email="ray-dev@googlegroups.com", description=("A system for parallel and distributed Python that unifies " "the ML ecosystem."), long_description=open("../README.rst").read(), @@ -164,7 +167,7 @@ def find_version(*filepath): # The BinaryDistribution argument triggers build_ext. distclass=BinaryDistribution, install_requires=requires, - setup_requires=["cython >= 0.27, < 0.28"], + setup_requires=["cython >= 0.29"], extras_require=extras, entry_points={ "console_scripts": [ diff --git a/src/ray/CMakeLists.txt b/src/ray/CMakeLists.txt index 3916423a6f261..378b5fc67ad7d 100644 --- a/src/ray/CMakeLists.txt +++ b/src/ray/CMakeLists.txt @@ -1,11 +1,6 @@ cmake_minimum_required(VERSION 3.2) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -std=c++11") -add_subdirectory(util) -add_subdirectory(gcs) -add_subdirectory(object_manager) -add_subdirectory(raylet) - include_directories(thirdparty/ae) set(HIREDIS_SRCS @@ -56,8 +51,15 @@ set(RAY_SRCS raylet/raylet.cc ) -set(RAY_LIB_STATIC_LINK_LIBS ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB}) +set(RAY_SYSTEM_LIBS ${Boost_SYSTEM_LIBRARY} ${Boost_THREAD_LIBRARY} pthread) +if(UNIX AND NOT APPLE) + set(RAY_SYSTEM_LIBS ${RAY_SYSTEM_LIBS} -lrt) +endif() + +set(RAY_LIB_STATIC_LINK_LIBS ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${RAY_SYSTEM_LIBS}) + set(RAY_LIB_DEPENDENCIES + boost_thread arrow_ep gen_gcs_fbs gen_object_manager_fbs @@ -89,6 +91,13 @@ ADD_RAY_LIB(ray SHARED_LINK_LIBS "" STATIC_LINK_LIBS ${RAY_LIB_STATIC_LINK_LIBS}) +set(RAY_TEST_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main ${RAY_SYSTEM_LIBS}) + +add_subdirectory(util) +add_subdirectory(gcs) +add_subdirectory(object_manager) +add_subdirectory(raylet) + add_custom_target(copy_redis ALL) foreach(file "redis-cli" "redis-server") add_custom_command(TARGET copy_redis POST_BUILD diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index 1ae225443fc36..1ca93ba2a988b 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -16,7 +16,13 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket, boost::asio::ip::tcp::endpoint endpoint(ip_address, port); boost::system::error_code error; socket.connect(endpoint, error); - return boost_to_ray_status(error); + const auto status = boost_to_ray_status(error); + if (!status.ok()) { + // Close the socket if the connect failed. + boost::system::error_code close_error; + socket.close(close_error); + } + return status; } template @@ -33,6 +39,14 @@ ServerConnection::ServerConnection(boost::asio::basic_stream_socket &&sock async_write_queue_(), async_write_in_flight_(false) {} +template +ServerConnection::~ServerConnection() { + // If there are any pending messages, invoke their callbacks with an IOError status. + for (const auto &write_buffer : async_write_queue_) { + write_buffer->handler(Status::IOError("Connection closed.")); + } +} + template Status ServerConnection::WriteBuffer( const std::vector &buffer) { diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index d4ca993d26a4b..7246c2b8125ef 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -29,6 +29,9 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket, template class ServerConnection : public std::enable_shared_from_this> { public: + /// ServerConnection destructor. + virtual ~ServerConnection(); + /// Allocate a new server connection. /// /// \param socket A reference to the server socket. diff --git a/src/ray/common/common_protocol.cc b/src/ray/common/common_protocol.cc index 5ce4c89d62120..bcbfcc5f02006 100644 --- a/src/ray/common/common_protocol.cc +++ b/src/ray/common/common_protocol.cc @@ -69,3 +69,25 @@ map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, } return fbb.CreateVector(resource_vector); } + +std::vector string_vec_from_flatbuf( + const flatbuffers::Vector> &flatbuf_vec) { + std::vector string_vector; + string_vector.reserve(flatbuf_vec.size()); + for (int64_t i = 0; i < flatbuf_vec.size(); i++) { + const auto flatbuf_str = flatbuf_vec.Get(i); + string_vector.push_back(string_from_flatbuf(*flatbuf_str)); + } + return string_vector; +} + +flatbuffers::Offset>> +string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, + const std::vector &string_vector) { + std::vector> flatbuf_str_vec; + flatbuf_str_vec.reserve(flatbuf_str_vec.size()); + for (auto const &str : string_vector) { + flatbuf_str_vec.push_back(fbb.CreateString(str)); + } + return fbb.CreateVector(flatbuf_str_vec); +} diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index 3afa6b8e5781c..de8f27fc4f388 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -72,4 +72,10 @@ map_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::unordered_map map_from_flatbuf( const flatbuffers::Vector> &resource_vector); +std::vector string_vec_from_flatbuf( + const flatbuffers::Vector> &flatbuf_vec); + +flatbuffers::Offset>> +string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, + const std::vector &string_vector); #endif diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index c1cbee11303e4..ba660e4f0ffc6 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -106,13 +106,13 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, /*password=*/password)); } + actor_table_.reset(new ActorTable({primary_context_}, this)); client_table_.reset(new ClientTable({primary_context_}, this, client_id)); error_table_.reset(new ErrorTable({primary_context_}, this)); driver_table_.reset(new DriverTable({primary_context_}, this)); heartbeat_batch_table_.reset(new HeartbeatBatchTable({primary_context_}, this)); // Tables below would be sharded. object_table_.reset(new ObjectTable(shard_contexts_, this, command_type)); - actor_table_.reset(new ActorTable(shard_contexts_, this)); raylet_task_table_.reset(new raylet::TaskTable(shard_contexts_, this, command_type)); task_reconstruction_log_.reset(new TaskReconstructionLog(shard_contexts_, this)); task_lease_table_.reset(new TaskLeaseTable(shard_contexts_, this)); diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 6414613d45840..df4077a45fecb 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -60,6 +60,9 @@ table TaskInfo { actor_creation_id: string; // The dummy object ID of the actor creation task if this is an actor method. actor_creation_dummy_object_id: string; + // The max number of times this actor should be recontructed. + // If this number of 0 or negative, the actor won't be reconstructed on failure. + max_actor_reconstructions: int; // Actor ID of the task. This is the actor that this task is executed on // or NIL_ACTOR_ID if the task is just a normal task. actor_id: string; @@ -70,8 +73,6 @@ table TaskInfo { actor_counter: int; // True if this task is an actor checkpoint task and false otherwise. is_actor_checkpoint_method: bool; - // Function ID of the task. - function_id: string; // Task arguments. args: [Arg]; // Object IDs of return values. @@ -167,8 +168,11 @@ table ClassTableData { enum ActorState:int { // Actor is alive. ALIVE = 0, + // Actor is dead, now being reconstructed. + // After reconstruction finishes, the state will become alive again. + RECONSTRUCTING = 1, // Actor is already dead and won't be reconstructed. - DEAD + DEAD = 2 } table ActorTableData { @@ -184,6 +188,10 @@ table ActorTableData { node_manager_id: string; // Current state of this actor. state: ActorState; + // Max number of times this actor should be reconstructed. + max_reconstructions: int; + // Remaining number of reconstructions. + remaining_reconstructions: int; } table ErrorTableData { diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index f832f9e186095..f19eacbb63fe2 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -20,41 +20,6 @@ extern RedisChainModule module; #endif -// Various tables are maintained in redis: -// -// == OBJECT TABLE == -// -// This consists of two parts: -// - The object location table, indexed by OL:object_id, which is the set of -// plasma manager indices that have access to the object. -// (In redis this is represented by a zset (sorted set).) -// -// - The object info table, indexed by OI:object_id, which is a hashmap of: -// "hash" -> the hash of the object, -// "data_size" -> the size of the object in bytes, -// "task" -> the task ID that generated this object. -// "is_put" -> 0 or 1. -// -// == TASK TABLE == -// -// It maps each TT:task_id to a hash: -// "state" -> the state of the task, encoded as a bit mask of scheduling_state -// enum values in task.h, -// "local_scheduler_id" -> the ID of the local scheduler the task is assigned -// to, -// "TaskSpec" -> serialized bytes of a TaskInfo (defined in common.fbs), which -// describes the details this task. -// -// See also the definition of TaskReply in common.fbs. - -#define OBJECT_INFO_PREFIX "OI:" -#define OBJECT_LOCATION_PREFIX "OL:" -#define OBJECT_NOTIFICATION_PREFIX "ON:" -#define TASK_PREFIX "TT:" -#define OBJECT_BCAST "BCAST" - -#define OBJECT_CHANNEL_PREFIX "OC:" - #define CHECK_ERROR(STATUS, MESSAGE) \ if ((STATUS) == REDISMODULE_ERR) { \ return RedisModule_ReplyWithError(ctx, (MESSAGE)); \ diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index c6c12aa53069c..e5770effaf64e 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -121,7 +121,7 @@ Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, if (subscribe != nullptr) { // Parse the notification. auto root = flatbuffers::GetRoot(data.data()); - ID id = UniqueID::nil(); + ID id; if (root->id()->size() > 0) { id = from_flatbuf(*root->id()); } @@ -368,10 +368,13 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); } else { + // NOTE(swang): The client should be added to this data structure before + // the callback gets called, in case the callback depends on the data + // structure getting updated. + removed_clients_.insert(client_id); if (client_removed_callback_ != nullptr) { client_removed_callback_(client, client_id, data); } - removed_clients_.insert(client_id); } } } diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 5cca066fb453c..1e17210cd3e03 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -393,7 +393,11 @@ class FunctionTable : public Table { using ClassTable = Table; -// TODO(swang): Set the pubsub channel for the actor table. +/// Actor table starts with an ALIVE entry, which represents the first time the actor +/// is created. This may be followed by 0 or more pairs of RECONSTRUCTING, ALIVE entries, +/// which represent each time the actor fails (RECONSTRUCTING) and gets recreated (ALIVE). +/// These may be followed by a DEAD entry, which means that the actor has failed and will +/// not be reconstructed. class ActorTable : public Log { public: ActorTable(const std::vector> &contexts, @@ -548,7 +552,7 @@ class ClientTable : private Log { : Log(contexts, client), // We set the client log's key equal to nil so that all instances of // ClientTable have the same key. - client_log_key_(UniqueID::nil()), + client_log_key_(), disconnected_(false), client_id_(client_id), local_client_() { @@ -618,8 +622,6 @@ class ClientTable : private Log { /// Get the information of all clients. /// - /// Note: The return value contains ClientID::nil() which should be filtered. - /// /// \return The client ID to client information map. const std::unordered_map &GetAllClients() const; diff --git a/src/ray/id.cc b/src/ray/id.cc index 95ab1bd640805..b3b2b187e50f4 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -24,6 +24,11 @@ std::mt19937 RandomlySeededMersenneTwister() { return seeded_engine; } +UniqueID::UniqueID() { + // Set the ID to nil. + std::fill_n(id_, kUniqueIDSize, 255); +} + UniqueID::UniqueID(const plasma::UniqueID &from) { std::memcpy(&id_, from.data(), kUniqueIDSize); } @@ -50,11 +55,9 @@ UniqueID UniqueID::from_binary(const std::string &binary) { return id; } -const UniqueID UniqueID::nil() { - UniqueID result; - uint8_t *data = result.mutable_data(); - std::fill_n(data, kUniqueIDSize, 255); - return result; +const UniqueID &UniqueID::nil() { + static const UniqueID nil_id; + return nil_id; } bool UniqueID::is_nil() const { @@ -67,17 +70,11 @@ bool UniqueID::is_nil() const { return true; } -const uint8_t *UniqueID::data() const { - return id_; -} +const uint8_t *UniqueID::data() const { return id_; } -uint8_t *UniqueID::mutable_data() { - return id_; -} +uint8_t *UniqueID::mutable_data() { return id_; } -size_t UniqueID::size() const { - return kUniqueIDSize; -} +size_t UniqueID::size() const { return kUniqueIDSize; } std::string UniqueID::binary() const { return std::string(reinterpret_cast(id_), kUniqueIDSize); diff --git a/src/ray/id.h b/src/ray/id.h index daac028fd7a85..0ab0c56408e74 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -14,11 +14,11 @@ namespace ray { class RAY_EXPORT UniqueID { public: - UniqueID() {} + UniqueID(); UniqueID(const plasma::UniqueID &from); static UniqueID from_random(); static UniqueID from_binary(const std::string &binary); - static const UniqueID nil(); + static const UniqueID &nil(); size_t hash() const; bool is_nil() const; bool operator==(const UniqueID &rhs) const; @@ -34,8 +34,7 @@ class RAY_EXPORT UniqueID { uint8_t id_[kUniqueIDSize]; }; -static_assert(std::is_standard_layout::value, - "UniqueID must be standard"); +static_assert(std::is_standard_layout::value, "UniqueID must be standard"); std::ostream &operator<<(std::ostream &os, const UniqueID &id); diff --git a/src/ray/object_manager/CMakeLists.txt b/src/ray/object_manager/CMakeLists.txt index 2d3c0f42f1bcc..054bd8cbacfb6 100644 --- a/src/ray/object_manager/CMakeLists.txt +++ b/src/ray/object_manager/CMakeLists.txt @@ -17,11 +17,11 @@ add_custom_command( add_custom_target(gen_object_manager_fbs DEPENDS ${OBJECT_MANAGER_FBS_OUTPUT_FILES}) -ADD_RAY_TEST(test/object_manager_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY}) -ADD_RAY_TEST(test/object_manager_stress_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY}) +ADD_RAY_TEST(test/object_manager_test STATIC_LINK_LIBS ${RAY_TEST_LIBS}) +ADD_RAY_TEST(test/object_manager_stress_test STATIC_LINK_LIBS ${RAY_TEST_LIBS}) add_library(object_manager object_manager.cc object_manager.h ${OBJECT_MANAGER_FBS_OUTPUT_FILES}) -target_link_libraries(object_manager common ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY}) +target_link_libraries(object_manager common ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY} ${Boost_THREAD_LIBRARY}) install(FILES object_manager diff --git a/src/ray/object_manager/object_buffer_pool.cc b/src/ray/object_manager/object_buffer_pool.cc index 5f101e7959202..fa312f0c78fb4 100644 --- a/src/ray/object_manager/object_buffer_pool.cc +++ b/src/ray/object_manager/object_buffer_pool.cc @@ -3,10 +3,10 @@ namespace ray { ObjectBufferPool::ObjectBufferPool(const std::string &store_socket_name, - uint64_t chunk_size, int release_delay) + uint64_t chunk_size) : default_chunk_size_(chunk_size) { store_socket_name_ = store_socket_name; - ARROW_CHECK_OK(store_client_.Connect(store_socket_name_.c_str(), "", release_delay)); + ARROW_CHECK_OK(store_client_.Connect(store_socket_name_.c_str())); } ObjectBufferPool::~ObjectBufferPool() { diff --git a/src/ray/object_manager/object_buffer_pool.h b/src/ray/object_manager/object_buffer_pool.h index ed6594ed4b496..e4790dfe537cf 100644 --- a/src/ray/object_manager/object_buffer_pool.h +++ b/src/ray/object_manager/object_buffer_pool.h @@ -40,10 +40,7 @@ class ObjectBufferPool { /// \param store_socket_name The socket name of the store to which plasma clients /// connect. /// \param chunk_size The chunk size into which objects are to be split. - /// \param release_delay The number of release calls before objects are released - /// from the store client (FIFO). - ObjectBufferPool(const std::string &store_socket_name, const uint64_t chunk_size, - const int release_delay); + ObjectBufferPool(const std::string &store_socket_name, const uint64_t chunk_size); ~ObjectBufferPool(); diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index ab20d27b66c61..db2b4b1490de7 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -8,10 +8,15 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, namespace { -std::vector UpdateObjectLocations( - std::unordered_set &client_ids, - const std::vector &location_history, - const ray::gcs::ClientTable &client_table) { +/// Process a suffix of the object table log and store the result in +/// client_ids. This assumes that client_ids already contains the result of the +/// object table log up to but not including this suffix. This also stores a +/// bool in has_been_created indicating whether the object has ever been +/// created before. +void UpdateObjectLocations(const std::vector &location_history, + const ray::gcs::ClientTable &client_table, + std::unordered_set *client_ids, + bool *has_been_created) { // location_history contains the history of locations of the object (it is a log), // which might look like the following: // client1.is_eviction = false @@ -19,23 +24,27 @@ std::vector UpdateObjectLocations( // client2.is_eviction = false // In such a scenario, we want to indicate client2 is the only client that contains // the object, which the following code achieves. + if (!location_history.empty()) { + // If there are entries, then the object has been created. Once this flag + // is set to true, it should never go back to false. + *has_been_created = true; + } for (const auto &object_table_data : location_history) { ClientID client_id = ClientID::from_binary(object_table_data.manager); if (!object_table_data.is_eviction) { - client_ids.insert(client_id); + client_ids->insert(client_id); } else { - client_ids.erase(client_id); + client_ids->erase(client_id); } } // Filter out the removed clients from the object locations. - for (auto it = client_ids.begin(); it != client_ids.end();) { + for (auto it = client_ids->begin(); it != client_ids->end();) { if (client_table.IsRemoved(*it)) { - it = client_ids.erase(it); + it = client_ids->erase(it); } else { it++; } } - return std::vector(client_ids.begin(), client_ids.end()); } } // namespace @@ -45,18 +54,18 @@ void ObjectDirectory::RegisterBackend() { gcs::AsyncGcsClient *client, const ObjectID &object_id, const std::vector &location_history) { // Objects are added to this map in SubscribeObjectLocations. - auto object_id_listener_pair = listeners_.find(object_id); + auto it = listeners_.find(object_id); // Do nothing for objects we are not listening for. - if (object_id_listener_pair == listeners_.end()) { + if (it == listeners_.end()) { return; } // Update entries for this object. - std::vector client_id_vec = - UpdateObjectLocations(object_id_listener_pair->second.current_object_locations, - location_history, gcs_client_->client_table()); + UpdateObjectLocations(location_history, gcs_client_->client_table(), + &it->second.current_object_locations, + &it->second.has_been_created); // Copy the callbacks so that the callbacks can unsubscribe without interrupting // looping over the callbacks. - auto callbacks = object_id_listener_pair->second.callbacks; + auto callbacks = it->second.callbacks; // Call all callbacks associated with the object id locations we have // received. This notifies the client even if the list of locations is // empty, since this may indicate that the objects have been evicted from @@ -64,7 +73,8 @@ void ObjectDirectory::RegisterBackend() { for (const auto &callback_pair : callbacks) { // It is safe to call the callback directly since this is already running // in the subscription callback stack. - callback_pair.second(client_id_vec, object_id); + callback_pair.second(object_id, it->second.current_object_locations, + it->second.has_been_created); } }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( @@ -76,27 +86,25 @@ ray::Status ObjectDirectory::ReportObjectAdded( const ObjectID &object_id, const ClientID &client_id, const object_manager::protocol::ObjectInfoT &object_info) { // Append the addition entry to the object table. - JobID job_id = JobID::nil(); auto data = std::make_shared(); data->manager = client_id.binary(); data->is_eviction = false; data->num_evictions = object_evictions_[object_id]; data->object_size = object_info.data_size; ray::Status status = - gcs_client_->object_table().Append(job_id, object_id, data, nullptr); + gcs_client_->object_table().Append(JobID::nil(), object_id, data, nullptr); return status; } ray::Status ObjectDirectory::ReportObjectRemoved(const ObjectID &object_id, const ClientID &client_id) { // Append the eviction entry to the object table. - JobID job_id = JobID::nil(); auto data = std::make_shared(); data->manager = client_id.binary(); data->is_eviction = true; data->num_evictions = object_evictions_[object_id]; ray::Status status = - gcs_client_->object_table().Append(job_id, object_id, data, nullptr); + gcs_client_->object_table().Append(JobID::nil(), object_id, data, nullptr); // Increment the number of times we've evicted this object. NOTE(swang): This // is only necessary because the Ray redis module expects unique entries in a // log. We track the number of evictions so that the next eviction, if there @@ -133,28 +141,51 @@ std::vector ObjectDirectory::LookupAllRemoteConnections() return remote_connections; } +void ObjectDirectory::HandleClientRemoved(const ClientID &client_id) { + for (auto &listener : listeners_) { + const ObjectID &object_id = listener.first; + if (listener.second.current_object_locations.count(client_id) > 0) { + // If the subscribed object has the removed client as a location, update + // its locations with an empty log so that the location will be removed. + UpdateObjectLocations({}, gcs_client_->client_table(), + &listener.second.current_object_locations, + &listener.second.has_been_created); + // Re-call all the subscribed callbacks for the object, since its + // locations have changed. + for (const auto &callback_pair : listener.second.callbacks) { + // It is safe to call the callback directly since this is already running + // in the subscription callback stack. + callback_pair.second(object_id, listener.second.current_object_locations, + listener.second.has_been_created); + } + } + } +} + ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_id, const ObjectID &object_id, const OnLocationsFound &callback) { ray::Status status = ray::Status::OK(); - if (listeners_.find(object_id) == listeners_.end()) { - listeners_.emplace(object_id, LocationListenerState()); + auto it = listeners_.find(object_id); + if (it == listeners_.end()) { + it = listeners_.emplace(object_id, LocationListenerState()).first; status = gcs_client_->object_table().RequestNotifications( JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); } - auto &listener_state = listeners_.find(object_id)->second; + auto &listener_state = it->second; // TODO(hme): Make this fatal after implementing Pull suppression. if (listener_state.callbacks.count(callback_id) > 0) { return ray::Status::OK(); } listener_state.callbacks.emplace(callback_id, callback); - // Immediately notify of object locations. This notifies the client even if - // the list of locations is empty, since this may indicate that the objects - // have been evicted from all nodes. - std::vector client_id_vec(listener_state.current_object_locations.begin(), - listener_state.current_object_locations.end()); - io_service_.post( - [callback, client_id_vec, object_id]() { callback(client_id_vec, object_id); }); + // If we previously received some notifications about the object's locations, + // immediately notify the caller of the current known locations. + if (listener_state.has_been_created) { + auto &locations = listener_state.current_object_locations; + io_service_.post([callback, locations, object_id]() { + callback(object_id, locations, /*has_been_created=*/true); + }); + } return status; } @@ -176,19 +207,31 @@ ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, const OnLocationsFound &callback) { - JobID job_id = JobID::nil(); - ray::Status status = gcs_client_->object_table().Lookup( - job_id, object_id, - [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, - const std::vector &location_history) { - // Build the set of current locations based on the entries in the log. - std::unordered_set client_ids; - std::vector locations_vector = UpdateObjectLocations( - client_ids, location_history, gcs_client_->client_table()); - // It is safe to call the callback directly since this is already running - // in the GCS client's lookup callback stack. - callback(locations_vector, object_id); - }); + ray::Status status; + auto it = listeners_.find(object_id); + if (it == listeners_.end()) { + status = gcs_client_->object_table().Lookup( + JobID::nil(), object_id, + [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, + const std::vector &location_history) { + // Build the set of current locations based on the entries in the log. + std::unordered_set client_ids; + bool has_been_created = false; + UpdateObjectLocations(location_history, gcs_client_->client_table(), + &client_ids, &has_been_created); + // It is safe to call the callback directly since this is already running + // in the GCS client's lookup callback stack. + callback(object_id, client_ids, has_been_created); + }); + } else { + // If we have locations cached due to a concurrent SubscribeObjectLocations + // call, call the callback immediately with the cached locations. + auto &locations = it->second.current_object_locations; + bool has_been_created = it->second.has_been_created; + io_service_.post([callback, object_id, locations, has_been_created]() { + callback(object_id, locations, has_been_created); + }); + } return status; } diff --git a/src/ray/object_manager/object_directory.h b/src/ray/object_manager/object_directory.h index e36c4c41604ef..b44197b639ef5 100644 --- a/src/ray/object_manager/object_directory.h +++ b/src/ray/object_manager/object_directory.h @@ -48,8 +48,9 @@ class ObjectDirectoryInterface { virtual std::vector LookupAllRemoteConnections() const = 0; /// Callback for object location notifications. - using OnLocationsFound = std::function &, - const ray::ObjectID &object_id)>; + using OnLocationsFound = std::function &, + bool has_been_created)>; /// Lookup object locations. Callback may be invoked with empty list of client ids. /// @@ -59,6 +60,13 @@ class ObjectDirectoryInterface { virtual ray::Status LookupLocations(const ObjectID &object_id, const OnLocationsFound &callback) = 0; + /// Handle the removal of an object manager client. This updates the + /// locations of all subscribed objects that have the removed client as a + /// location, and fires the subscribed callbacks for those objects. + /// + /// \param client_id The object manager client that was removed. + virtual void HandleClientRemoved(const ClientID &client_id) = 0; + /// Subscribe to be notified of locations (ClientID) of the given object. /// The callback will be invoked with the complete list of known locations /// whenever the set of locations changes. The callback will also be fired if @@ -138,6 +146,8 @@ class ObjectDirectory : public ObjectDirectoryInterface { ray::Status LookupLocations(const ObjectID &object_id, const OnLocationsFound &callback) override; + void HandleClientRemoved(const ClientID &client_id) override; + ray::Status SubscribeObjectLocations(const UniqueID &callback_id, const ObjectID &object_id, const OnLocationsFound &callback) override; @@ -164,6 +174,12 @@ class ObjectDirectory : public ObjectDirectoryInterface { std::unordered_map callbacks; /// The current set of known locations of this object. std::unordered_set current_object_locations; + /// This flag will get set to true if the object has ever been created. It + /// should never go back to false once set to true. If this is true, and + /// the current_object_locations is empty, then this means that the object + /// does not exist on any nodes due to eviction (rather than due to the + /// object never getting created, for instance). + bool has_been_created; }; /// Reference to the event loop. diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index a3cc87c7f17ca..6afc5f3dc2c70 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -10,42 +10,11 @@ namespace ray { ObjectManager::ObjectManager(asio::io_service &main_service, const ObjectManagerConfig &config, - std::shared_ptr gcs_client) - // TODO(hme): Eliminate knowledge of GCS. - : client_id_(gcs_client->client_table().GetLocalClientId()), - config_(config), - object_directory_(new ObjectDirectory(main_service, gcs_client)), - store_notification_(main_service, config_.store_socket_name), - // release_delay of 2 * config_.max_sends is to ensure the pool does not release - // an object prematurely whenever we reach the maximum number of sends. - buffer_pool_(config_.store_socket_name, config_.object_chunk_size, - /*release_delay=*/2 * config_.max_sends), - send_work_(send_service_), - receive_work_(receive_service_), - connection_pool_(), - gen_(std::chrono::high_resolution_clock::now().time_since_epoch().count()) { - RAY_CHECK(config_.max_sends > 0); - RAY_CHECK(config_.max_receives > 0); - main_service_ = &main_service; - store_notification_.SubscribeObjAdded( - [this](const object_manager::protocol::ObjectInfoT &object_info) { - HandleObjectAdded(object_info); - }); - store_notification_.SubscribeObjDeleted( - [this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); }); - StartIOService(); -} - -ObjectManager::ObjectManager(asio::io_service &main_service, - const ObjectManagerConfig &config, - std::unique_ptr od) + std::shared_ptr object_directory) : config_(config), - object_directory_(std::move(od)), + object_directory_(std::move(object_directory)), store_notification_(main_service, config_.store_socket_name), - // release_delay of 2 * config_.max_sends is to ensure the pool does not release - // an object prematurely whenever we reach the maximum number of sends. - buffer_pool_(config_.store_socket_name, config_.object_chunk_size, - /*release_delay=*/2 * config_.max_sends), + buffer_pool_(config_.store_socket_name, config_.object_chunk_size), send_work_(send_service_), receive_work_(receive_service_), connection_pool_(), @@ -156,7 +125,8 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) { // no ordering guarantee between notifications. return object_directory_->SubscribeObjectLocations( object_directory_pull_callback_id_, object_id, - [this](const std::vector &client_ids, const ObjectID &object_id) { + [this](const ObjectID &object_id, const std::unordered_set &client_ids, + bool created) { // Exit if the Pull request has already been fulfilled or canceled. auto it = pull_requests_.find(object_id); if (it == pull_requests_.end()) { @@ -166,7 +136,8 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) { // NOTE(swang): Since we are overwriting the previous list of clients, // we may end up sending a duplicate request to the same client as // before. - it->second.client_locations = client_ids; + it->second.client_locations = + std::vector(client_ids.begin(), client_ids.end()); if (it->second.client_locations.empty()) { // The object locations are now empty, so we should wait for the next // notification about a new object location. Cancel the timer until @@ -591,8 +562,9 @@ ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) { // Lookup remaining objects. wait_state.requested_objects.insert(object_id); RAY_RETURN_NOT_OK(object_directory_->LookupLocations( - object_id, [this, wait_id](const std::vector &client_ids, - const ObjectID &lookup_object_id) { + object_id, + [this, wait_id](const ObjectID &lookup_object_id, + const std::unordered_set &client_ids, bool created) { auto &wait_state = active_wait_requests_.find(wait_id)->second; if (!client_ids.empty()) { wait_state.remaining.erase(lookup_object_id); @@ -624,8 +596,9 @@ void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) { wait_state.requested_objects.insert(object_id); // Subscribe to object notifications. RAY_CHECK_OK(object_directory_->SubscribeObjectLocations( - wait_id, object_id, [this, wait_id](const std::vector &client_ids, - const ObjectID &subscribe_object_id) { + wait_id, object_id, + [this, wait_id](const ObjectID &subscribe_object_id, + const std::unordered_set &client_ids, bool created) { if (!client_ids.empty()) { auto object_id_wait_state = active_wait_requests_.find(wait_id); if (object_id_wait_state == active_wait_requests_.end()) { diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index b2583376701e2..57cf27e8b77ba 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -70,25 +70,16 @@ class ObjectManagerInterface { // TODO(hme): Add success/failure callbacks for push and pull. class ObjectManager : public ObjectManagerInterface { public: - /// Implicitly instantiates Ray implementation of ObjectDirectory. - /// - /// \param main_service The main asio io_service. - /// \param config ObjectManager configuration. - /// \param gcs_client A client connection to the Ray GCS. - explicit ObjectManager(boost::asio::io_service &main_service, - const ObjectManagerConfig &config, - std::shared_ptr gcs_client); - /// Takes user-defined ObjectDirectoryInterface implementation. /// When this constructor is used, the ObjectManager assumes ownership of /// the given ObjectDirectory instance. /// /// \param main_service The main asio io_service. /// \param config ObjectManager configuration. - /// \param od An object implementing the object directory interface. + /// \param object_directory An object implementing the object directory interface. explicit ObjectManager(boost::asio::io_service &main_service, const ObjectManagerConfig &config, - std::unique_ptr od); + std::shared_ptr object_directory); ~ObjectManager(); @@ -363,7 +354,7 @@ class ObjectManager : public ObjectManagerInterface { ClientID client_id_; const ObjectManagerConfig config_; - std::unique_ptr object_directory_; + std::shared_ptr object_directory_; ObjectStoreNotificationManager store_notification_; ObjectBufferPool buffer_pool_; diff --git a/src/ray/object_manager/object_store_notification_manager.cc b/src/ray/object_manager/object_store_notification_manager.cc index e590e8efa549d..fce65a607dd4f 100644 --- a/src/ray/object_manager/object_store_notification_manager.cc +++ b/src/ray/object_manager/object_store_notification_manager.cc @@ -14,9 +14,12 @@ namespace ray { ObjectStoreNotificationManager::ObjectStoreNotificationManager( boost::asio::io_service &io_service, const std::string &store_socket_name) - : store_client_(), socket_(io_service) { - ARROW_CHECK_OK(store_client_.Connect(store_socket_name.c_str(), "", - plasma::kPlasmaDefaultReleaseDelay)); + : store_client_(), + length_(0), + num_adds_processed_(0), + num_removes_processed_(0), + socket_(io_service) { + ARROW_CHECK_OK(store_client_.Connect(store_socket_name.c_str())); ARROW_CHECK_OK(store_client_.Subscribe(&c_socket_)); boost::system::error_code ec; diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index c7f38b0c7f1d7..84e27e5ed9c51 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -33,7 +33,8 @@ class MockServer { main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)), object_manager_socket_(main_service), gcs_client_(gcs_client), - object_manager_(main_service, object_manager_config, gcs_client) { + object_manager_(main_service, object_manager_config, + std::make_shared(main_service, gcs_client_)) { RAY_CHECK_OK(RegisterGcs(main_service)); // Start listening for clients. DoAcceptObjectManager(); @@ -153,8 +154,8 @@ class TestObjectManagerBase : public ::testing::Test { server2.reset(new MockServer(main_service, om_config_2, gcs_client_2)); // connect to stores. - ARROW_CHECK_OK(client1.Connect(store_id_1, "", plasma::kPlasmaDefaultReleaseDelay)); - ARROW_CHECK_OK(client2.Connect(store_id_2, "", plasma::kPlasmaDefaultReleaseDelay)); + ARROW_CHECK_OK(client1.Connect(store_id_1)); + ARROW_CHECK_OK(client2.Connect(store_id_2)); } void TearDown() { diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index cb98706753d91..4c108f2d307c6 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -24,7 +24,8 @@ class MockServer { main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)), object_manager_socket_(main_service), gcs_client_(gcs_client), - object_manager_(main_service, object_manager_config, gcs_client) { + object_manager_(main_service, object_manager_config, + std::make_shared(main_service, gcs_client_)) { RAY_CHECK_OK(RegisterGcs(main_service)); // Start listening for clients. DoAcceptObjectManager(); @@ -138,8 +139,8 @@ class TestObjectManagerBase : public ::testing::Test { server2.reset(new MockServer(main_service, om_config_2, gcs_client_2)); // connect to stores. - ARROW_CHECK_OK(client1.Connect(store_id_1, "", plasma::kPlasmaDefaultReleaseDelay)); - ARROW_CHECK_OK(client2.Connect(store_id_2, "", plasma::kPlasmaDefaultReleaseDelay)); + ARROW_CHECK_OK(client1.Connect(store_id_1)); + ARROW_CHECK_OK(client2.Connect(store_id_2)); } void TearDown() { @@ -285,8 +286,9 @@ class TestObjectManager : public TestObjectManagerBase { RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( sub_id, object_1, - [this, sub_id, object_1, object_2](const std::vector &clients, - const ray::ObjectID &object_id) { + [this, sub_id, object_1, object_2]( + const ray::ObjectID &object_id, + const std::unordered_set &clients, bool created) { if (!clients.empty()) { TestWaitWhileSubscribed(sub_id, object_1, object_2); } @@ -449,7 +451,7 @@ class TestObjectManager : public TestObjectManagerBase { << "\n"; ClientTableDataT data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << (ClientID::from_binary(data.client_id) == ClientID::nil()); + RAY_LOG(DEBUG) << (ClientID::from_binary(data.client_id).is_nil()); RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::from_binary(data.client_id); RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address; RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port; diff --git a/src/ray/ray_config.h b/src/ray/ray_config.h index 480c976a5defa..4887026d04d9b 100644 --- a/src/ray/ray_config.h +++ b/src/ray/ray_config.h @@ -64,10 +64,6 @@ class RayConfig { return kill_worker_timeout_milliseconds_; } - int64_t manager_timeout_milliseconds() const { return manager_timeout_milliseconds_; } - - int64_t buf_size() const { return buf_size_; } - int64_t max_time_for_handler_milliseconds() const { return max_time_for_handler_milliseconds_; } @@ -154,10 +150,6 @@ class RayConfig { local_scheduler_fetch_request_size_ = pair.second; } else if (pair.first == "kill_worker_timeout_milliseconds") { kill_worker_timeout_milliseconds_ = pair.second; - } else if (pair.first == "manager_timeout_milliseconds") { - manager_timeout_milliseconds_ = pair.second; - } else if (pair.first == "buf_size") { - buf_size_ = pair.second; } else if (pair.first == "max_time_for_handler_milliseconds") { max_time_for_handler_milliseconds_ = pair.second; } else if (pair.first == "size_limit") { @@ -200,7 +192,7 @@ class RayConfig { : ray_protocol_version_(0x0000000000000000), handler_warning_timeout_ms_(100), heartbeat_timeout_milliseconds_(100), - num_heartbeats_timeout_(100), + num_heartbeats_timeout_(300), num_heartbeats_warning_(5), debug_dump_period_milliseconds_(10000), initial_reconstruction_timeout_milliseconds_(10000), @@ -216,8 +208,6 @@ class RayConfig { max_num_to_reconstruct_(10000), local_scheduler_fetch_request_size_(10000), kill_worker_timeout_milliseconds_(100), - manager_timeout_milliseconds_(1000), - buf_size_(80 * 1024), max_time_for_handler_milliseconds_(1000), size_limit_(10000), num_elements_limit_(10000), @@ -245,8 +235,7 @@ class RayConfig { /// warning is logged that the handler is taking too long. int64_t handler_warning_timeout_ms_; - /// The duration between heartbeats. These are sent by the plasma manager and - /// local scheduler. + /// The duration between heartbeats. These are sent by the raylet. int64_t heartbeat_timeout_milliseconds_; /// If a component has not sent a heartbeat in the last num_heartbeats_timeout /// heartbeat intervals, the global scheduler or monitor process will report @@ -306,10 +295,6 @@ class RayConfig { /// the worker SIGKILL. int64_t kill_worker_timeout_milliseconds_; - /// These are used by the plasma manager. - int64_t manager_timeout_milliseconds_; - int64_t buf_size_; - /// This is a timeout used to cause failures in the plasma manager and local /// scheduler when certain event loop handlers take too long. int64_t max_time_for_handler_milliseconds_; diff --git a/src/ray/raylet/CMakeLists.txt b/src/ray/raylet/CMakeLists.txt index 2faac13d02beb..5834b03006217 100644 --- a/src/ray/raylet/CMakeLists.txt +++ b/src/ray/raylet/CMakeLists.txt @@ -30,66 +30,64 @@ add_custom_command( add_dependencies(gen_node_manager_fbs flatbuffers_ep) -ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY}) - -ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) - -ADD_RAY_TEST(client_connection_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) -ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) -ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) -ADD_RAY_TEST(task_dependency_manager_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) -ADD_RAY_TEST(reconstruction_policy_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) +ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ${RAY_TEST_LIBS}) +ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ${RAY_TEST_LIBS}) +ADD_RAY_TEST(client_connection_test STATIC_LINK_LIBS ${RAY_TEST_LIBS}) +ADD_RAY_TEST(task_test STATIC_LINK_LIBS ${RAY_TEST_LIBS}) +ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ${RAY_TEST_LIBS}) +ADD_RAY_TEST(task_dependency_manager_test STATIC_LINK_LIBS ${RAY_TEST_LIBS}) +ADD_RAY_TEST(reconstruction_policy_test STATIC_LINK_LIBS ${RAY_TEST_LIBS}) include_directories(${GCS_FBS_OUTPUT_DIRECTORY}) add_library(rayletlib raylet.cc ${NODE_MANAGER_FBS_OUTPUT_FILES}) -target_link_libraries(rayletlib ray_static ${Boost_SYSTEM_LIBRARY}) +target_link_libraries(rayletlib ray_static ${Boost_SYSTEM_LIBRARY} ${Boost_THREAD_LIBRARY}) -add_library(local_scheduler_client STATIC local_scheduler_client.cc) +add_library(raylet_client STATIC raylet_client.cc) # Encode the fact that some things require some autogenerated flatbuffer files # to be created first. add_dependencies(rayletlib gen_gcs_fbs) -add_dependencies(local_scheduler_client gen_gcs_fbs arrow_ep gen_node_manager_fbs) +add_dependencies(raylet_client gen_gcs_fbs arrow_ep gen_node_manager_fbs) add_executable(raylet main.cc) -target_link_libraries(raylet rayletlib ${Boost_SYSTEM_LIBRARY} pthread) +target_link_libraries(raylet rayletlib ${RAY_SYSTEM_LIBS}) add_executable(raylet_monitor monitor_main.cc) -target_link_libraries(raylet_monitor rayletlib ${Boost_SYSTEM_LIBRARY} pthread) +target_link_libraries(raylet_monitor rayletlib ${Boost_SYSTEM_LIBRARY} ${Boost_THREAD_LIBRARY} pthread) install(FILES raylet DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/ray/raylet") -macro(get_local_scheduler_library LANG VAR) - set(${VAR} "local_scheduler_library_${LANG}") +macro(get_raylet_library LANG VAR) + set(${VAR} "raylet_library_${LANG}") endmacro() -macro(set_local_scheduler_library LANG) - get_local_scheduler_library(${LANG} LOCAL_SCHEDULER_LIBRARY_${LANG}) - set(LOCAL_SCHEDULER_LIBRARY_LANG ${LOCAL_SCHEDULER_LIBRARY_${LANG}}) +macro(set_raylet_library LANG) + get_raylet_library(${LANG} RAYLET_LIBRARY_${LANG}) + set(RAYLET_LIBRARY_LANG ${RAYLET_LIBRARY_${LANG}}) - file(GLOB LOCAL_SCHEDULER_LIBRARY_${LANG}_SRC + file(GLOB RAYLET_LIBRARY_${LANG}_SRC lib/${LANG}/*.cc) - add_library(${LOCAL_SCHEDULER_LIBRARY_LANG} SHARED - ${LOCAL_SCHEDULER_LIBRARY_${LANG}_SRC}) + add_library(${RAYLET_LIBRARY_LANG} SHARED + ${RAYLET_LIBRARY_${LANG}_SRC}) if(APPLE) if ("${LANG}" STREQUAL "python") - SET_TARGET_PROPERTIES(${LOCAL_SCHEDULER_LIBRARY_LANG} PROPERTIES SUFFIX .so) + SET_TARGET_PROPERTIES(${RAYLET_LIBRARY_LANG} PROPERTIES SUFFIX .so) endif() - target_link_libraries(${LOCAL_SCHEDULER_LIBRARY_LANG} "-undefined dynamic_lookup" local_scheduler_client ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY}) + target_link_libraries(${RAYLET_LIBRARY_LANG} "-undefined dynamic_lookup" raylet_client ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY} ${Boost_THREAD_LIBRARY}) else(APPLE) - target_link_libraries(${LOCAL_SCHEDULER_LIBRARY_LANG} local_scheduler_client ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY}) + target_link_libraries(${RAYLET_LIBRARY_LANG} raylet_client ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY} ${Boost_THREAD_LIBRARY}) endif(APPLE) - add_dependencies(${LOCAL_SCHEDULER_LIBRARY_LANG} gen_node_manager_fbs) + add_dependencies(${RAYLET_LIBRARY_LANG} gen_node_manager_fbs) - install(TARGETS ${LOCAL_SCHEDULER_LIBRARY_LANG} DESTINATION ${CMAKE_SOURCE_DIR}/local_scheduler) + install(TARGETS ${RAYLET_LIBRARY_LANG} DESTINATION ${CMAKE_SOURCE_DIR}/raylet) endmacro() if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - set_local_scheduler_library("python") + set_raylet_library("python") include_directories("${PYTHON_INCLUDE_DIRS}") include_directories("${NUMPY_INCLUDE_DIR}") endif() @@ -103,5 +101,5 @@ if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") else() # linux add_compile_options("-I$ENV{JAVA_HOME}/include/linux") endif() - set_local_scheduler_library("java") + set_raylet_library("java") endif() diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index 7ea95e6566426..5c9b322288811 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -9,9 +9,7 @@ namespace ray { namespace raylet { ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) - : actor_table_data_(actor_table_data), - execution_dependency_(ObjectID::nil()), - frontier_() {} + : actor_table_data_(actor_table_data) {} const ClientID ActorRegistration::GetNodeManagerId() const { return ClientID::from_binary(actor_table_data_.node_manager_id); @@ -25,6 +23,18 @@ const ObjectID ActorRegistration::GetExecutionDependency() const { return execution_dependency_; } +const DriverID ActorRegistration::GetDriverId() const { + return DriverID::from_binary(actor_table_data_.driver_id); +} + +const int64_t ActorRegistration::GetMaxReconstructions() const { + return actor_table_data_.max_reconstructions; +} + +const int64_t ActorRegistration::GetRemainingReconstructions() const { + return actor_table_data_.remaining_reconstructions; +} + const std::unordered_map &ActorRegistration::GetFrontier() const { return frontier_; @@ -39,10 +49,6 @@ void ActorRegistration::ExtendFrontier(const ActorHandleID &handle_id, dummy_objects_.push_back(execution_dependency); } -bool ActorRegistration::IsAlive() const { - return actor_table_data_.state == ActorState::ALIVE; -} - int ActorRegistration::NumHandles() const { return frontier_.size(); } } // namespace raylet diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 4cf9b110afe12..9c4664455b990 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -46,6 +46,9 @@ class ActorRegistration { /// \return The actor's current state. const ActorState &GetState() const { return actor_table_data_.state; } + /// Update actor's state. + void SetState(const ActorState &state) { actor_table_data_.state = state; } + /// Get the actor's node manager location. /// /// \return The actor's node manager location. All tasks for the actor should @@ -59,6 +62,15 @@ class ActorRegistration { /// \return The execution dependency returned by the actor's creation task. const ObjectID GetActorCreationDependency() const; + /// Get actor's driver ID. + const DriverID GetDriverId() const; + + /// Get the max number of times this actor should be reconstructed. + const int64_t GetMaxReconstructions() const; + + /// Get the remaining number of times this actor should be reconstructed. + const int64_t GetRemainingReconstructions() const; + /// Get the object that represents the actor's current state. This is the /// execution dependency returned by the task most recently executed on the /// actor. The next task to execute on the actor should be marked as @@ -88,12 +100,6 @@ class ActorRegistration { void ExtendFrontier(const ActorHandleID &handle_id, const ObjectID &execution_dependency); - /// Return whether the actor is alive or not. This should only be called on - /// local actors. - /// - /// \return True if the local actor is alive and false if it is dead. - bool IsAlive() const; - /// Returns num handles to this actor entry. /// /// \return int. @@ -111,6 +117,7 @@ class ActorRegistration { /// executed so far and which tasks may execute next, based on execution /// dependencies. This is indexed by handle. std::unordered_map frontier_; + /// All of the dummy object IDs from this actor's tasks. std::vector dummy_objects_; }; diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 1e62202d79129..7725b9bc7c60b 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -118,7 +118,7 @@ table GetTaskReply { } // This struct is used to register a new worker with the local scheduler. -// It is shipped as part of local_scheduler_connect. +// It is shipped as part of raylet_connect. table RegisterClientRequest { // True if the client is a worker and false if the client is a driver. is_worker: bool; diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 12388d181ff31..212f91a84f61b 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -3,7 +3,7 @@ #include #include "ray/id.h" -#include "ray/raylet/local_scheduler_client.h" +#include "ray/raylet/raylet_client.h" #include "ray/util/logging.h" #ifdef __cplusplus @@ -42,10 +42,10 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( UniqueIdFromJByteArray worker_id(env, workerId); UniqueIdFromJByteArray driver_id(env, driverId); const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); - auto client = LocalSchedulerConnection_init(nativeString, *worker_id.PID, isWorker, - *driver_id.PID, Language::JAVA); + auto raylet_client = new RayletClient(nativeString, *worker_id.PID, isWorker, + *driver_id.PID, Language::JAVA); env->ReleaseStringUTFChars(sockName, nativeString); - return reinterpret_cast(client); + return reinterpret_cast(raylet_client); } /* @@ -56,7 +56,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask( JNIEnv *env, jclass, jlong client, jbyteArray cursorId, jobject taskBuff, jint pos, jint taskSize) { - auto conn = reinterpret_cast(client); + auto raylet_client = reinterpret_cast(client); std::vector execution_dependencies; if (cursorId != nullptr) { @@ -66,7 +66,8 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit auto data = reinterpret_cast(env->GetDirectBufferAddress(taskBuff)) + pos; ray::raylet::TaskSpecification task_spec(std::string(data, taskSize)); - local_scheduler_submit_raylet(conn, execution_dependencies, task_spec); + auto status = raylet_client->SubmitTask(execution_dependencies, task_spec); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to submit a task to raylet."); } /* @@ -76,10 +77,12 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit */ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask( JNIEnv *env, jclass, jlong client) { - auto conn = reinterpret_cast(client); + auto raylet_client = reinterpret_cast(client); // TODO: handle actor failure later - ray::raylet::TaskSpecification *spec = local_scheduler_get_task_raylet(conn); + std::unique_ptr spec; + auto status = raylet_client->GetTask(&spec); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to get a task from raylet."); // We serialize the task specification using flatbuffers and then parse the // resulting string. This awkwardness is due to the fact that the Java @@ -100,7 +103,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native result, 0, task_message->size(), reinterpret_cast(const_cast(task_message->data()))); - delete spec; return result; } @@ -111,17 +113,18 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_native */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy( JNIEnv *, jclass, jlong client) { - auto conn = reinterpret_cast(client); - local_scheduler_disconnect_client(conn); - LocalSchedulerConnection_free(conn); + auto raylet_client = reinterpret_cast(client); + RAY_CHECK_OK_PREPEND(raylet_client->Disconnect(), + "[RayletClient] Failed to disconnect."); + delete raylet_client; } /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeFetchOrReconstruct - * Signature: (J[[BZ)V + * Signature: (J[[BZ[B)I */ -JNIEXPORT void JNICALL +JNIEXPORT jint JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean fetchOnly, jbyteArray currentTaskId) { @@ -135,8 +138,10 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( env->DeleteLocalRef(object_id_bytes); } UniqueIdFromJByteArray current_task_id(env, currentTaskId); - auto conn = reinterpret_cast(client); - local_scheduler_fetch_or_reconstruct(conn, object_ids, fetchOnly, *current_task_id.PID); + auto raylet_client = reinterpret_cast(client); + auto status = + raylet_client->FetchOrReconstruct(object_ids, fetchOnly, *current_task_id.PID); + return static_cast(status.code()); } /* @@ -147,8 +152,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) { UniqueIdFromJByteArray current_task_id(env, currentTaskId); - auto conn = reinterpret_cast(client); - local_scheduler_notify_unblocked(conn, *current_task_id.PID); + auto raylet_client = reinterpret_cast(client); + auto status = raylet_client->NotifyUnblocked(*current_task_id.PID); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to notify unblocked."); } /* @@ -171,12 +177,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( } UniqueIdFromJByteArray current_task_id(env, currentTaskId); - auto conn = reinterpret_cast(client); + auto raylet_client = reinterpret_cast(client); // Invoke wait. - std::pair, std::vector> result = - local_scheduler_wait(conn, object_ids, numReturns, timeoutMillis, - static_cast(isWaitLocal), *current_task_id.PID); + WaitResultPair result; + auto status = + raylet_client->Wait(object_ids, numReturns, timeoutMillis, + static_cast(isWaitLocal), *current_task_id.PID, &result); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to wait for objects."); // Convert result to java object. jboolean put_value = true; @@ -245,8 +253,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( object_ids.push_back(*object_id.PID); env->DeleteLocalRef(object_id_bytes); } - auto conn = reinterpret_cast(client); - local_scheduler_free_objects_in_object_store(conn, object_ids, localOnly); + auto raylet_client = reinterpret_cast(client); + auto status = raylet_client->FreeObjects(object_ids, localOnly); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to free objects."); } #ifdef __cplusplus diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index 8940046cea9d0..0d5d0b9cbb513 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -42,9 +42,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(JNIEnv *, jclass, jlo /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeFetchOrReconstruct - * Signature: (J[[BZ)V + * Signature: (J[[BZ[B)I */ -JNIEXPORT void JNICALL +JNIEXPORT jint JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct(JNIEnv *, jclass, jlong, jobjectArray, jboolean, diff --git a/src/ray/raylet/lib/python/common_extension.cc b/src/ray/raylet/lib/python/common_extension.cc index f4979620c3195..d986427cdd14a 100644 --- a/src/ray/raylet/lib/python/common_extension.cc +++ b/src/ray/raylet/lib/python/common_extension.cc @@ -84,6 +84,32 @@ int PyObjectToUniqueID(PyObject *object, ObjectID *objectid) { } } +int PyListStringToStringVector(PyObject *object, + std::vector *function_descriptor) { + if (function_descriptor == nullptr) { + PyErr_SetString(PyExc_TypeError, "function descriptor must be non-empty pointer"); + return 0; + } + function_descriptor->clear(); + std::vector string_vector; + if (PyList_Check(object)) { + Py_ssize_t size = PyList_Size(object); + for (Py_ssize_t i = 0; i < size; ++i) { + PyObject *item = PyList_GetItem(object, i); + if (PyBytes_Check(item) == 0) { + PyErr_SetString(PyExc_TypeError, + "PyListStringToStringVector takes a list of byte strings."); + return 0; + } + function_descriptor->emplace_back(PyBytes_AsString(item), PyBytes_Size(item)); + } + return 1; + } else { + PyErr_SetString(PyExc_TypeError, "must be a list of strings"); + return 0; + } +} + static int PyObjectID_init(PyObjectID *self, PyObject *args, PyObject *kwds) { const char *data; int size; @@ -117,7 +143,7 @@ TaskSpec *TaskSpec_copy(TaskSpec *spec, int64_t task_spec_size) { * * This is called from Python like * - * task = local_scheduler.task_from_string("...") + * task = raylet.task_from_string("...") * * @param task_string String representation of the task specification. * @return Python task specification object. @@ -142,7 +168,7 @@ PyObject *PyTask_from_string(PyObject *self, PyObject *args) { * * This is called from Python like * - * s = local_scheduler.task_to_string(task) + * s = raylet.task_to_string(task) * * @param task Ray task specification Python object. * @return String representing the task specification. @@ -358,25 +384,28 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { // ID of the driver that this task originates from. UniqueID driver_id; // ID of the actor this task should run on. - UniqueID actor_id = ActorID::nil(); + UniqueID actor_id; // ID of the actor handle used to submit this task. - UniqueID actor_handle_id = ActorHandleID::nil(); + UniqueID actor_handle_id; // How many tasks have been launched on the actor so far? int actor_counter = 0; - // ID of the function this task executes. - FunctionID function_id; // Arguments of the task (can be PyObjectIDs or Python values). PyObject *arguments; // Number of return values of this task. int num_returns; + // Task language type enum number. + int language = static_cast(Language::PYTHON); // The ID of the task that called this task. TaskID parent_task_id; // The number of tasks that the parent task has called prior to this one. int parent_counter; // The actor creation ID. - ActorID actor_creation_id = ActorID::nil(); + ActorID actor_creation_id; // The dummy object for the actor creation task (if this is an actor method). - ObjectID actor_creation_dummy_object_id = ObjectID::nil(); + ObjectID actor_creation_dummy_object_id; + // Max number of times to reconstruct this actor (only used for actor creation + // task). + int32_t max_actor_reconstructions; // Arguments of the task that are execution-dependent. These must be // PyObjectIDs). PyObject *execution_arguments = nullptr; @@ -384,13 +413,17 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { PyObject *resource_map = nullptr; // Dictionary of required placement resources for this task. PyObject *placement_resource_map = nullptr; - if (!PyArg_ParseTuple(args, "O&O&OiO&i|O&O&O&O&iOOO", &PyObjectToUniqueID, &driver_id, - &PyObjectToUniqueID, &function_id, &arguments, &num_returns, - &PyObjectToUniqueID, &parent_task_id, &parent_counter, - &PyObjectToUniqueID, &actor_creation_id, &PyObjectToUniqueID, - &actor_creation_dummy_object_id, &PyObjectToUniqueID, &actor_id, - &PyObjectToUniqueID, &actor_handle_id, &actor_counter, - &execution_arguments, &resource_map, &placement_resource_map)) { + + // Function descriptor. + std::vector function_descriptor; + if (!PyArg_ParseTuple( + args, "O&O&OiO&i|O&O&iO&O&iOOOi", &PyObjectToUniqueID, &driver_id, + &PyListStringToStringVector, &function_descriptor, &arguments, &num_returns, + &PyObjectToUniqueID, &parent_task_id, &parent_counter, &PyObjectToUniqueID, + &actor_creation_id, &PyObjectToUniqueID, &actor_creation_dummy_object_id, + &max_actor_reconstructions, &PyObjectToUniqueID, &actor_id, &PyObjectToUniqueID, + &actor_handle_id, &actor_counter, &execution_arguments, &resource_map, + &placement_resource_map, &language)) { return -1; } @@ -420,6 +453,7 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { self->task_spec = nullptr; // Create the task spec. + // Parse the arguments from the list. std::vector> task_args; for (Py_ssize_t i = 0; i < num_args; ++i) { @@ -439,9 +473,9 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { self->task_spec = new ray::raylet::TaskSpecification( driver_id, parent_task_id, parent_counter, actor_creation_id, - actor_creation_dummy_object_id, actor_id, actor_handle_id, actor_counter, - function_id, task_args, num_returns, required_resources, - required_placement_resources, Language::PYTHON); + actor_creation_dummy_object_id, max_actor_reconstructions, actor_id, + actor_handle_id, actor_counter, task_args, num_returns, required_resources, + required_placement_resources, Language::PYTHON, function_descriptor); /* Set the task's execution dependencies. */ self->execution_dependencies = new std::vector(); @@ -466,9 +500,23 @@ static void PyTask_dealloc(PyTask *self) { Py_TYPE(self)->tp_free(reinterpret_cast(self)); } -static PyObject *PyTask_function_id(PyTask *self) { - FunctionID function_id = self->task_spec->FunctionId(); - return PyObjectID_make(function_id); +// Helper function to change a c++ string vector to a Python string list. +static PyObject *VectorStringToPyBytesList( + const std::vector &function_descriptor) { + size_t size = function_descriptor.size(); + PyObject *return_list = PyList_New(static_cast(size)); + for (size_t i = 0; i < size; ++i) { + auto py_bytes = PyBytes_FromStringAndSize(function_descriptor[i].data(), + function_descriptor[i].size()); + PyList_SetItem(return_list, i, py_bytes); + } + return return_list; +} + +static PyObject *PyTask_function_descriptor_vector(PyTask *self) { + std::vector function_descriptor; + function_descriptor = self->task_spec->FunctionDescriptor(); + return VectorStringToPyBytesList(function_descriptor); } static PyObject *PyTask_actor_id(PyTask *self) { @@ -593,8 +641,8 @@ static PyObject *PyTask_to_serialized_flatbuf(PyTask *self) { } static PyMethodDef PyTask_methods[] = { - {"function_id", (PyCFunction)PyTask_function_id, METH_NOARGS, - "Return the function ID for this task."}, + {"function_descriptor_list", (PyCFunction)PyTask_function_descriptor_vector, + METH_NOARGS, "Return the function descriptor for this task."}, {"parent_task_id", (PyCFunction)PyTask_parent_task_id, METH_NOARGS, "Return the task ID of the parent task."}, {"parent_counter", (PyCFunction)PyTask_parent_counter, METH_NOARGS, diff --git a/src/ray/raylet/lib/python/config_extension.cc b/src/ray/raylet/lib/python/config_extension.cc index 06b0a032ad435..3431641d22a50 100644 --- a/src/ray/raylet/lib/python/config_extension.cc +++ b/src/ray/raylet/lib/python/config_extension.cc @@ -69,18 +69,6 @@ PyObject *PyRayConfig_kill_worker_timeout_milliseconds(PyObject *self) { return PyLong_FromLongLong(RayConfig::instance().kill_worker_timeout_milliseconds()); } -PyObject *PyRayConfig_manager_timeout_milliseconds(PyObject *self) { - return PyLong_FromLongLong(RayConfig::instance().manager_timeout_milliseconds()); -} - -PyObject *PyRayConfig_buf_size(PyObject *self) { - return PyLong_FromLongLong(RayConfig::instance().buf_size()); -} - -PyObject *PyRayConfig_max_time_for_handler_milliseconds(PyObject *self) { - return PyLong_FromLongLong(RayConfig::instance().max_time_for_handler_milliseconds()); -} - PyObject *PyRayConfig_size_limit(PyObject *self) { return PyLong_FromLongLong(RayConfig::instance().size_limit()); } @@ -144,13 +132,6 @@ static PyMethodDef PyRayConfig_methods[] = { {"kill_worker_timeout_milliseconds", (PyCFunction)PyRayConfig_kill_worker_timeout_milliseconds, METH_NOARGS, "Return kill_worker_timeout_milliseconds"}, - {"manager_timeout_milliseconds", - (PyCFunction)PyRayConfig_manager_timeout_milliseconds, METH_NOARGS, - "Return manager_timeout_milliseconds"}, - {"buf_size", (PyCFunction)PyRayConfig_buf_size, METH_NOARGS, "Return buf_size"}, - {"max_time_for_handler_milliseconds", - (PyCFunction)PyRayConfig_max_time_for_handler_milliseconds, METH_NOARGS, - "Return max_time_for_handler_milliseconds"}, {"size_limit", (PyCFunction)PyRayConfig_size_limit, METH_NOARGS, "Return size_limit"}, {"num_elements_limit", (PyCFunction)PyRayConfig_num_elements_limit, METH_NOARGS, "Return num_elements_limit"}, diff --git a/src/ray/raylet/lib/python/config_extension.h b/src/ray/raylet/lib/python/config_extension.h index 3cd4f56afc4c6..182158e9bd8e4 100644 --- a/src/ray/raylet/lib/python/config_extension.h +++ b/src/ray/raylet/lib/python/config_extension.h @@ -28,9 +28,6 @@ PyObject *PyRayConfig_local_scheduler_reconstruction_timeout_milliseconds(PyObje PyObject *PyRayConfig_max_num_to_reconstruct(PyObject *self); PyObject *PyRayConfig_local_scheduler_fetch_request_size(PyObject *self); PyObject *PyRayConfig_kill_worker_timeout_milliseconds(PyObject *self); -PyObject *PyRayConfig_manager_timeout_milliseconds(PyObject *self); -PyObject *PyRayConfig_buf_size(PyObject *self); -PyObject *PyRayConfig_max_time_for_handler_milliseconds(PyObject *self); PyObject *PyRayConfig_size_limit(PyObject *self); PyObject *PyRayConfig_num_elements_limit(PyObject *self); PyObject *PyRayConfig_max_time_for_loop(PyObject *self); diff --git a/src/ray/raylet/lib/python/local_scheduler_extension.cc b/src/ray/raylet/lib/python/raylet_extension.cc similarity index 55% rename from src/ray/raylet/lib/python/local_scheduler_extension.cc rename to src/ray/raylet/lib/python/raylet_extension.cc index 05d24cdb48439..c5f7eafd59107 100644 --- a/src/ray/raylet/lib/python/local_scheduler_extension.cc +++ b/src/ray/raylet/lib/python/raylet_extension.cc @@ -1,79 +1,75 @@ #include +#include #include "common_extension.h" #include "config_extension.h" -#include "ray/raylet/local_scheduler_client.h" +#include "ray/raylet/raylet_client.h" PyObject *LocalSchedulerError; // clang-format off typedef struct { PyObject_HEAD - LocalSchedulerConnection *local_scheduler_connection; -} PyLocalSchedulerClient; + RayletClient *raylet_client; +} PyRayletClient; // clang-format on -static int PyLocalSchedulerClient_init(PyLocalSchedulerClient *self, PyObject *args, - PyObject *kwds) { +static int PyRayletClient_init(PyRayletClient *self, PyObject *args, PyObject *kwds) { char *socket_name; UniqueID client_id; PyObject *is_worker; JobID driver_id; if (!PyArg_ParseTuple(args, "sO&OO&", &socket_name, PyStringToUniqueID, &client_id, &is_worker, &PyObjectToUniqueID, &driver_id)) { - self->local_scheduler_connection = NULL; + self->raylet_client = NULL; return -1; } /* Connect to the local scheduler. */ - self->local_scheduler_connection = LocalSchedulerConnection_init( - socket_name, client_id, static_cast(PyObject_IsTrue(is_worker)), driver_id, - Language::PYTHON); + self->raylet_client = new RayletClient(socket_name, client_id, + static_cast(PyObject_IsTrue(is_worker)), + driver_id, Language::PYTHON); return 0; } -static void PyLocalSchedulerClient_dealloc(PyLocalSchedulerClient *self) { - if (self->local_scheduler_connection != NULL) { - LocalSchedulerConnection_free(self->local_scheduler_connection); +static void PyRayletClient_dealloc(PyRayletClient *self) { + if (self->raylet_client != NULL) { + delete self->raylet_client; } Py_TYPE(self)->tp_free((PyObject *)self); } -static PyObject *PyLocalSchedulerClient_disconnect(PyObject *self) { - local_scheduler_disconnect_client( - ((PyLocalSchedulerClient *)self)->local_scheduler_connection); +static PyObject *PyRayletClient_Disconnect(PyRayletClient *self) { + auto status = self->raylet_client->Disconnect(); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to disconnect."); Py_RETURN_NONE; } -static PyObject *PyLocalSchedulerClient_submit(PyObject *self, PyObject *args) { +static PyObject *PyRayletClient_SubmitTask(PyRayletClient *self, PyObject *args) { PyObject *py_task; if (!PyArg_ParseTuple(args, "O", &py_task)) { return NULL; } - LocalSchedulerConnection *connection = - reinterpret_cast(self)->local_scheduler_connection; PyTask *task = reinterpret_cast(py_task); - - local_scheduler_submit_raylet(connection, *task->execution_dependencies, - *task->task_spec); - + auto status = + self->raylet_client->SubmitTask(*task->execution_dependencies, *task->task_spec); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to submit a task to raylet."); Py_RETURN_NONE; } // clang-format off -static PyObject *PyLocalSchedulerClient_get_task(PyObject *self) { - ray::raylet::TaskSpecification *task_spec; +static PyObject *PyRayletClient_GetTask(PyRayletClient *self) { + std::unique_ptr task_spec; /* Drop the global interpreter lock while we get a task because - * local_scheduler_get_task may block for a long time. */ + * raylet_GetTask may block for a long time. */ Py_BEGIN_ALLOW_THREADS - task_spec = local_scheduler_get_task_raylet( - reinterpret_cast(self)->local_scheduler_connection); + auto status = self->raylet_client->GetTask(&task_spec); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to get a task from raylet."); Py_END_ALLOW_THREADS - return PyTask_make(task_spec); + return PyTask_make(task_spec.release()); } // clang-format on -static PyObject *PyLocalSchedulerClient_fetch_or_reconstruct(PyObject *self, - PyObject *args) { +static PyObject *PyRayletClient_FetchOrReconstruct(PyRayletClient *self, PyObject *args) { PyObject *py_object_ids; PyObject *py_fetch_only; std::vector object_ids; @@ -92,23 +88,31 @@ static PyObject *PyLocalSchedulerClient_fetch_or_reconstruct(PyObject *self, } object_ids.push_back(object_id); } - local_scheduler_fetch_or_reconstruct( - reinterpret_cast(self)->local_scheduler_connection, - object_ids, fetch_only, current_task_id); - Py_RETURN_NONE; + auto status = + self->raylet_client->FetchOrReconstruct(object_ids, fetch_only, current_task_id); + if (status.ok()) { + Py_RETURN_NONE; + } else { + std::ostringstream stream; + stream << "[RayletClient] FetchOrReconstruct failed: " + << "raylet client may be closed, check raylet status. error message: " + << status.ToString(); + PyErr_SetString(CommonError, stream.str().c_str()); + return NULL; + } } -static PyObject *PyLocalSchedulerClient_notify_unblocked(PyObject *self, PyObject *args) { +static PyObject *PyRayletClient_NotifyUnblocked(PyRayletClient *self, PyObject *args) { TaskID current_task_id; if (!PyArg_ParseTuple(args, "O&", &PyObjectToUniqueID, ¤t_task_id)) { return NULL; } - local_scheduler_notify_unblocked( - ((PyLocalSchedulerClient *)self)->local_scheduler_connection, current_task_id); + auto status = self->raylet_client->NotifyUnblocked(current_task_id); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to notify unblocked."); Py_RETURN_NONE; } -static PyObject *PyLocalSchedulerClient_compute_put_id(PyObject *self, PyObject *args) { +static PyObject *PyRayletClient_compute_put_id(PyObject *self, PyObject *args) { int put_index; TaskID task_id; if (!PyArg_ParseTuple(args, "O&i", &PyObjectToUniqueID, &task_id, &put_index)) { @@ -118,25 +122,10 @@ static PyObject *PyLocalSchedulerClient_compute_put_id(PyObject *self, PyObject return PyObjectID_make(put_id); } -static PyObject *PyLocalSchedulerClient_gpu_ids(PyObject *self) { - /* Construct a Python list of GPU IDs. */ - std::vector gpu_ids = - ((PyLocalSchedulerClient *)self)->local_scheduler_connection->gpu_ids; - int num_gpu_ids = gpu_ids.size(); - PyObject *gpu_ids_list = PyList_New((Py_ssize_t)num_gpu_ids); - for (int i = 0; i < num_gpu_ids; ++i) { - PyList_SetItem(gpu_ids_list, i, PyLong_FromLong(gpu_ids[i])); - } - return gpu_ids_list; -} - -// NOTE(rkn): This function only makes sense for the raylet code path. -static PyObject *PyLocalSchedulerClient_resource_ids(PyObject *self) { +static PyObject *PyRayletClient_resource_ids(PyRayletClient *self) { // Construct a Python dictionary of resource IDs and resource fractions. PyObject *resource_ids = PyDict_New(); - - for (auto const &resource_info : reinterpret_cast(self) - ->local_scheduler_connection->resource_ids_) { + for (auto const &resource_info : self->raylet_client->GetResourceIDs()) { auto const &resource_name = resource_info.first; auto const &ids_and_fractions = resource_info.second; @@ -161,7 +150,7 @@ static PyObject *PyLocalSchedulerClient_resource_ids(PyObject *self) { return resource_ids; } -static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) { +static PyObject *PyRayletClient_Wait(PyRayletClient *self, PyObject *args) { PyObject *py_object_ids; int num_returns; int64_t timeout_ms; @@ -195,9 +184,10 @@ static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) { } // Invoke wait. - std::pair, std::vector> result = local_scheduler_wait( - reinterpret_cast(self)->local_scheduler_connection, - object_ids, num_returns, timeout_ms, wait_local, current_task_id); + WaitResultPair result; + auto status = self->raylet_client->Wait(object_ids, num_returns, timeout_ms, wait_local, + current_task_id, &result); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to wait for objects."); // Convert result to py object. PyObject *py_found = PyList_New(static_cast(result.first.size())); @@ -211,7 +201,7 @@ static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) { return Py_BuildValue("(OO)", py_found, py_remaining); } -static PyObject *PyLocalSchedulerClient_push_error(PyObject *self, PyObject *args) { +static PyObject *PyRayletClient_PushError(PyRayletClient *self, PyObject *args) { JobID job_id; const char *type; int type_length; @@ -224,11 +214,10 @@ static PyObject *PyLocalSchedulerClient_push_error(PyObject *self, PyObject *arg return NULL; } - local_scheduler_push_error( - reinterpret_cast(self)->local_scheduler_connection, + auto status = self->raylet_client->PushError( job_id, std::string(type, type_length), std::string(error_message, error_message_length), timestamp); - + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to push errors to raylet."); Py_RETURN_NONE; } @@ -248,8 +237,7 @@ int PyBytes_or_PyUnicode_to_string(PyObject *py_string, std::string &out) { return 0; } -static PyObject *PyLocalSchedulerClient_push_profile_events(PyObject *self, - PyObject *args) { +static PyObject *PyRayletClient_PushProfileEvents(PyRayletClient *self, PyObject *args) { const char *component_type; int component_type_length; UniqueID component_id; @@ -321,14 +309,12 @@ static PyObject *PyLocalSchedulerClient_push_profile_events(PyObject *self, profile_info.profile_events.emplace_back(new ProfileEventT(profile_event)); } - local_scheduler_push_profile_events( - reinterpret_cast(self)->local_scheduler_connection, - profile_info); - + auto status = self->raylet_client->PushProfileEvents(profile_info); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to push profile events to raylet."); Py_RETURN_NONE; } -static PyObject *PyLocalSchedulerClient_free(PyObject *self, PyObject *args) { +static PyObject *PyRayletClient_FreeObjects(PyRayletClient *self, PyObject *args) { PyObject *py_object_ids; PyObject *py_local_only; @@ -357,83 +343,80 @@ static PyObject *PyLocalSchedulerClient_free(PyObject *self, PyObject *args) { object_ids.push_back(object_id); } - // Invoke local_scheduler_free_objects_in_object_store. - local_scheduler_free_objects_in_object_store( - reinterpret_cast(self)->local_scheduler_connection, - object_ids, local_only); + // Invoke raylet_FreeObjects. + auto status = self->raylet_client->FreeObjects(object_ids, local_only); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Failed to free objects."); Py_RETURN_NONE; } -static PyMethodDef PyLocalSchedulerClient_methods[] = { - {"disconnect", (PyCFunction)PyLocalSchedulerClient_disconnect, METH_NOARGS, +static PyMethodDef PyRayletClient_methods[] = { + {"disconnect", (PyCFunction)PyRayletClient_Disconnect, METH_NOARGS, "Notify the local scheduler that this client is exiting gracefully."}, - {"submit", (PyCFunction)PyLocalSchedulerClient_submit, METH_VARARGS, + {"submit_task", (PyCFunction)PyRayletClient_SubmitTask, METH_VARARGS, "Submit a task to the local scheduler."}, - {"get_task", (PyCFunction)PyLocalSchedulerClient_get_task, METH_NOARGS, + {"get_task", (PyCFunction)PyRayletClient_GetTask, METH_NOARGS, "Get a task from the local scheduler."}, - {"fetch_or_reconstruct", (PyCFunction)PyLocalSchedulerClient_fetch_or_reconstruct, - METH_VARARGS, "Ask the local scheduler to reconstruct an object."}, - {"notify_unblocked", (PyCFunction)PyLocalSchedulerClient_notify_unblocked, - METH_VARARGS, "Notify the local scheduler that we are unblocked."}, - {"compute_put_id", (PyCFunction)PyLocalSchedulerClient_compute_put_id, METH_VARARGS, + {"fetch_or_reconstruct", (PyCFunction)PyRayletClient_FetchOrReconstruct, METH_VARARGS, + "Ask the local scheduler to reconstruct an object."}, + {"notify_unblocked", (PyCFunction)PyRayletClient_NotifyUnblocked, METH_VARARGS, + "Notify the local scheduler that we are unblocked."}, + {"compute_put_id", (PyCFunction)PyRayletClient_compute_put_id, METH_VARARGS, "Return the object ID for a put call within a task."}, - {"gpu_ids", (PyCFunction)PyLocalSchedulerClient_gpu_ids, METH_NOARGS, - "Get the IDs of the GPUs that are reserved for this client."}, - {"resource_ids", (PyCFunction)PyLocalSchedulerClient_resource_ids, METH_NOARGS, + {"resource_ids", (PyCFunction)PyRayletClient_resource_ids, METH_NOARGS, "Get the IDs of the resources that are reserved for this client."}, - {"wait", (PyCFunction)PyLocalSchedulerClient_wait, METH_VARARGS, + {"wait", (PyCFunction)PyRayletClient_Wait, METH_VARARGS, "Wait for a list of objects to be created."}, - {"push_error", (PyCFunction)PyLocalSchedulerClient_push_error, METH_VARARGS, + {"push_error", (PyCFunction)PyRayletClient_PushError, METH_VARARGS, "Push an error message to the relevant driver."}, - {"push_profile_events", (PyCFunction)PyLocalSchedulerClient_push_profile_events, - METH_VARARGS, "Store some profiling events in the GCS."}, - {"free", (PyCFunction)PyLocalSchedulerClient_free, METH_VARARGS, + {"push_profile_events", (PyCFunction)PyRayletClient_PushProfileEvents, METH_VARARGS, + "Store some profiling events in the GCS."}, + {"free_objects", (PyCFunction)PyRayletClient_FreeObjects, METH_VARARGS, "Free a list of objects from object stores."}, {NULL} /* Sentinel */ }; -static PyTypeObject PyLocalSchedulerClientType = { - PyVarObject_HEAD_INIT(NULL, 0) /* ob_size */ - "local_scheduler.LocalSchedulerClient", /* tp_name */ - sizeof(PyLocalSchedulerClient), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)PyLocalSchedulerClient_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - "LocalSchedulerClient object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - PyLocalSchedulerClient_methods, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)PyLocalSchedulerClient_init, /* tp_init */ - 0, /* tp_alloc */ - PyType_GenericNew, /* tp_new */ +static PyTypeObject PyRayletClientType = { + PyVarObject_HEAD_INIT(NULL, 0) /* ob_size */ + "raylet.RayletClient", /* tp_name */ + sizeof(PyRayletClient), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)PyRayletClient_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "RayletClient object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + PyRayletClient_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)PyRayletClient_init, /* tp_init */ + 0, /* tp_alloc */ + PyType_GenericNew, /* tp_new */ }; -static PyMethodDef local_scheduler_methods[] = { +static PyMethodDef raylet_methods[] = { {"check_simple_value", check_simple_value, METH_VARARGS, "Should the object be passed by value?"}, {"compute_task_id", compute_task_id, METH_VARARGS, @@ -449,14 +432,14 @@ static PyMethodDef local_scheduler_methods[] = { #if PY_MAJOR_VERSION >= 3 static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, - "liblocal_scheduler", /* m_name */ - "A module for the local scheduler.", /* m_doc */ - 0, /* m_size */ - local_scheduler_methods, /* m_methods */ - NULL, /* m_reload */ - NULL, /* m_traverse */ - NULL, /* m_clear */ - NULL, /* m_free */ + "libraylet", /* m_name */ + "A module for the raylet.", /* m_doc */ + 0, /* m_size */ + raylet_methods, /* m_methods */ + NULL, /* m_reload */ + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL, /* m_free */ }; #endif @@ -476,7 +459,7 @@ static struct PyModuleDef moduledef = { #define MOD_INIT(name) PyMODINIT_FUNC init##name(void) #endif -MOD_INIT(liblocal_scheduler_library_python) { +MOD_INIT(libraylet_library_python) { if (PyType_Ready(&PyTaskType) < 0) { INITERROR; } @@ -485,7 +468,7 @@ MOD_INIT(liblocal_scheduler_library_python) { INITERROR; } - if (PyType_Ready(&PyLocalSchedulerClientType) < 0) { + if (PyType_Ready(&PyRayletClientType) < 0) { INITERROR; } @@ -496,9 +479,8 @@ MOD_INIT(liblocal_scheduler_library_python) { #if PY_MAJOR_VERSION >= 3 PyObject *m = PyModule_Create(&moduledef); #else - PyObject *m = - Py_InitModule3("liblocal_scheduler_library_python", local_scheduler_methods, - "A module for the local scheduler."); + PyObject *m = Py_InitModule3("libraylet_library_python", raylet_methods, + "A module for the raylet."); #endif init_numpy_module(); @@ -510,8 +492,8 @@ MOD_INIT(liblocal_scheduler_library_python) { Py_INCREF(&PyObjectIDType); PyModule_AddObject(m, "ObjectID", (PyObject *)&PyObjectIDType); - Py_INCREF(&PyLocalSchedulerClientType); - PyModule_AddObject(m, "LocalSchedulerClient", (PyObject *)&PyLocalSchedulerClientType); + Py_INCREF(&PyRayletClientType); + PyModule_AddObject(m, "RayletClient", (PyObject *)&PyRayletClientType); char common_error[] = "common.error"; CommonError = PyErr_NewException(common_error, NULL, NULL); diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index edfb0db69f597..32a0e593268bc 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -112,9 +112,10 @@ static inline Task ExampleTask(const std::vector &arguments, std::vector references = {argument}; task_arguments.emplace_back(std::make_shared(references)); } + std::vector function_descriptor(3); auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0, - UniqueID::from_random(), task_arguments, num_returns, - required_resources, Language::PYTHON); + task_arguments, num_returns, required_resources, + Language::PYTHON, function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); execution_spec.IncrementNumForwards(); Task task = Task(execution_spec, spec); diff --git a/src/ray/raylet/local_scheduler_client.cc b/src/ray/raylet/local_scheduler_client.cc deleted file mode 100644 index ec9434cbc06db..0000000000000 --- a/src/ray/raylet/local_scheduler_client.cc +++ /dev/null @@ -1,416 +0,0 @@ -#include "local_scheduler_client.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ray/common/common_protocol.h" -#include "ray/ray_config.h" -#include "ray/raylet/format/node_manager_generated.h" -#include "ray/raylet/task_spec.h" -#include "ray/util/logging.h" - -using MessageType = ray::protocol::MessageType; - -// TODO(rkn): The io methods below should be removed. - -int connect_ipc_sock(const char *socket_pathname) { - struct sockaddr_un socket_address; - int socket_fd; - - socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (socket_fd < 0) { - RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname; - return -1; - } - - memset(&socket_address, 0, sizeof(socket_address)); - socket_address.sun_family = AF_UNIX; - if (strlen(socket_pathname) + 1 > sizeof(socket_address.sun_path)) { - RAY_LOG(ERROR) << "Socket pathname is too long."; - return -1; - } - strncpy(socket_address.sun_path, socket_pathname, strlen(socket_pathname) + 1); - - if (connect(socket_fd, (struct sockaddr *)&socket_address, sizeof(socket_address)) != - 0) { - close(socket_fd); - return -1; - } - - return socket_fd; -} - -int connect_ipc_sock_retry(const char *socket_pathname, int num_retries, - int64_t timeout) { - /* Pick the default values if the user did not specify. */ - if (num_retries < 0) { - num_retries = RayConfig::instance().num_connect_attempts(); - } - if (timeout < 0) { - timeout = RayConfig::instance().connect_timeout_milliseconds(); - } - - RAY_CHECK(socket_pathname); - int fd = -1; - for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) { - fd = connect_ipc_sock(socket_pathname); - if (fd >= 0) { - break; - } - if (num_attempts > 0) { - RAY_LOG(ERROR) << "Retrying to connect to socket for pathname " << socket_pathname - << " (num_attempts = " << num_attempts - << ", num_retries = " << num_retries << ")"; - } - /* Sleep for timeout milliseconds. */ - usleep(timeout * 1000); - } - /* If we could not connect to the socket, exit. */ - if (fd == -1) { - RAY_LOG(FATAL) << "Could not connect to socket " << socket_pathname; - } - return fd; -} - -int read_bytes(int fd, uint8_t *cursor, size_t length) { - ssize_t nbytes = 0; - /* Termination condition: EOF or read 'length' bytes total. */ - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - nbytes = read(fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return -1; /* Errno will be set. */ - } else if (0 == nbytes) { - /* Encountered early EOF. */ - return -1; - } - RAY_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - - return 0; -} - -void read_message(int fd, int64_t *type, int64_t *length, uint8_t **bytes) { - int64_t version; - int closed = read_bytes(fd, (uint8_t *)&version, sizeof(version)); - if (closed) { - goto disconnected; - } - RAY_CHECK(version == RayConfig::instance().ray_protocol_version()); - closed = read_bytes(fd, (uint8_t *)type, sizeof(*type)); - if (closed) { - goto disconnected; - } - closed = read_bytes(fd, (uint8_t *)length, sizeof(*length)); - if (closed) { - goto disconnected; - } - *bytes = (uint8_t *)malloc(*length * sizeof(uint8_t)); - closed = read_bytes(fd, *bytes, *length); - if (closed) { - free(*bytes); - goto disconnected; - } - return; - -disconnected: - /* Handle the case in which the socket is closed. */ - *type = static_cast(MessageType::DisconnectClient); - *length = 0; - *bytes = NULL; - return; -} - -int write_bytes(int fd, uint8_t *cursor, size_t length) { - ssize_t nbytes = 0; - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - /* While we haven't written the whole message, write to the file - * descriptor, advance the cursor, and decrease the amount left to write. */ - nbytes = write(fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return -1; /* Errno will be set. */ - } else if (0 == nbytes) { - /* Encountered early EOF. */ - return -1; - } - RAY_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - - return 0; -} - -int do_write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) { - int64_t version = RayConfig::instance().ray_protocol_version(); - int closed; - closed = write_bytes(fd, (uint8_t *)&version, sizeof(version)); - if (closed) { - return closed; - } - closed = write_bytes(fd, (uint8_t *)&type, sizeof(type)); - if (closed) { - return closed; - } - closed = write_bytes(fd, (uint8_t *)&length, sizeof(length)); - if (closed) { - return closed; - } - closed = write_bytes(fd, bytes, length * sizeof(char)); - if (closed) { - return closed; - } - return 0; -} - -int write_message(int fd, int64_t type, int64_t length, uint8_t *bytes, - std::mutex *mutex) { - if (mutex != NULL) { - std::unique_lock guard(*mutex); - return do_write_message(fd, type, length, bytes); - } else { - return do_write_message(fd, type, length, bytes); - } -} - -LocalSchedulerConnection *LocalSchedulerConnection_init( - const char *local_scheduler_socket, const UniqueID &client_id, bool is_worker, - const JobID &driver_id, const Language &language) { - LocalSchedulerConnection *result = new LocalSchedulerConnection(); - result->conn = connect_ipc_sock_retry(local_scheduler_socket, -1, -1); - - /* Register with the local scheduler. - * NOTE(swang): If the local scheduler exits and we are registered as a - * worker, we will get killed. */ - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateRegisterClientRequest( - fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), to_flatbuf(fbb, driver_id), - language); - fbb.Finish(message); - /* Register the process ID with the local scheduler. */ - int success = write_message( - result->conn, static_cast(MessageType::RegisterClientRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &result->write_mutex); - RAY_CHECK(success == 0) << "Unable to register worker with local scheduler"; - - return result; -} - -void LocalSchedulerConnection_free(LocalSchedulerConnection *conn) { - close(conn->conn); - delete conn; -} - -void local_scheduler_disconnect_client(LocalSchedulerConnection *conn) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateDisconnectClient(fbb); - fbb.Finish(message); - write_message(conn->conn, - static_cast(MessageType::IntentionalDisconnectClient), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_submit_raylet(LocalSchedulerConnection *conn, - const std::vector &execution_dependencies, - const ray::raylet::TaskSpecification &task_spec) { - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies_message = to_flatbuf(fbb, execution_dependencies); - auto message = ray::protocol::CreateSubmitTaskRequest( - fbb, execution_dependencies_message, task_spec.ToFlatbuffer(fbb)); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::SubmitTask), fbb.GetSize(), - fbb.GetBufferPointer(), &conn->write_mutex); -} - -ray::raylet::TaskSpecification *local_scheduler_get_task_raylet( - LocalSchedulerConnection *conn) { - int64_t type; - int64_t reply_size; - uint8_t *reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, static_cast(MessageType::GetTask), 0, NULL, - &conn->write_mutex); - // Receive a task from the local scheduler. This will block until the local - // scheduler gives this client a task. - read_message(conn->conn, &type, &reply_size, &reply); - } - if (type == static_cast(MessageType::DisconnectClient)) { - RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection."; - exit(1); - } - if (type != static_cast(MessageType::ExecuteTask)) { - RAY_LOG(FATAL) << "Problem communicating with raylet from worker: check logs or " - "dmesg for previous errors."; - } - - // Parse the flatbuffer object. - auto reply_message = flatbuffers::GetRoot(reply); - - // Set the resource IDs for this task. - conn->resource_ids_.clear(); - for (size_t i = 0; i < reply_message->fractional_resource_ids()->size(); ++i) { - auto const &fractional_resource_ids = - reply_message->fractional_resource_ids()->Get(i); - auto &acquired_resources = conn->resource_ids_[string_from_flatbuf( - *fractional_resource_ids->resource_name())]; - - size_t num_resource_ids = fractional_resource_ids->resource_ids()->size(); - size_t num_resource_fractions = fractional_resource_ids->resource_fractions()->size(); - RAY_CHECK(num_resource_ids == num_resource_fractions); - RAY_CHECK(num_resource_ids > 0); - for (size_t j = 0; j < num_resource_ids; ++j) { - int64_t resource_id = fractional_resource_ids->resource_ids()->Get(j); - double resource_fraction = fractional_resource_ids->resource_fractions()->Get(j); - if (num_resource_ids > 1) { - int64_t whole_fraction = resource_fraction; - RAY_CHECK(whole_fraction == resource_fraction); - } - acquired_resources.push_back(std::make_pair(resource_id, resource_fraction)); - } - } - - ray::raylet::TaskSpecification *task_spec = new ray::raylet::TaskSpecification( - string_from_flatbuf(*reply_message->task_spec())); - - // Free the original message from the local scheduler. - free(reply); - - // Return the copy of the task spec and pass ownership to the caller. - return task_spec; -} - -void local_scheduler_task_done(LocalSchedulerConnection *conn) { - write_message(conn->conn, static_cast(MessageType::TaskDone), 0, NULL, - &conn->write_mutex); -} - -void local_scheduler_fetch_or_reconstruct(LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only, - const TaskID ¤t_task_id) { - flatbuffers::FlatBufferBuilder fbb; - auto object_ids_message = to_flatbuf(fbb, object_ids); - auto message = ray::protocol::CreateFetchOrReconstruct( - fbb, object_ids_message, fetch_only, to_flatbuf(fbb, current_task_id)); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::FetchOrReconstruct), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - /* TODO(swang): Propagate the error. */ -} - -void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn, - const TaskID ¤t_task_id) { - flatbuffers::FlatBufferBuilder fbb; - auto message = - ray::protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id)); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::NotifyUnblocked), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -std::pair, std::vector> local_scheduler_wait( - LocalSchedulerConnection *conn, const std::vector &object_ids, - int num_returns, int64_t timeout_milliseconds, bool wait_local, - const TaskID ¤t_task_id) { - // Write request. - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateWaitRequest( - fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local, - to_flatbuf(fbb, current_task_id)); - fbb.Finish(message); - int64_t type; - int64_t reply_size; - uint8_t *reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, - static_cast(ray::protocol::MessageType::WaitRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - // Read result. - read_message(conn->conn, &type, &reply_size, &reply); - } - if (static_cast(type) != - ray::protocol::MessageType::WaitReply) { - RAY_LOG(FATAL) << "Problem communicating with raylet from worker: check logs or " - "dmesg for previous errors."; - } - auto reply_message = flatbuffers::GetRoot(reply); - // Convert result. - std::pair, std::vector> result; - auto found = reply_message->found(); - for (uint i = 0; i < found->size(); i++) { - ObjectID object_id = ObjectID::from_binary(found->Get(i)->str()); - result.first.push_back(object_id); - } - auto remaining = reply_message->remaining(); - for (uint i = 0; i < remaining->size(); i++) { - ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str()); - result.second.push_back(object_id); - } - /* Free the original message from the local scheduler. */ - free(reply); - return result; -} - -void local_scheduler_push_error(LocalSchedulerConnection *conn, const JobID &job_id, - const std::string &type, const std::string &error_message, - double timestamp) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreatePushErrorRequest( - fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), - fbb.CreateString(error_message), timestamp); - fbb.Finish(message); - - write_message(conn->conn, - static_cast(ray::protocol::MessageType::PushErrorRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_push_profile_events(LocalSchedulerConnection *conn, - const ProfileTableDataT &profile_events) { - flatbuffers::FlatBufferBuilder fbb; - - auto message = CreateProfileTableData(fbb, &profile_events); - fbb.Finish(message); - - write_message(conn->conn, static_cast( - ray::protocol::MessageType::PushProfileEventsRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_free_objects_in_object_store( - LocalSchedulerConnection *conn, const std::vector &object_ids, - bool local_only) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateFreeObjectsRequest(fbb, local_only, - to_flatbuf(fbb, object_ids)); - fbb.Finish(message); - - int success = write_message( - conn->conn, - static_cast(ray::protocol::MessageType::FreeObjectsInObjectStoreRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - RAY_CHECK(success == 0) << "Failed to write message to raylet."; -} diff --git a/src/ray/raylet/local_scheduler_client.h b/src/ray/raylet/local_scheduler_client.h deleted file mode 100644 index 66c76f37a61d3..0000000000000 --- a/src/ray/raylet/local_scheduler_client.h +++ /dev/null @@ -1,185 +0,0 @@ -#ifndef LOCAL_SCHEDULER_CLIENT_H -#define LOCAL_SCHEDULER_CLIENT_H - -#include - -#include "ray/raylet/task_spec.h" - -using ray::ObjectID; -using ray::JobID; -using ray::TaskID; -using ray::ActorID; -using ray::UniqueID; - -struct LocalSchedulerConnection { - /** File descriptor of the Unix domain socket that connects to local - * scheduler. */ - int conn; - /** The IDs of the GPUs that this client can use. NOTE(rkn): This is only used - * by legacy Ray and will be deprecated. */ - std::vector gpu_ids; - /// A map from resource name to the resource IDs that are currently reserved - /// for this worker. Each pair consists of the resource ID and the fraction - /// of that resource allocated for this worker. - std::unordered_map>> resource_ids_; - /// A mutex to protect stateful operations of the local scheduler client. - std::mutex mutex; - /// A mutext to protect write operations of the local scheduler client. - std::mutex write_mutex; -}; - -/** - * Connect to the local scheduler. - * - * @param local_scheduler_socket The name of the socket to use to connect to the - * local scheduler. - * @param worker_id A unique ID to represent the worker. - * @param is_worker Whether this client is a worker. If it is a worker, an - * additional message will be sent to register as one. - * @param driver_id The ID of the driver. This is non-nil if the client is a - * driver. - * @return The connection information. - */ -LocalSchedulerConnection *LocalSchedulerConnection_init( - const char *local_scheduler_socket, const UniqueID &worker_id, bool is_worker, - const JobID &driver_id, const Language &language); - -/** - * Disconnect from the local scheduler. - * - * @param conn Local scheduler connection information returned by - * LocalSchedulerConnection_init. - * @return Void. - */ -void LocalSchedulerConnection_free(LocalSchedulerConnection *conn); - -/// Submit a task using the raylet code path. -/// -/// \param The connection information. -/// \param The execution dependencies. -/// \param The task specification. -/// \return Void. -void local_scheduler_submit_raylet(LocalSchedulerConnection *conn, - const std::vector &execution_dependencies, - const ray::raylet::TaskSpecification &task_spec); - -/** - * Notify the local scheduler that this client is disconnecting gracefully. This - * is used by actors to exit gracefully so that the local scheduler doesn't - * propagate an error message to the driver. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_disconnect_client(LocalSchedulerConnection *conn); - -/// Get next task for this client. This will block until the scheduler assigns -/// a task to this worker. The caller takes ownership of the returned task -/// specification and must free it. -/// -/// \param conn The connection information. -/// \return The assigned task. -ray::raylet::TaskSpecification *local_scheduler_get_task_raylet( - LocalSchedulerConnection *conn); - -/** - * Tell the local scheduler that the client has finished executing a task. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_task_done(LocalSchedulerConnection *conn); - -/** - * Tell the local scheduler to reconstruct or fetch objects. - * - * @param conn The connection information. - * @param object_ids The IDs of the objects to reconstruct. - * @param fetch_only Only fetch objects, do not reconstruct them. - * @param current_task_id The task that needs the objects. - * @return Void. - */ -void local_scheduler_fetch_or_reconstruct(LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only, const TaskID ¤t_task_id); - -/** - * Notify the local scheduler that this client (worker) is no longer blocked. - * - * @param conn The connection information. - * @param current_task_id The task that is no longer blocked. - * @return Void. - */ -void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn, - const TaskID ¤t_task_id); - -// /** -// * Get an actor's current task frontier. -// * -// * @param conn The connection information. -// * @param actor_id The ID of the actor whose frontier is returned. -// * @return A byte vector that can be traversed as an ActorFrontier flatbuffer. -// */ -// const std::vector local_scheduler_get_actor_frontier( -// LocalSchedulerConnection *conn, -// ActorID actor_id); - -// /** -// * Set an actor's current task frontier. -// * -// * @param conn The connection information. -// * @param frontier An ActorFrontier flatbuffer to set the frontier to. -// * @return Void. -// */ -// void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn, -// const std::vector &frontier); - -/// Wait for the given objects until timeout expires or num_return objects are -/// found. -/// -/// \param conn The connection information. -/// \param object_ids The objects to wait for. -/// \param num_returns The number of objects to wait for. -/// \param timeout_milliseconds Duration, in milliseconds, to wait before -/// returning. -/// \param wait_local Whether to wait for objects to appear on this node. -/// \param current_task_id The task that called wait. -/// \return A pair with the first element containing the object ids that were -/// found, and the second element the objects that were not found. -std::pair, std::vector> local_scheduler_wait( - LocalSchedulerConnection *conn, const std::vector &object_ids, - int num_returns, int64_t timeout_milliseconds, bool wait_local, - const TaskID ¤t_task_id); - -/// Push an error to the relevant driver. -/// -/// \param conn The connection information. -/// \param The ID of the job that the error is for. -/// \param The type of the error. -/// \param The error message. -/// \param The timestamp of the error. -/// \return Void. -void local_scheduler_push_error(LocalSchedulerConnection *conn, const JobID &job_id, - const std::string &type, const std::string &error_message, - double timestamp); - -/// Store some profile events in the GCS. -/// -/// \param conn The connection information. -/// \param profile_events A batch of profiling event information. -/// \return Void. -void local_scheduler_push_profile_events(LocalSchedulerConnection *conn, - const ProfileTableDataT &profile_events); - -/// Free a list of objects from object stores. -/// -/// \param conn The connection information. -/// \param object_ids A list of ObjectsIDs to be deleted. -/// \param local_only Whether keep this request with local object store -/// or send it to all the object stores. -/// \return Void. -void local_scheduler_free_objects_in_object_store( - LocalSchedulerConnection *conn, const std::vector &object_ids, - bool local_only); - -#endif diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 85fa4d0e9522f..02f7027892cd6 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -118,15 +118,15 @@ int main(int argc, char *argv[]) { << "max_receives = " << object_manager_config.max_receives << "\n" << "object_chunk_size = " << object_manager_config.object_chunk_size; + // Initialize the node manager. + boost::asio::io_service main_service; + // initialize mock gcs & object directory auto gcs_client = std::make_shared(redis_address, redis_port, redis_password); RAY_LOG(DEBUG) << "Initializing GCS client " << gcs_client->client_table().GetLocalClientId(); - // Initialize the node manager. - boost::asio::io_service main_service; - ray::raylet::Raylet server(main_service, raylet_socket_name, node_ip_address, redis_address, redis_port, redis_password, node_manager_config, object_manager_config, gcs_client); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 9110f0c878819..fd70132485bd4 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -46,34 +46,37 @@ namespace raylet { NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeManagerConfig &config, ObjectManager &object_manager, - std::shared_ptr gcs_client) + std::shared_ptr gcs_client, + std::shared_ptr object_directory) : io_service_(io_service), object_manager_(object_manager), - gcs_client_(gcs_client), + gcs_client_(std::move(gcs_client)), + object_directory_(std::move(object_directory)), heartbeat_timer_(io_service), heartbeat_period_(std::chrono::milliseconds(config.heartbeat_period_ms)), debug_dump_period_(config.debug_dump_period_ms), temp_dir_(config.temp_dir), object_manager_profile_timer_(io_service), - local_resources_(config.resource_config), + initial_config_(config), local_available_resources_(config.resource_config), worker_pool_(config.num_initial_workers, config.num_workers_per_process, config.maximum_startup_concurrency, config.worker_commands), scheduling_policy_(local_queues_), reconstruction_policy_( io_service_, - [this](const TaskID &task_id) { HandleTaskReconstruction(task_id); }, + [this](const TaskID &task_id, bool return_values_lost) { + HandleTaskReconstruction(task_id); + }, RayConfig::instance().initial_reconstruction_timeout_milliseconds(), - gcs_client_->client_table().GetLocalClientId(), gcs_client->task_lease_table(), - std::make_shared(io_service, gcs_client), - gcs_client_->task_reconstruction_log()), + gcs_client_->client_table().GetLocalClientId(), gcs_client_->task_lease_table(), + object_directory_, gcs_client_->task_reconstruction_log()), task_dependency_manager_( object_manager, reconstruction_policy_, io_service, gcs_client_->client_table().GetLocalClientId(), RayConfig::instance().initial_reconstruction_timeout_milliseconds(), - gcs_client->task_lease_table()), + gcs_client_->task_lease_table()), lineage_cache_(gcs_client_->client_table().GetLocalClientId(), - gcs_client->raylet_task_table(), gcs_client->raylet_task_table(), + gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), config.max_lineage_size), remote_clients_(), remote_server_connections_(), @@ -92,7 +95,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, RAY_CHECK_OK(object_manager_.SubscribeObjDeleted( [this](const ObjectID &object_id) { HandleObjectMissing(object_id); })); - ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str(), "", 0)); + ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); } ray::Status NodeManager::RegisterGcs() { @@ -134,15 +137,19 @@ ray::Status NodeManager::RegisterGcs() { JobID::nil(), gcs_client_->client_table().GetLocalClientId(), task_lease_notification_callback, task_lease_empty_callback, nullptr)); - // Register a callback for actor creation notifications. - auto actor_creation_callback = [this](gcs::AsyncGcsClient *client, - const ActorID &actor_id, - const std::vector &data) { - HandleActorStateTransition(actor_id, data.back()); + // Register a callback to handle actor notifications. + auto actor_notification_callback = [this](gcs::AsyncGcsClient *client, + const ActorID &actor_id, + const std::vector &data) { + if (!data.empty()) { + // We only need the last entry, because it represents the latest state of + // this actor. + HandleActorStateTransition(actor_id, data.back()); + } }; RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe( - UniqueID::nil(), UniqueID::nil(), actor_creation_callback, nullptr)); + UniqueID::nil(), UniqueID::nil(), actor_notification_callback, nullptr)); // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, @@ -332,7 +339,7 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { if (client_id == gcs_client_->client_table().GetLocalClientId()) { // We got a notification for ourselves, so we are connected to the GCS now. // Save this NodeManager's resource information in the cluster resource map. - cluster_resource_map_[client_id] = local_resources_; + cluster_resource_map_[client_id] = initial_config_.resource_config; return; } @@ -398,17 +405,31 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { cluster_resource_map_.erase(client_id); // Remove the remote server connection. - remote_server_connections_.erase(client_id); + const auto connection_entry = remote_server_connections_.find(client_id); + if (connection_entry != remote_server_connections_.end()) { + connection_entry->second->Close(); + remote_server_connections_.erase(connection_entry); + } else { + RAY_LOG(WARNING) << "Received ClientRemoved callback for an unknown client " + << client_id << "."; + } // For any live actors that were on the dead node, broadcast a notification // about the actor's death // TODO(swang): This could be very slow if there are many actors. for (const auto &actor_entry : actor_registry_) { if (actor_entry.second.GetNodeManagerId() == client_id && - actor_entry.second.IsAlive()) { - HandleDisconnectedActor(actor_entry.first, /*was_local=*/false); + actor_entry.second.GetState() == ActorState::ALIVE) { + RAY_LOG(INFO) << "Actor " << actor_entry.first + << " is disconnected, because its node " << client_id + << " is removed from cluster. It may be reconstructed."; + HandleDisconnectedActor(actor_entry.first, /*was_local=*/false, + /*intentional_disconnect=*/false); } } + // Notify the object directory that the client has been removed so that it + // can remove it from any cached locations. + object_directory_->HandleClientRemoved(client_id); } void NodeManager::HeartbeatAdded(const ClientID &client_id, @@ -467,53 +488,49 @@ void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_ } } -void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_local) { - RAY_LOG(DEBUG) << "Actor disconnected " << actor_id; - auto actor_entry = actor_registry_.find(actor_id); - RAY_CHECK(actor_entry != actor_registry_.end()); - - // Release all the dummy objects for the dead actor. - if (was_local) { - for (auto &dummy_object : actor_entry->second.GetDummyObjects()) { - HandleObjectMissing(dummy_object); - } - } - - auto new_actor_data = - std::make_shared(actor_entry->second.GetTableData()); - new_actor_data->state = ActorState::DEAD; - HandleActorStateTransition(actor_id, *new_actor_data); - ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; - if (was_local) { - // The actor was local to this node, so we are the only one who should try - // to update the log. - failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { - RAY_LOG(FATAL) << "Failed to update state to DEAD for actor " << id; - }; - } - // Actor reconstruction is disabled, so the actor can only go from ALIVE to - // DEAD. The DEAD entry must therefore be at the second index in the log. - RAY_CHECK_OK(gcs_client_->actor_table().AppendAt(JobID::nil(), actor_id, new_actor_data, - nullptr, failure_callback, - /*log_index=*/1)); +void NodeManager::PublishActorStateTransition( + const ActorID &actor_id, const ActorTableDataT &data, + const ray::gcs::ActorTable::WriteCallback &failure_callback) { + // Copy the actor notification data. + auto actor_notification = std::make_shared(data); + + // The actor log starts with an ALIVE entry. This is followed by 0 to N pairs + // of (RECONSTRUCTING, ALIVE) entries, where N is the maximum number of + // reconstructions. This is followed optionally by a DEAD entry. + int log_length = 2 * (actor_notification->max_reconstructions - + actor_notification->remaining_reconstructions); + if (actor_notification->state != ActorState::ALIVE) { + // RECONSTRUCTING or DEAD entries have an odd index. + log_length += 1; + } + RAY_CHECK_OK(gcs_client_->actor_table().AppendAt( + JobID::nil(), actor_id, actor_notification, nullptr, failure_callback, log_length)); } void NodeManager::HandleActorStateTransition(const ActorID &actor_id, const ActorTableDataT &data) { - RAY_LOG(DEBUG) << "Actor creation notification received: " << actor_id << " " - << static_cast(data.state); - - // Register the new actor. ActorRegistration actor_registration(data); + RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id + << ", node_manager_id = " << actor_registration.GetNodeManagerId() + << ", state = " << static_cast(actor_registration.GetState()) + << ", remaining_reconstructions = " + << actor_registration.GetRemainingReconstructions(); // Update local registry. auto it = actor_registry_.find(actor_id); if (it == actor_registry_.end()) { it = actor_registry_.emplace(actor_id, actor_registration).first; } else { - RAY_CHECK(it->second.GetNodeManagerId() == actor_registration.GetNodeManagerId()); - if (actor_registration.GetState() > it->second.GetState()) { - // The new state is later than our current state. + // Only process the state transition if it is to a later state than ours. + if (actor_registration.GetState() > it->second.GetState() && + actor_registration.GetRemainingReconstructions() == + it->second.GetRemainingReconstructions()) { + // The new state is later than ours if it is about the same lifetime, but + // a greater state. + it->second = actor_registration; + } else if (actor_registration.GetRemainingReconstructions() < + it->second.GetRemainingReconstructions()) { + // The new state is also later than ours it is about a later lifetime of + // the actor. it->second = actor_registration; } else { // Our state is already at or past the update, so skip the update. @@ -521,7 +538,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } } - if (it->second.IsAlive()) { + if (actor_registration.GetState() == ActorState::ALIVE) { // The actor's location is now known. Dequeue any methods that were // submitted before the actor's location was known. // (See design_docs/task_states.rst for the state transition diagram.) @@ -543,9 +560,9 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } // Maintain the invariant that if a task is in the // MethodsWaitingForActorCreation queue, then it is subscribed to its - // respective actor creation task and that task only. Since the actor - // location is now known, we can remove the task from the queue and - // forget its dependency on the actor creation task. + // respective actor creation task. Since the actor location is now known, + // we can remove the task from the queue and forget its dependency on the + // actor creation task. RAY_CHECK(task_dependency_manager_.UnsubscribeDependencies( method.GetTaskSpecification().TaskId())); // The task's uncommitted lineage was already added to the local lineage @@ -553,7 +570,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // empty lineage this time. SubmitTask(method, Lineage()); } - } else { + } else if (actor_registration.GetState() == ActorState::DEAD) { // When an actor dies, loop over all of the queued tasks for that actor // and treat them as failed. auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); @@ -561,6 +578,17 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, for (auto const &task : removed_tasks) { TreatTaskAsFailed(task); } + } else { + RAY_CHECK(actor_registration.GetState() == ActorState::RECONSTRUCTING); + RAY_LOG(DEBUG) << "Actor is being reconstructed: " << actor_id; + // When an actor fails but can be reconstructed, resubmit all of the queued + // tasks for that actor. This will mark the tasks as waiting for actor + // creation. + auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); + auto removed_tasks = local_queues_.RemoveTasks(tasks_to_remove); + for (auto const &task : removed_tasks) { + SubmitTask(task, Lineage()); + } } } @@ -640,7 +668,7 @@ void NodeManager::ProcessClientMessage( return; } break; case protocol::MessageType::IntentionalDisconnectClient: { - ProcessDisconnectClientMessage(client, /* push_warning = */ false); + ProcessDisconnectClientMessage(client, /* intentional_disconnect = */ true); // We don't need to receive future messages from this client, // because it's already disconnected. return; @@ -702,6 +730,50 @@ void NodeManager::ProcessRegisterClientRequestMessage( } } +void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_local, + bool intentional_disconnect) { + auto actor_entry = actor_registry_.find(actor_id); + RAY_CHECK(actor_entry != actor_registry_.end()); + auto &actor_registration = actor_entry->second; + RAY_LOG(DEBUG) << "The actor with ID " << actor_id << " died " + << (intentional_disconnect ? "intentionally" : "unintentionally") + << ", remaining reconstructions = " + << actor_registration.GetRemainingReconstructions(); + + // Check if this actor needs to be reconstructed. + ActorState new_state = + actor_registration.GetRemainingReconstructions() > 0 && !intentional_disconnect + ? ActorState::RECONSTRUCTING + : ActorState::DEAD; + if (was_local) { + // Clean up the dummy objects from this actor. + RAY_LOG(DEBUG) << "Removing dummy objects for actor: " << actor_id; + for (auto &id : actor_entry->second.GetDummyObjects()) { + HandleObjectMissing(id); + } + } + // Update the actor's state. + ActorTableDataT new_actor_data = actor_entry->second.GetTableData(); + new_actor_data.state = new_state; + if (was_local) { + // If the actor was local, immediately update the state in actor registry. + // So if we receive any actor tasks before we receive GCS notification, + // these tasks can be correctly routed to the `MethodsWaitingForActorCreation` queue, + // instead of being assigned to the dead actor. + HandleActorStateTransition(actor_id, new_actor_data); + } + ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; + if (was_local) { + failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, + const ActorTableDataT &data) { + // If the disconnected actor was local, only this node will try to update actor + // state. So the update shouldn't fail. + RAY_LOG(FATAL) << "Failed to update state for actor " << id; + }; + } + PublishActorStateTransition(actor_id, new_actor_data, failure_callback); +} + void NodeManager::ProcessGetTaskMessage( const std::shared_ptr &client) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); @@ -721,7 +793,7 @@ void NodeManager::ProcessGetTaskMessage( } void NodeManager::ProcessDisconnectClientMessage( - const std::shared_ptr &client, bool push_warning) { + const std::shared_ptr &client, bool intentional_disconnect) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); bool is_worker = false, is_driver = false; if (worker) { @@ -768,7 +840,7 @@ void NodeManager::ProcessDisconnectClientMessage( const JobID &job_id = worker->GetAssignedDriverId(); - if (push_warning) { + if (!intentional_disconnect) { // TODO(rkn): Define this constant somewhere else. std::string type = "worker_died"; std::ostringstream error_message; @@ -786,7 +858,7 @@ void NodeManager::ProcessDisconnectClientMessage( if (!actor_id.is_nil()) { RAY_LOG(DEBUG) << "The actor with ID " << actor_id << " died on " << gcs_client_->client_table().GetLocalClientId(); - HandleDisconnectedActor(actor_id, /*was_local=*/true); + HandleDisconnectedActor(actor_id, /*was_local=*/true, intentional_disconnect); } const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); @@ -1064,7 +1136,7 @@ void NodeManager::TreatTaskAsFailed(const Task &task) { // Loop over the return IDs (except the dummy ID) and store a fake object in // the object store. int64_t num_returns = spec.NumReturns(); - if (spec.IsActorTask()) { + if (spec.IsActorCreationTask() || spec.IsActorTask()) { // TODO(rkn): We subtract 1 to avoid the dummy ID. However, this leaks // information about the TaskSpecification implementation. num_returns -= 1; @@ -1098,9 +1170,52 @@ void NodeManager::TreatTaskAsFailed(const Task &task) { task_dependency_manager_.UnsubscribeDependencies(spec.TaskId()); } +void NodeManager::TreatTaskAsFailedIfLost(const Task &task) { + const TaskSpecification &spec = task.GetTaskSpecification(); + RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed."; + // Loop over the return IDs (except the dummy ID) and check whether a + // location for the return ID exists. + int64_t num_returns = spec.NumReturns(); + if (spec.IsActorCreationTask() || spec.IsActorTask()) { + // TODO(rkn): We subtract 1 to avoid the dummy ID. However, this leaks + // information about the TaskSpecification implementation. + num_returns -= 1; + } + // Use a shared flag to make sure that we only treat the task as failed at + // most once. This flag will get deallocated once all of the object table + // lookup callbacks are fired. + auto task_marked_as_failed = std::make_shared(false); + for (int64_t i = 0; i < num_returns; i++) { + const ObjectID object_id = spec.ReturnId(i); + // Lookup the return value's locations. + RAY_CHECK_OK(object_directory_->LookupLocations( + object_id, + [this, task_marked_as_failed, task]( + const ray::ObjectID &object_id, + const std::unordered_set &clients, bool has_been_created) { + if (!*task_marked_as_failed) { + // Only process the object locations if we haven't already marked the + // task as failed. + if (clients.empty() && has_been_created) { + // The object does not exist on any nodes but has been created + // before, so the object has been lost. Mark the task as failed to + // prevent any tasks that depend on this object from hanging. + TreatTaskAsFailed(task); + *task_marked_as_failed = true; + } + } + })); + } +} + void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineage, bool forwarded) { - const TaskID &task_id = task.GetTaskSpecification().TaskId(); + const TaskSpecification &spec = task.GetTaskSpecification(); + const TaskID &task_id = spec.TaskId(); + RAY_LOG(DEBUG) << "Submitting task: task_id = " << task_id + << ", actor_id = " << spec.ActorId() + << ", actor_creation_id = " << spec.ActorCreationId(); + if (local_queues_.HasTask(task_id)) { RAY_LOG(WARNING) << "Submitted task " << task_id << " is already queued and will not be reconstructed. This is most " @@ -1115,49 +1230,58 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag << " already in lineage cache. This is most likely due to reconstruction."; } - const TaskSpecification &spec = task.GetTaskSpecification(); if (spec.IsActorTask()) { // Check whether we know the location of the actor. const auto actor_entry = actor_registry_.find(spec.ActorId()); - if (actor_entry != actor_registry_.end()) { - if (!actor_entry->second.IsAlive()) { + bool seen = actor_entry != actor_registry_.end(); + // If we have already seen this actor and this actor is not being reconstructed, + // its location is known. + bool location_known = + seen && actor_entry->second.GetState() != ActorState::RECONSTRUCTING; + if (location_known) { + if (actor_entry->second.GetState() == ActorState::DEAD) { + // If this actor is dead, either because the actor process is dead + // or because its residing node is dead, treat this task as failed. TreatTaskAsFailed(task); } else { - // We have a known location for the actor. + // If this actor is alive, check whether this actor is local. auto node_manager_id = actor_entry->second.GetNodeManagerId(); if (node_manager_id == gcs_client_->client_table().GetLocalClientId()) { - // Queue the task for local execution, bypassing placement. + // If this actor is local, queue the task for local execution, bypassing + // placement. EnqueuePlaceableTask(task); } else { - // If the node manager has been removed, then it must have already been - // marked as DEAD in the handler for a removed GCS client. - RAY_CHECK(!gcs_client_->client_table().IsRemoved(node_manager_id)); - // The actor is remote. Attempt to forward the task to the node manager - // that owns the actor. If this fails to forward the task, the task - // will be resubmitted locally. + // The actor is remote. Forward the task to the node manager that owns + // the actor. + // Attempt to forward the task. If this fails to forward the task, + // the task will be resubmit locally. ForwardTaskOrResubmit(task, node_manager_id); } } } else { - // We do not have a registered location for the object, so either the - // actor has not yet been created or we missed the notification for the - // actor creation because this node joined the cluster after the actor - // was already created. Look up the actor's registered location in case - // we missed the creation notification. - // NOTE(swang): This codepath needs to be tested in a cluster setting. - auto lookup_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { - if (!data.empty()) { - // The actor has been created. - HandleActorStateTransition(actor_id, data.back()); - } else { - // The actor has not yet been created. - // TODO(swang): Set a timer for reconstructing the actor creation - // task. - } - }; - RAY_CHECK_OK(gcs_client_->actor_table().Lookup(JobID::nil(), spec.ActorId(), - lookup_callback)); + ObjectID actor_creation_dummy_object; + if (!seen) { + // We do not have a registered location for the object, so either the + // actor has not yet been created or we missed the notification for the + // actor creation because this node joined the cluster after the actor + // was already created. Look up the actor's registered location in case + // we missed the creation notification. + auto lookup_callback = [this](gcs::AsyncGcsClient *client, + const ActorID &actor_id, + const std::vector &data) { + if (!data.empty()) { + // The actor has been created. We only need the last entry, because + // it represents the latest state of this actor. + HandleActorStateTransition(actor_id, data.back()); + } + }; + RAY_CHECK_OK(gcs_client_->actor_table().Lookup(JobID::nil(), spec.ActorId(), + lookup_callback)); + actor_creation_dummy_object = spec.ActorCreationDummyObjectId(); + } else { + actor_creation_dummy_object = actor_entry->second.GetActorCreationDependency(); + } + // Keep the task queued until we discover the actor's location. // (See design_docs/task_states.rst for the state transition diagram.) local_queues_.QueueMethodsWaitingForActorCreation({task}); @@ -1169,7 +1293,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // waiting queue, the caller must make the corresponding call to // UnsubscribeDependencies. task_dependency_manager_.SubscribeDependencies(spec.TaskId(), - {spec.ActorCreationDummyObjectId()}); + {actor_creation_dummy_object}); // Mark the task as pending. It will be canceled once we discover the // actor's location and either execute the task ourselves or forward it // to another node. @@ -1328,7 +1452,11 @@ bool NodeManager::AssignTask(const Task &task) { // If this is an actor task, check that the new task has the correct counter. if (spec.IsActorTask()) { if (CheckDuplicateActorTask(actor_registry_, spec)) { - // This actor has been already assigned, so ignore it. + // The actor is alive, and a task that has already been executed before + // has been found. The task will be treated as failed if at least one of + // the task's return values have been evicted, to prevent the application + // from hanging. + TreatTaskAsFailedIfLost(task); return true; } } @@ -1443,38 +1571,47 @@ void NodeManager::FinishAssignedTask(Worker &worker) { // If this was an actor creation task, then convert the worker to an actor. auto actor_id = task.GetTaskSpecification().ActorCreationId(); worker.AssignActorId(actor_id); - const auto driver_id = task.GetTaskSpecification().DriverId(); - // Publish the actor creation event to all other nodes so that methods for // the actor will be forwarded directly to this node. - RAY_CHECK(actor_registry_.find(actor_id) == actor_registry_.end()) - << "Created an actor that already exists"; - auto actor_data = std::make_shared(); - actor_data->actor_id = actor_id.binary(); - actor_data->actor_creation_dummy_object_id = - task.GetTaskSpecification().ActorDummyObject().binary(); - actor_data->driver_id = driver_id.binary(); - actor_data->node_manager_id = gcs_client_->client_table().GetLocalClientId().binary(); - actor_data->state = ActorState::ALIVE; - - RAY_LOG(DEBUG) << "Publishing actor creation: " << actor_id - << " driver_id: " << driver_id; - HandleActorStateTransition(actor_id, *actor_data); - // The actor should not have been created before, so writing to the first - // index in the log should succeed. - auto failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { - // TODO(swang): Instead of making this a fatal check, we could just kill - // the duplicate actor process. If we do this, we must make sure to - // either resubmit the tasks that went to the duplicate actor, or wait - // for success before handling the actor state transition to ALIVE. - RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; - }; - RAY_CHECK_OK(gcs_client_->actor_table().AppendAt( - JobID::nil(), actor_id, actor_data, nullptr, failure_callback, /*log_index=*/0)); + auto actor_entry = actor_registry_.find(actor_id); + ActorTableDataT new_actor_data; + if (actor_entry == actor_registry_.end()) { + // Set all of the static fields for the actor. These fields will not + // change even if the actor fails or is reconstructed. + new_actor_data.actor_id = actor_id.binary(); + new_actor_data.actor_creation_dummy_object_id = + task.GetTaskSpecification().ActorDummyObject().binary(); + new_actor_data.driver_id = task.GetTaskSpecification().DriverId().binary(); + new_actor_data.max_reconstructions = + task.GetTaskSpecification().MaxActorReconstructions(); + // This is the first time that the actor has been created, so the number + // of remaining reconstructions is the max. + new_actor_data.remaining_reconstructions = + task.GetTaskSpecification().MaxActorReconstructions(); + } else { + // If we've already seen this actor, it means that this actor was reconstructed. + // Thus, its previous state must be RECONSTRUCTING. + RAY_CHECK(actor_entry->second.GetState() == ActorState::RECONSTRUCTING); + // Copy the static fields from the current actor entry. + new_actor_data = actor_entry->second.GetTableData(); + // We are reconstructing the actor, so subtract its + // remaining_reconstructions by 1. + new_actor_data.remaining_reconstructions--; + } - // Resources required by an actor creation task are acquired for the - // lifetime of the actor, so we do not release any resources here. + // Set the new fields for the actor's state to indicate that the actor is + // now alive on this node manager. + new_actor_data.node_manager_id = + gcs_client_->client_table().GetLocalClientId().binary(); + new_actor_data.state = ActorState::ALIVE; + HandleActorStateTransition(actor_id, new_actor_data); + PublishActorStateTransition( + actor_id, new_actor_data, + /*failure_callback=*/ + [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableDataT &data) { + // Only one node at a time should succeed at creating the actor. + RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; + }); } else { // Release task's resources. local_available_resources_.Release(worker.GetTaskResourceIds()); @@ -1488,8 +1625,6 @@ void NodeManager::FinishAssignedTask(Worker &worker) { // If the finished task was an actor task, mark the returned dummy object as // locally available. This is not added to the object table, so the update // will be invisible to both the local object manager and the other nodes. - // NOTE(swang): These objects are never cleaned up. We should consider - // removing the objects, e.g., when an actor is terminated. if (task.GetTaskSpecification().IsActorCreationTask() || task.GetTaskSpecification().IsActorTask()) { ActorID actor_id; @@ -1543,7 +1678,12 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id) { // The task was not in the GCS task table. It must therefore be in the // lineage cache. - RAY_CHECK(lineage_cache_.ContainsTask(task_id)); + RAY_CHECK(lineage_cache_.ContainsTask(task_id)) + << "Task metadata not found in either GCS or lineage cache. It may have been " + "evicted " + << "by the redis LRU configuration. Consider increasing the memory " + "allocation via " + << "ray.init(redis_max_memory=)."; // Use a copy of the cached task spec to re-execute the task. const Task task = lineage_cache_.GetTask(task_id); ResubmitTask(task); @@ -1552,23 +1692,22 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { } void NodeManager::ResubmitTask(const Task &task) { - if (task.GetTaskSpecification().IsActorTask()) { - // Actor reconstruction is turned off by default right now. - const ActorID actor_id = task.GetTaskSpecification().ActorId(); - auto it = actor_registry_.find(actor_id); - RAY_CHECK(it != actor_registry_.end()); - if (it->second.IsAlive()) { - // If the actor is still alive, then do not resubmit. - RAY_LOG(ERROR) << "The output of an actor task is required, but the actor may " - "still be alive. If the output has been evicted, the job may " - "hang."; + RAY_LOG(DEBUG) << "Attempting to resubmit task " + << task.GetTaskSpecification().TaskId(); + + // Actors should only be recreated if the first initialization failed or if + // the most recent instance of the actor failed. + if (task.GetTaskSpecification().IsActorCreationTask()) { + const auto &actor_id = task.GetTaskSpecification().ActorCreationId(); + const auto it = actor_registry_.find(actor_id); + if (it != actor_registry_.end() && it->second.GetState() == ActorState::ALIVE) { + // If the actor is still alive, then do not resubmit the task. If the + // actor actually is dead and a result is needed, then reconstruction + // for this task will be triggered again. + RAY_LOG(WARNING) + << "Actor creation task resubmitted, but the actor is still alive."; return; } - // The actor is dead. The actor task will get resubmitted, at which point - // it will be treated as failed. - } else { - RAY_LOG(INFO) << "Reconstructing task " << task.GetTaskSpecification().TaskId() - << " on client " << gcs_client_->client_table().GetLocalClientId(); } // Driver tasks cannot be reconstructed. If this is a driver task, push an @@ -1586,6 +1725,8 @@ void NodeManager::ResubmitTask(const Task &task) { return; } + RAY_LOG(INFO) << "Resubmitting task " << task.GetTaskSpecification().TaskId() + << " on client " << gcs_client_->client_table().GetLocalClientId(); // The task may be reconstructed. Submit it with an empty lineage, since any // uncommitted lineage must already be in the lineage cache. At this point, // the task should not yet exist in the local scheduling queue. If it does, @@ -1604,6 +1745,7 @@ void NodeManager::HandleObjectLocal(const ObjectID &object_id) { // First filter out the tasks that should not be moved to READY. local_queues_.FilterState(ready_task_id_set, TaskState::BLOCKED); local_queues_.FilterState(ready_task_id_set, TaskState::DRIVER); + local_queues_.FilterState(ready_task_id_set, TaskState::WAITING_FOR_ACTOR_CREATION); // Make sure that the remaining tasks are all WAITING. auto ready_task_id_set_copy = ready_task_id_set; @@ -1773,7 +1915,7 @@ std::string NodeManager::DebugString() const { std::stringstream result; uint64_t now_ms = current_time_ms(); result << "NodeManager:"; - result << "\nLocalResources: " << local_resources_.DebugString(); + result << "\nInitialConfigResources: " << initial_config_.resource_config.ToString(); result << "\nClusterResources:"; for (auto &pair : cluster_resource_map_) { result << "\n" << pair.first.hex() << ": " << pair.second.DebugString(); @@ -1788,10 +1930,13 @@ std::string NodeManager::DebugString() const { result << "\nActorRegistry:"; int live_actors = 0; int dead_actors = 0; + int reconstructing_actors = 0; int max_num_handles = 0; for (auto &pair : actor_registry_) { - if (pair.second.IsAlive()) { + if (pair.second.GetState() == ActorState::ALIVE) { live_actors += 1; + } else if (pair.second.GetState() == ActorState::RECONSTRUCTING) { + reconstructing_actors += 1; } else { dead_actors += 1; } @@ -1800,6 +1945,7 @@ std::string NodeManager::DebugString() const { } } result << "\n- num live actors: " << live_actors; + result << "\n- num reconstructing actors: " << live_actors; result << "\n- num dead actors: " << dead_actors; result << "\n- max num handles: " << max_num_handles; result << "\nRemoteConnections:"; diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 7fd820a5e5adc..d8502e7d8005a 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -58,7 +58,8 @@ class NodeManager { /// \param object_manager A reference to the local object manager. NodeManager(boost::asio::io_service &io_service, const NodeManagerConfig &config, ObjectManager &object_manager, - std::shared_ptr gcs_client); + std::shared_ptr gcs_client, + std::shared_ptr object_directory_); /// Process a new client connection. /// @@ -153,9 +154,17 @@ class NodeManager { /// it were any other task that had been assigned, executed, and removed from /// the local queue. /// - /// \param spec The specification of the task. + /// \param task The task to fail. /// \return Void. void TreatTaskAsFailed(const Task &task); + /// This is similar to TreatTaskAsFailed, but it will only mark the task as + /// failed if at least one of the task's return values is lost. A return + /// value is lost if it has been created before, but no longer exists on any + /// nodes, due to either node failure or eviction. + /// + /// \param task The task to potentially fail. + /// \return Void. + void TreatTaskAsFailedIfLost(const Task &task); /// Handle specified task's submission to the local node manager. /// /// \param task The task being submitted. @@ -260,22 +269,25 @@ class NodeManager { /// \return Void. void KillWorker(std::shared_ptr worker); - /// Methods for actor scheduling. - /// Handler for an actor state transition, for a newly created actor or an - /// actor that died. This method is idempotent and will ignore old state - /// transitions. + /// The callback for handling an actor state transition (e.g., from ALIVE to + /// DEAD), whether as a notification from the actor table or as a handler for + /// a local actor's state transition. This method is idempotent and will ignore + /// old state transition. /// - /// \param actor_id The actor ID of the actor that was created. - /// \param data Data associated with the actor state transition. + /// \param actor_id The actor ID of the actor whose state was updated. + /// \param data Data associated with this notification. /// \return Void. void HandleActorStateTransition(const ActorID &actor_id, const ActorTableDataT &data); - /// Handler for an actor dying. The actor may be remote. + /// Publish an actor's state transition to all other nodes. /// - /// \param actor_id The actor ID of the actor that died. - /// \param was_local Whether the actor was local. - /// \return Void. - void HandleDisconnectedActor(const ActorID &actor_id, bool was_local); + /// \param actor_id The actor ID of the actor whose state was updated. + /// \param data Data to publish. + /// \param failure_callback An optional callback to call if the publish is + /// unsuccessful. + void PublishActorStateTransition( + const ActorID &actor_id, const ActorTableDataT &data, + const ray::gcs::ActorTable::WriteCallback &failure_callback); /// When a driver dies, loop over all of the queued tasks for that driver and /// treat them as failed. @@ -332,10 +344,11 @@ class NodeManager { /// client. /// /// \param client The client that sent the message. - /// \param push_warning Propogate error message if true. + /// \param intentional_disconnect Wether the client was intentionally disconnected. /// \return Void. void ProcessDisconnectClientMessage( - const std::shared_ptr &client, bool push_warning = true); + const std::shared_ptr &client, + bool intentional_disconnect = false); /// Process client message of SubmitTask /// @@ -365,6 +378,18 @@ class NodeManager { /// \return Void. void ProcessPushErrorRequestMessage(const uint8_t *message_data); + /// Handle the case where an actor is disconnected, determine whether this + /// actor needs to be reconstructed and then update actor table. + /// This function needs to be called either when actor process dies or when + /// a node dies. + /// + /// \param actor_id Id of this actor. + /// \param was_local Whether the disconnected was on this local node. + /// \param intentional_disconnect Wether the client was intentionally disconnected. + /// \return Void. + void HandleDisconnectedActor(const ActorID &actor_id, bool was_local, + bool intentional_disconnect); + boost::asio::io_service &io_service_; ObjectManager &object_manager_; /// A Plasma object store client. This is used exclusively for creating new @@ -373,6 +398,8 @@ class NodeManager { plasma::PlasmaClient store_client_; /// A client connection to the GCS. std::shared_ptr gcs_client_; + /// The object table. This is shared with the object manager. + std::shared_ptr object_directory_; /// The timer used to send heartbeats. boost::asio::steady_timer heartbeat_timer_; /// The period used for the heartbeat timer. @@ -389,8 +416,8 @@ class NodeManager { uint64_t last_heartbeat_at_ms_; /// The time that the last debug string was logged to the console. uint64_t last_debug_dump_at_ms_; - /// The resources local to this node. - const SchedulingResources local_resources_; + /// Initial node manager configuration. + const NodeManagerConfig initial_config_; /// The resources (and specific resource IDs) that are currently available. ResourceIdSet local_available_resources_; std::unordered_map cluster_resource_map_; diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index 83c7a9f8f8be7..0d514e2ecb3ff 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -74,8 +74,8 @@ class TestObjectManagerBase : public ::testing::Test { GetNodeManagerConfig("raylet_2", store_sock_2), om_config_2, gcs_client_2)); // connect to stores. - ARROW_CHECK_OK(client1.Connect(store_sock_1, "", plasma::kPlasmaDefaultReleaseDelay)); - ARROW_CHECK_OK(client2.Connect(store_sock_2, "", plasma::kPlasmaDefaultReleaseDelay)); + ARROW_CHECK_OK(client1.Connect(store_sock_1)); + ARROW_CHECK_OK(client2.Connect(store_sock_2)); } void TearDown() { @@ -206,7 +206,7 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { << "\n"; ClientTableDataT data; gcs_client_2->client_table().GetClient(client_id_1, data); - RAY_LOG(INFO) << (ClientID::from_binary(data.client_id) == ClientID::nil()); + RAY_LOG(INFO) << (ClientID::from_binary(data.client_id).is_nil()); RAY_LOG(INFO) << "ClientID=" << ClientID::from_binary(data.client_id); RAY_LOG(INFO) << "ClientIp=" << data.node_manager_address; RAY_LOG(INFO) << "ClientPort=" << data.node_manager_port; diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 679b18052920b..39028ce7f96b8 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -18,8 +18,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client) : gcs_client_(gcs_client), - object_manager_(main_service, object_manager_config, gcs_client), - node_manager_(main_service, node_manager_config, object_manager_, gcs_client_), + object_directory_(std::make_shared(main_service, gcs_client_)), + object_manager_(main_service, object_manager_config, object_directory_), + node_manager_(main_service, node_manager_config, object_manager_, gcs_client_, + object_directory_), socket_name_(socket_name), acceptor_(main_service, boost::asio::local::stream_protocol::endpoint(socket_name)), socket_(main_service), diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 9b424781af171..84274ea6ecfe1 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -70,6 +70,9 @@ class Raylet { /// A client connection to the GCS. std::shared_ptr gcs_client_; + /// The object table. This is shared between the object manager and node + /// manager. + std::shared_ptr object_directory_; /// Manages client requests for object transfers and availability. ObjectManager object_manager_; /// Manages client requests for task submission and execution. diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc new file mode 100644 index 0000000000000..884fc1f4fd70a --- /dev/null +++ b/src/ray/raylet/raylet_client.cc @@ -0,0 +1,360 @@ +#include "raylet_client.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ray/common/common_protocol.h" +#include "ray/ray_config.h" +#include "ray/raylet/format/node_manager_generated.h" +#include "ray/raylet/task_spec.h" +#include "ray/util/logging.h" + +using MessageType = ray::protocol::MessageType; + +// TODO(rkn): The io methods below should be removed. +int connect_ipc_sock(const std::string &socket_pathname) { + struct sockaddr_un socket_address; + int socket_fd; + + socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (socket_fd < 0) { + RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname; + return -1; + } + + memset(&socket_address, 0, sizeof(socket_address)); + socket_address.sun_family = AF_UNIX; + if (socket_pathname.length() + 1 > sizeof(socket_address.sun_path)) { + RAY_LOG(ERROR) << "Socket pathname is too long."; + close(socket_fd); + return -1; + } + strncpy(socket_address.sun_path, socket_pathname.c_str(), socket_pathname.length() + 1); + + if (connect(socket_fd, (struct sockaddr *)&socket_address, sizeof(socket_address)) != + 0) { + close(socket_fd); + return -1; + } + return socket_fd; +} + +int read_bytes(int socket_fd, uint8_t *cursor, size_t length) { + ssize_t nbytes = 0; + // Termination condition: EOF or read 'length' bytes total. + size_t bytesleft = length; + size_t offset = 0; + while (bytesleft > 0) { + nbytes = read(socket_fd, cursor + offset, bytesleft); + if (nbytes < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + continue; + } + return -1; // Errno will be set. + } else if (0 == nbytes) { + // Encountered early EOF. + return -1; + } + RAY_CHECK(nbytes > 0); + bytesleft -= nbytes; + offset += nbytes; + } + return 0; +} + +int write_bytes(int socket_fd, uint8_t *cursor, size_t length) { + ssize_t nbytes = 0; + size_t bytesleft = length; + size_t offset = 0; + while (bytesleft > 0) { + // While we haven't written the whole message, write to the file + // descriptor, advance the cursor, and decrease the amount left to write. + nbytes = write(socket_fd, cursor + offset, bytesleft); + if (nbytes < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + continue; + } + return -1; // Errno will be set. + } else if (0 == nbytes) { + // Encountered early EOF. + return -1; + } + RAY_CHECK(nbytes > 0); + bytesleft -= nbytes; + offset += nbytes; + } + return 0; +} + +RayletConnection::RayletConnection(const std::string &raylet_socket, int num_retries, + int64_t timeout) { + // Pick the default values if the user did not specify. + if (num_retries < 0) { + num_retries = RayConfig::instance().num_connect_attempts(); + } + if (timeout < 0) { + timeout = RayConfig::instance().connect_timeout_milliseconds(); + } + RAY_CHECK(!raylet_socket.empty()); + conn_ = -1; + for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) { + conn_ = connect_ipc_sock(raylet_socket); + if (conn_ >= 0) break; + if (num_attempts > 0) { + RAY_LOG(ERROR) << "Retrying to connect to socket for pathname " << raylet_socket + << " (num_attempts = " << num_attempts + << ", num_retries = " << num_retries << ")"; + } + // Sleep for timeout milliseconds. + usleep(timeout * 1000); + } + // If we could not connect to the socket, exit. + if (conn_ == -1) { + RAY_LOG(FATAL) << "Could not connect to socket " << raylet_socket; + } +} + +ray::Status RayletConnection::Disconnect() { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateDisconnectClient(fbb); + fbb.Finish(message); + auto status = WriteMessage(MessageType::IntentionalDisconnectClient, &fbb); + // Don't be too strict for disconnection errors. + // Just create logs and prevent it from crash. + if (!status.ok()) { + RAY_LOG(ERROR) << status.ToString() + << " [RayletClient] Failed to disconnect from raylet."; + } + return ray::Status::OK(); +} + +ray::Status RayletConnection::ReadMessage(MessageType type, + std::unique_ptr &message) { + int64_t version; + int64_t type_field; + int64_t length; + int closed = read_bytes(conn_, (uint8_t *)&version, sizeof(version)); + if (closed) goto disconnected; + RAY_CHECK(version == RayConfig::instance().ray_protocol_version()); + closed = read_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field)); + if (closed) goto disconnected; + closed = read_bytes(conn_, (uint8_t *)&length, sizeof(length)); + if (closed) goto disconnected; + message = std::unique_ptr(new uint8_t[length]); + closed = read_bytes(conn_, message.get(), length); + if (closed) { + // Handle the case in which the socket is closed. + message.reset(nullptr); + disconnected: + message = nullptr; + type_field = static_cast(MessageType::DisconnectClient); + length = 0; + } + if (type_field == static_cast(MessageType::DisconnectClient)) { + return ray::Status::IOError("[RayletClient] Raylet connection closed."); + } + if (type_field != static_cast(type)) { + return ray::Status::TypeError( + std::string("[RayletClient] Raylet connection corrupted. ") + + "Expected message type: " + std::to_string(static_cast(type)) + + "; got message type: " + std::to_string(type_field) + + ". Check logs or dmesg for previous errors."); + } + return ray::Status::OK(); +} + +ray::Status RayletConnection::WriteMessage(MessageType type, + flatbuffers::FlatBufferBuilder *fbb) { + std::unique_lock guard(write_mutex_); + int64_t version = RayConfig::instance().ray_protocol_version(); + int64_t length = fbb ? fbb->GetSize() : 0; + uint8_t *bytes = fbb ? fbb->GetBufferPointer() : nullptr; + int64_t type_field = static_cast(type); + auto io_error = ray::Status::IOError("[RayletClient] Connection closed unexpectedly."); + int closed; + closed = write_bytes(conn_, (uint8_t *)&version, sizeof(version)); + if (closed) return io_error; + closed = write_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field)); + if (closed) return io_error; + closed = write_bytes(conn_, (uint8_t *)&length, sizeof(length)); + if (closed) return io_error; + closed = write_bytes(conn_, bytes, length * sizeof(char)); + if (closed) return io_error; + return ray::Status::OK(); +} + +ray::Status RayletConnection::AtomicRequestReply( + MessageType request_type, MessageType reply_type, + std::unique_ptr &reply_message, flatbuffers::FlatBufferBuilder *fbb) { + std::unique_lock guard(mutex_); + auto status = WriteMessage(request_type, fbb); + if (!status.ok()) return status; + return ReadMessage(reply_type, reply_message); +} + +RayletClient::RayletClient(const std::string &raylet_socket, const UniqueID &client_id, + bool is_worker, const JobID &driver_id, + const Language &language) + : client_id_(client_id), + is_worker_(is_worker), + driver_id_(driver_id), + language_(language) { + // For C++14, we could use std::make_unique + conn_ = std::unique_ptr(new RayletConnection(raylet_socket, -1, -1)); + + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateRegisterClientRequest( + fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), to_flatbuf(fbb, driver_id), + language); + fbb.Finish(message); + // Register the process ID with the raylet. + // NOTE(swang): If raylet exits and we are registered as a worker, we will get killed. + auto status = conn_->WriteMessage(MessageType::RegisterClientRequest, &fbb); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Unable to register worker with raylet."); +} + +ray::Status RayletClient::SubmitTask(const std::vector &execution_dependencies, + const ray::raylet::TaskSpecification &task_spec) { + flatbuffers::FlatBufferBuilder fbb; + auto execution_dependencies_message = to_flatbuf(fbb, execution_dependencies); + auto message = ray::protocol::CreateSubmitTaskRequest( + fbb, execution_dependencies_message, task_spec.ToFlatbuffer(fbb)); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::SubmitTask, &fbb); +} + +ray::Status RayletClient::GetTask( + std::unique_ptr *task_spec) { + std::unique_ptr reply; + // Receive a task from the raylet. This will block until the local + // scheduler gives this client a task. + auto status = + conn_->AtomicRequestReply(MessageType::GetTask, MessageType::ExecuteTask, reply); + if (!status.ok()) return status; + // Parse the flatbuffer object. + auto reply_message = flatbuffers::GetRoot(reply.get()); + // Set the resource IDs for this task. + resource_ids_.clear(); + for (size_t i = 0; i < reply_message->fractional_resource_ids()->size(); ++i) { + auto const &fractional_resource_ids = + reply_message->fractional_resource_ids()->Get(i); + auto &acquired_resources = + resource_ids_[string_from_flatbuf(*fractional_resource_ids->resource_name())]; + + size_t num_resource_ids = fractional_resource_ids->resource_ids()->size(); + size_t num_resource_fractions = fractional_resource_ids->resource_fractions()->size(); + RAY_CHECK(num_resource_ids == num_resource_fractions); + RAY_CHECK(num_resource_ids > 0); + for (size_t j = 0; j < num_resource_ids; ++j) { + int64_t resource_id = fractional_resource_ids->resource_ids()->Get(j); + double resource_fraction = fractional_resource_ids->resource_fractions()->Get(j); + if (num_resource_ids > 1) { + int64_t whole_fraction = resource_fraction; + RAY_CHECK(whole_fraction == resource_fraction); + } + acquired_resources.push_back(std::make_pair(resource_id, resource_fraction)); + } + } + + // Return the copy of the task spec and pass ownership to the caller. + task_spec->reset(new ray::raylet::TaskSpecification( + string_from_flatbuf(*reply_message->task_spec()))); + return ray::Status::OK(); +} + +ray::Status RayletClient::TaskDone() { + return conn_->WriteMessage(MessageType::TaskDone); +} + +ray::Status RayletClient::FetchOrReconstruct(const std::vector &object_ids, + bool fetch_only, + const TaskID ¤t_task_id) { + flatbuffers::FlatBufferBuilder fbb; + auto object_ids_message = to_flatbuf(fbb, object_ids); + auto message = ray::protocol::CreateFetchOrReconstruct( + fbb, object_ids_message, fetch_only, to_flatbuf(fbb, current_task_id)); + fbb.Finish(message); + auto status = conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb); + return status; +} + +ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = + ray::protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id)); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb); +} + +ray::Status RayletClient::Wait(const std::vector &object_ids, int num_returns, + int64_t timeout_milliseconds, bool wait_local, + const TaskID ¤t_task_id, WaitResultPair *result) { + // Write request. + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateWaitRequest( + fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local, + to_flatbuf(fbb, current_task_id)); + fbb.Finish(message); + std::unique_ptr reply; + auto status = conn_->AtomicRequestReply(MessageType::WaitRequest, + MessageType::WaitReply, reply, &fbb); + if (!status.ok()) return status; + // Parse the flatbuffer object. + auto reply_message = flatbuffers::GetRoot(reply.get()); + auto found = reply_message->found(); + for (uint i = 0; i < found->size(); i++) { + ObjectID object_id = ObjectID::from_binary(found->Get(i)->str()); + result->first.push_back(object_id); + } + auto remaining = reply_message->remaining(); + for (uint i = 0; i < remaining->size(); i++) { + ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str()); + result->second.push_back(object_id); + } + return ray::Status::OK(); +} + +ray::Status RayletClient::PushError(const JobID &job_id, const std::string &type, + const std::string &error_message, double timestamp) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreatePushErrorRequest( + fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), + fbb.CreateString(error_message), timestamp); + fbb.Finish(message); + + return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb); +} + +ray::Status RayletClient::PushProfileEvents(const ProfileTableDataT &profile_events) { + flatbuffers::FlatBufferBuilder fbb; + auto message = CreateProfileTableData(fbb, &profile_events); + fbb.Finish(message); + + auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb); + // Don't be too strict for profile errors. Just create logs and prevent it from crash. + if (!status.ok()) { + RAY_LOG(ERROR) << status.ToString() + << " [RayletClient] Failed to push profile events."; + } + return ray::Status::OK(); +} + +ray::Status RayletClient::FreeObjects(const std::vector &object_ids, + bool local_only) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateFreeObjectsRequest(fbb, local_only, + to_flatbuf(fbb, object_ids)); + fbb.Finish(message); + + auto status = conn_->WriteMessage(MessageType::FreeObjectsInObjectStoreRequest, &fbb); + return status; +} diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h new file mode 100644 index 0000000000000..9d93919717877 --- /dev/null +++ b/src/ray/raylet/raylet_client.h @@ -0,0 +1,171 @@ +#ifndef RAYLET_CLIENT_H +#define RAYLET_CLIENT_H + +#include +#include +#include +#include + +#include "ray/raylet/task_spec.h" +#include "ray/status.h" + +using ray::ActorID; +using ray::JobID; +using ray::ObjectID; +using ray::TaskID; +using ray::UniqueID; + +using MessageType = ray::protocol::MessageType; +using ResourceMappingType = + std::unordered_map>>; +using WaitResultPair = std::pair, std::vector>; + +class RayletConnection { + public: + /// Connect to the raylet. + /// + /// \param raylet_socket The name of the socket to use to connect to the raylet. + /// \param worker_id A unique ID to represent the worker. + /// \param is_worker Whether this client is a worker. If it is a worker, an + /// additional message will be sent to register as one. + /// \param driver_id The ID of the driver. This is non-nil if the client is a + /// driver. + /// \return The connection information. + RayletConnection(const std::string &raylet_socket, int num_retries, int64_t timeout); + + ~RayletConnection() { close(conn_); } + /// Notify the raylet that this client is disconnecting gracefully. This + /// is used by actors to exit gracefully so that the raylet doesn't + /// propagate an error message to the driver. + /// + /// \return ray::Status. + ray::Status Disconnect(); + ray::Status ReadMessage(MessageType type, std::unique_ptr &message); + ray::Status WriteMessage(MessageType type, + flatbuffers::FlatBufferBuilder *fbb = nullptr); + ray::Status AtomicRequestReply(MessageType request_type, MessageType reply_type, + std::unique_ptr &reply_message, + flatbuffers::FlatBufferBuilder *fbb = nullptr); + + private: + /// File descriptor of the Unix domain socket that connects to raylet. + int conn_; + /// A mutex to protect stateful operations of the raylet client. + std::mutex mutex_; + /// A mutex to protect write operations of the raylet client. + std::mutex write_mutex_; +}; + +class RayletClient { + public: + /// Connect to the raylet. + /// + /// \param raylet_socket The name of the socket to use to connect to the raylet. + /// \param worker_id A unique ID to represent the worker. + /// \param is_worker Whether this client is a worker. If it is a worker, an + /// additional message will be sent to register as one. + /// \param driver_id The ID of the driver. This is non-nil if the client is a driver. + /// \return The connection information. + RayletClient(const std::string &raylet_socket, const UniqueID &client_id, + bool is_worker, const JobID &driver_id, const Language &language); + + ray::Status Disconnect() { return conn_->Disconnect(); }; + + /// Submit a task using the raylet code path. + /// + /// \param The execution dependencies. + /// \param The task specification. + /// \return ray::Status. + ray::Status SubmitTask(const std::vector &execution_dependencies, + const ray::raylet::TaskSpecification &task_spec); + + /// Get next task for this client. This will block until the scheduler assigns + /// a task to this worker. The caller takes ownership of the returned task + /// specification and must free it. + /// + /// \param task_spec The assigned task. + /// \return ray::Status. + ray::Status GetTask(std::unique_ptr *task_spec); + + /// Tell the raylet that the client has finished executing a task. + /// + /// \return ray::Status. + ray::Status TaskDone(); + + /// Tell the raylet to reconstruct or fetch objects. + /// + /// \param object_ids The IDs of the objects to reconstruct. + /// \param fetch_only Only fetch objects, do not reconstruct them. + /// \param current_task_id The task that needs the objects. + /// \return int 0 means correct, other numbers mean error. + ray::Status FetchOrReconstruct(const std::vector &object_ids, bool fetch_only, + const TaskID ¤t_task_id); + /// Notify the raylet that this client (worker) is no longer blocked. + /// + /// \param current_task_id The task that is no longer blocked. + /// \return ray::Status. + ray::Status NotifyUnblocked(const TaskID ¤t_task_id); + + /// Wait for the given objects until timeout expires or num_return objects are + /// found. + /// + /// \param object_ids The objects to wait for. + /// \param num_returns The number of objects to wait for. + /// \param timeout_milliseconds Duration, in milliseconds, to wait before returning. + /// \param wait_local Whether to wait for objects to appear on this node. + /// \param current_task_id The task that called wait. + /// \param result A pair with the first element containing the object ids that were + /// found, and the second element the objects that were not found. + /// \return ray::Status. + ray::Status Wait(const std::vector &object_ids, int num_returns, + int64_t timeout_milliseconds, bool wait_local, + const TaskID ¤t_task_id, WaitResultPair *result); + + /// Push an error to the relevant driver. + /// + /// \param The ID of the job that the error is for. + /// \param The type of the error. + /// \param The error message. + /// \param The timestamp of the error. + /// \return ray::Status. + ray::Status PushError(const JobID &job_id, const std::string &type, + const std::string &error_message, double timestamp); + + /// Store some profile events in the GCS. + /// + /// \param profile_events A batch of profiling event information. + /// \return ray::Status. + ray::Status PushProfileEvents(const ProfileTableDataT &profile_events); + + /// Free a list of objects from object stores. + /// + /// \param object_ids A list of ObjectsIDs to be deleted. + /// \param local_only Whether keep this request with local object store + /// or send it to all the object stores. + /// \return ray::Status. + ray::Status FreeObjects(const std::vector &object_ids, bool local_only); + + Language GetLanguage() const { return language_; } + + JobID GetClientID() const { return client_id_; } + + JobID GetDriverID() const { return driver_id_; } + + bool IsWorker() const { return is_worker_; } + + const ResourceMappingType &GetResourceIDs() const { return resource_ids_; } + + private: + const UniqueID client_id_; + const bool is_worker_; + const JobID driver_id_; + const Language language_; + /// A map from resource name to the resource IDs that are currently reserved + /// for this worker. Each pair consists of the resource ID and the fraction + /// of that resource allocated for this worker. + ResourceMappingType resource_ids_; + /// The connection to the raylet server. + std::unique_ptr conn_; +}; + +#endif diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 6abf5b53d8246..d698402994a45 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -6,7 +6,7 @@ namespace raylet { ReconstructionPolicy::ReconstructionPolicy( boost::asio::io_service &io_service, - std::function reconstruction_handler, + std::function reconstruction_handler, int64_t initial_reconstruction_timeout_ms, const ClientID &client_id, gcs::PubsubInterface &task_lease_pubsub, std::shared_ptr object_directory, @@ -74,13 +74,14 @@ void ReconstructionPolicy::HandleReconstructionLogAppend(const TaskID &task_id, SetTaskTimeout(it, initial_reconstruction_timeout_ms_); if (success) { - reconstruction_handler_(task_id); + reconstruction_handler_(task_id, it->second.return_values_lost); } } void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, const ObjectID &required_object_id, - int reconstruction_attempt) { + int reconstruction_attempt, + bool created) { // If we are no longer listening for objects created by this task, give up. auto it = listening_tasks_.find(task_id); if (it == listening_tasks_.end()) { @@ -92,6 +93,10 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, return; } + if (created) { + it->second.return_values_lost = true; + } + // Suppress duplicate reconstructions of the same task. This can happen if, // for example, a task creates two different objects that both require // reconstruction. @@ -138,12 +143,13 @@ void ReconstructionPolicy::HandleTaskLeaseExpired(const TaskID &task_id) { for (const auto &created_object_id : it->second.created_objects) { RAY_CHECK_OK(object_directory_->LookupLocations( created_object_id, - [this, task_id, reconstruction_attempt](const std::vector &clients, - const ray::ObjectID &object_id) { + [this, task_id, reconstruction_attempt]( + const ray::ObjectID &object_id, + const std::unordered_set &clients, bool created) { if (clients.empty()) { // The required object no longer exists on any live nodes. Attempt // reconstruction. - AttemptReconstruction(task_id, object_id, reconstruction_attempt); + AttemptReconstruction(task_id, object_id, reconstruction_attempt, created); } })); } diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index f18290aa37254..d936a632e1f1c 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -40,7 +40,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { /// lease notifications from. ReconstructionPolicy( boost::asio::io_service &io_service, - std::function reconstruction_handler, + std::function reconstruction_handler, int64_t initial_reconstruction_timeout_ms, const ClientID &client_id, gcs::PubsubInterface &task_lease_pubsub, std::shared_ptr object_directory, @@ -93,6 +93,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { bool subscribed; // The number of times we've attempted reconstructing this task so far. int reconstruction_attempt; + bool return_values_lost; // The task's reconstruction timer. If this expires before a lease // notification is received, then the task will be reconstructed. std::unique_ptr reconstruction_timer; @@ -115,7 +116,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { /// reconstructions of the same task (e.g., if a task creates two objects /// that both require reconstruction). void AttemptReconstruction(const TaskID &task_id, const ObjectID &required_object_id, - int reconstruction_attempt); + int reconstruction_attempt, bool created); /// Handle expiration of a task lease. void HandleTaskLeaseExpired(const TaskID &task_id); @@ -127,7 +128,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { /// The event loop. boost::asio::io_service &io_service_; /// The handler to call for tasks that require reconstruction. - const std::function reconstruction_handler_; + const std::function reconstruction_handler_; /// The initial timeout within which a task lease notification must be /// received. Otherwise, reconstruction will be triggered. const int64_t initial_reconstruction_timeout_ms_; diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 9f1499c31664a..5e9ae6d7e5218 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -19,7 +19,7 @@ class MockObjectDirectory : public ObjectDirectoryInterface { MockObjectDirectory() {} ray::Status LookupLocations(const ObjectID &object_id, - const OnLocationsFound &callback) { + const OnLocationsFound &callback) override { callbacks_.push_back({object_id, callback}); return ray::Status::OK(); } @@ -27,16 +27,29 @@ class MockObjectDirectory : public ObjectDirectoryInterface { void FlushCallbacks() { for (const auto &callback : callbacks_) { const ObjectID object_id = callback.first; - callback.second(locations_[object_id], object_id); + auto it = locations_.find(object_id); + if (it == locations_.end()) { + callback.second(object_id, std::unordered_set(), + /*created=*/false); + } else { + callback.second(object_id, it->second, /*created=*/true); + } } callbacks_.clear(); } - void SetObjectLocations(const ObjectID &object_id, std::vector locations) { + void SetObjectLocations(const ObjectID &object_id, + const std::unordered_set &locations) { locations_[object_id] = locations; } - std::string DebugString() const { return ""; } + void HandleClientRemoved(const ClientID &client_id) override { + for (auto &locations : locations_) { + locations.second.erase(client_id); + } + } + + std::string DebugString() const override { return ""; } MOCK_METHOD0(RegisterBackend, void(void)); MOCK_METHOD0(GetLocalClientID, ray::ClientID()); @@ -54,7 +67,7 @@ class MockObjectDirectory : public ObjectDirectoryInterface { private: std::vector> callbacks_; - std::unordered_map> locations_; + std::unordered_map> locations_; }; class MockGcs : public gcs::PubsubInterface, @@ -138,8 +151,8 @@ class ReconstructionPolicyTest : public ::testing::Test { mock_object_directory_(std::make_shared()), reconstruction_timeout_ms_(50), reconstruction_policy_(std::make_shared( - io_service_, - [this](const TaskID &task_id) { TriggerReconstruction(task_id); }, + io_service_, [this](const TaskID &task_id, + bool created) { TriggerReconstruction(task_id); }, reconstruction_timeout_ms_, ClientID::from_random(), mock_gcs_, mock_object_directory_, mock_gcs_)), timer_canceled_(false) { @@ -242,7 +255,32 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { ASSERT_EQ(reconstructed_tasks_[task_id], 0); // Simulate evicting one of the objects. - mock_object_directory_->SetObjectLocations(object_id, {}); + mock_object_directory_->SetObjectLocations(object_id, + std::unordered_set()); + // Run the test again. + Run(reconstruction_timeout_ms_ * 1.1); + // Check that reconstruction was triggered, since one of the objects was + // evicted. + ASSERT_EQ(reconstructed_tasks_[task_id], 1); +} + +TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { + TaskID task_id = TaskID::from_random(); + task_id = FinishTaskId(task_id); + ObjectID object_id = ComputeReturnId(task_id, 1); + ClientID client_id = ClientID::from_random(); + mock_object_directory_->SetObjectLocations(object_id, {client_id}); + + // Listen for both objects. + reconstruction_policy_->ListenAndMaybeReconstruct(object_id); + // Run the test for longer than the reconstruction timeout. + Run(reconstruction_timeout_ms_ * 1.1); + // Check that reconstruction was not triggered, since the objects still + // exist on a live node. + ASSERT_EQ(reconstructed_tasks_[task_id], 0); + + // Simulate evicting one of the objects. + mock_object_directory_->HandleClientRemoved(client_id); // Run the test again. Run(reconstruction_timeout_ms_ * 1.1); // Check that reconstruction was triggered, since one of the objects was diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index b9045e57912e1..47c353a9a4c05 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -36,8 +36,7 @@ inline void QueueTasks(TaskQueue &queue, const std::vector &t // Helper function to filter out tasks of a given state. template inline void FilterStateFromQueue(const TaskQueue &queue, - std::unordered_set &task_ids, - ray::raylet::TaskState filter_state) { + std::unordered_set &task_ids) { for (auto it = task_ids.begin(); it != task_ids.end();) { if (queue.HasTask(*it)) { it = task_ids.erase(it); @@ -173,16 +172,19 @@ void SchedulingQueue::FilterState(std::unordered_set &task_ids, TaskState filter_state) const { switch (filter_state) { case TaskState::PLACEABLE: - FilterStateFromQueue(placeable_tasks_, task_ids, filter_state); + FilterStateFromQueue(placeable_tasks_, task_ids); + break; + case TaskState::WAITING_FOR_ACTOR_CREATION: + FilterStateFromQueue(methods_waiting_for_actor_creation_, task_ids); break; case TaskState::WAITING: - FilterStateFromQueue(waiting_tasks_, task_ids, filter_state); + FilterStateFromQueue(waiting_tasks_, task_ids); break; case TaskState::READY: - FilterStateFromQueue(ready_tasks_, task_ids, filter_state); + FilterStateFromQueue(ready_tasks_, task_ids); break; case TaskState::RUNNING: - FilterStateFromQueue(running_tasks_, task_ids, filter_state); + FilterStateFromQueue(running_tasks_, task_ids); break; case TaskState::BLOCKED: { const auto blocked_ids = GetBlockedTaskIds(); @@ -195,7 +197,7 @@ void SchedulingQueue::FilterState(std::unordered_set &task_ids, } } break; case TaskState::INFEASIBLE: - FilterStateFromQueue(infeasible_tasks_, task_ids, filter_state); + FilterStateFromQueue(infeasible_tasks_, task_ids); break; case TaskState::DRIVER: { const auto driver_ids = GetDriverTaskIds(); diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index bdd065fa176be..dad94a6d3b69c 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -18,6 +18,8 @@ enum class TaskState { INIT, // The task may be placed on a node. PLACEABLE, + // The task is for an actor whose location we do not know yet. + WAITING_FOR_ACTOR_CREATION, // The task has been placed on a node and is waiting for some object // dependencies to become local. WAITING, diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 56a97dcb0a82f..1e05283172320 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -74,9 +74,10 @@ static inline Task ExampleTask(const std::vector &arguments, std::vector references = {argument}; task_arguments.emplace_back(std::make_shared(references)); } + std::vector function_descriptor(3); auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0, - UniqueID::from_random(), task_arguments, num_returns, - required_resources, Language::PYTHON); + task_arguments, num_returns, required_resources, + Language::PYTHON, function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); execution_spec.IncrementNumForwards(); Task task = Task(execution_spec, spec); diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 0a914adcb0c66..7e33a9acf29b5 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -56,24 +56,24 @@ TaskSpecification::TaskSpecification(const std::string &string) { TaskSpecification::TaskSpecification( const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, - const FunctionID &function_id, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, - const Language &language) + const Language &language, const std::vector &function_descriptor) : TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::nil(), - ObjectID::nil(), ActorID::nil(), ActorHandleID::nil(), -1, - function_id, task_arguments, num_returns, required_resources, - std::unordered_map(), language) {} + ObjectID::nil(), 0, ActorID::nil(), ActorHandleID::nil(), -1, + task_arguments, num_returns, required_resources, + std::unordered_map(), language, + function_descriptor) {} TaskSpecification::TaskSpecification( const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, - const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, - const FunctionID &function_id, + const int64_t max_actor_reconstructions, const ActorID &actor_id, + const ActorHandleID &actor_handle_id, int64_t actor_counter, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language) + const Language &language, const std::vector &function_descriptor) : spec_() { flatbuffers::FlatBufferBuilder fbb; @@ -96,11 +96,12 @@ TaskSpecification::TaskSpecification( auto spec = CreateTaskInfo( fbb, to_flatbuf(fbb, driver_id), to_flatbuf(fbb, task_id), to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), - to_flatbuf(fbb, actor_creation_dummy_object_id), to_flatbuf(fbb, actor_id), - to_flatbuf(fbb, actor_handle_id), actor_counter, false, - to_flatbuf(fbb, function_id), fbb.CreateVector(arguments), - fbb.CreateVector(returns), map_to_flatbuf(fbb, required_resources), - map_to_flatbuf(fbb, required_placement_resources), language); + to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions, + to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, false, + fbb.CreateVector(arguments), fbb.CreateVector(returns), + map_to_flatbuf(fbb, required_resources), + map_to_flatbuf(fbb, required_placement_resources), language, + string_vec_to_flatbuf(fbb, function_descriptor)); fbb.Finish(spec); AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } @@ -133,9 +134,9 @@ int64_t TaskSpecification::ParentCounter() const { auto message = flatbuffers::GetRoot(spec_.data()); return message->parent_counter(); } -FunctionID TaskSpecification::FunctionId() const { +std::vector TaskSpecification::FunctionDescriptor() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->function_id()); + return string_vec_from_flatbuf(*message->function_descriptor()); } int64_t TaskSpecification::NumArgs() const { @@ -196,7 +197,7 @@ const ResourceSet TaskSpecification::GetRequiredPlacementResources() const { bool TaskSpecification::IsDriverTask() const { // Driver tasks are empty tasks that have no function ID set. - return FunctionId().is_nil(); + return FunctionDescriptor().empty(); } Language TaskSpecification::GetLanguage() const { @@ -220,6 +221,11 @@ ObjectID TaskSpecification::ActorCreationDummyObjectId() const { return from_flatbuf(*message->actor_creation_dummy_object_id()); } +int64_t TaskSpecification::MaxActorReconstructions() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return message->max_actor_reconstructions(); +} + ActorID TaskSpecification::ActorId() const { auto message = flatbuffers::GetRoot(spec_.data()); return from_flatbuf(*message->actor_id()); diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 5a86a443c78e3..da33275e9537d 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -91,17 +91,18 @@ class TaskSpecification { /// \param parent_task_id The task ID of the task that spawned this task. /// \param parent_counter The number of tasks that this task's parent spawned /// before this task. - /// \param function_id The ID of the function this task should execute. + /// \param function_descriptor The function descriptor. /// \param task_arguments The list of task arguments. /// \param num_returns The number of values returned by the task. /// \param required_resources The task's resource demands. /// \param language The language of the worker that must execute the function. TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id, - int64_t parent_counter, const FunctionID &function_id, + int64_t parent_counter, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, - const Language &language); + const Language &language, + const std::vector &function_descriptor); // TODO(swang): Define an actor task constructor. /// Create a task specification from the raw fields. @@ -119,7 +120,6 @@ class TaskSpecification { /// task. If this is not an actor task, then this is nil. /// \param actor_counter The number of tasks submitted before this task from /// the same actor handle. If this is not an actor task, then this is 0. - /// \param function_id The ID of the function this task should execute. /// \param task_arguments The list of task arguments. /// \param num_returns The number of values returned by the task. /// \param required_resources The task's resource demands. @@ -127,16 +127,17 @@ class TaskSpecification { /// task on a node. Typically, this should be an empty map in which case it /// will default to be equal to the required_resources argument. /// \param language The language of the worker that must execute the function. + /// \param function_descriptor The function descriptor. TaskSpecification( const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, - const ActorID &actor_id, const ActorHandleID &actor_handle_id, - int64_t actor_counter, const FunctionID &function_id, + int64_t max_actor_reconstructions, const ActorID &actor_id, + const ActorHandleID &actor_handle_id, int64_t actor_counter, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language); + const Language &language, const std::vector &function_descriptor); /// Deserialize a task specification from a flatbuffer's string data. /// @@ -158,7 +159,7 @@ class TaskSpecification { UniqueID DriverId() const; TaskID ParentTaskId() const; int64_t ParentCounter() const; - FunctionID FunctionId() const; + std::vector FunctionDescriptor() const; int64_t NumArgs() const; int64_t NumReturns() const; bool ArgByRef(int64_t arg_index) const; @@ -192,6 +193,7 @@ class TaskSpecification { bool IsActorTask() const; ActorID ActorCreationId() const; ObjectID ActorCreationDummyObjectId() const; + int64_t MaxActorReconstructions() const; ActorID ActorId() const; ActorHandleID ActorHandleId() const; int64_t ActorCounter() const; diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 5a2342514d01a..c6686e4b6f6bb 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -15,8 +15,6 @@ Worker::Worker(pid_t pid, const Language &language, : pid_(pid), language_(language), connection_(connection), - assigned_task_id_(TaskID::nil()), - actor_id_(ActorID::nil()), dead_(false), blocked_(false) {} diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index abaf675ff625c..9b228457b2b5e 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -63,9 +63,10 @@ class WorkerPoolTest : public ::testing::Test { static inline TaskSpecification ExampleTaskSpec( const ActorID actor_id = ActorID::nil(), const Language &language = Language::PYTHON) { + std::vector function_descriptor(3); return TaskSpecification(UniqueID::nil(), UniqueID::nil(), 0, ActorID::nil(), - ObjectID::nil(), actor_id, ActorHandleID::nil(), 0, - FunctionID::nil(), {}, 0, {{}}, {{}}, language); + ObjectID::nil(), 0, actor_id, ActorHandleID::nil(), 0, {}, 0, + {{}}, {{}}, language, function_descriptor); } TEST_F(WorkerPoolTest, HandleWorkerRegistration) { diff --git a/test/actor_test.py b/test/actor_test.py index 1586f8827ff0d..9a42842d68ed5 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -8,18 +8,25 @@ import numpy as np import os import pytest +import signal import sys import time import ray import ray.ray_constants as ray_constants import ray.test.test_utils +import ray.test.cluster_utils @pytest.fixture def ray_start_regular(): # Start the Ray processes. - ray.init(num_cpus=1) + ray.init( + num_cpus=1, + _internal_config=json.dumps({ + "initial_reconstruction_timeout_milliseconds": 200, + "num_heartbeats_timeout": 10, + })) yield None # The code after the yield will run as teardown code. ray.shutdown() @@ -32,6 +39,23 @@ def shutdown_only(): ray.shutdown() +@pytest.fixture +def head_node_cluster(): + cluster = ray.test.cluster_utils.Cluster( + initialize_head=True, + connect=True, + head_node_args={ + "_internal_config": json.dumps({ + "initial_reconstruction_timeout_milliseconds": 200, + "num_heartbeats_timeout": 10, + }) + }) + yield cluster + # The code after the yield will run as teardown code. + ray.shutdown() + cluster.shutdown() + + def test_actor_init_error_propagated(ray_start_regular): @ray.remote class Actor(object): @@ -913,8 +937,8 @@ def get_location_and_ids(self): def test_actor_multiple_gpus_from_multiple_tasks(shutdown_only): - num_local_schedulers = 10 - num_gpus_per_scheduler = 10 + num_local_schedulers = 5 + num_gpus_per_scheduler = 5 ray.worker._init( start_ray_local=True, num_local_schedulers=num_local_schedulers, @@ -1259,15 +1283,8 @@ def blocking_method(self): assert remaining_ids == [x_id] -def test_exception_raised_when_actor_node_dies(shutdown_only): - ray.worker._init( - start_ray_local=True, - num_local_schedulers=2, - num_cpus=1, - _internal_config=json.dumps({ - "initial_reconstruction_timeout_milliseconds": 200, - "num_heartbeats_timeout": 10, - })) +def test_exception_raised_when_actor_node_dies(head_node_cluster): + remote_node = head_node_cluster.add_node() @ray.remote class Counter(object): @@ -1281,18 +1298,14 @@ def inc(self): self.x += 1 return self.x - local_plasma = ray.worker.global_worker.plasma_client.store_socket_name - # Create an actor that is not on the local scheduler. actor = Counter.remote() - while ray.get(actor.local_plasma.remote()) == local_plasma: + while (ray.get(actor.local_plasma.remote()) != + remote_node.get_plasma_store_name()): actor = Counter.remote() - # Kill the second plasma store to get rid of the cached objects and - # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() + # Kill the second node. + head_node_cluster.remove_node(remote_node) # Submit some new actor tasks both before and after the node failure is # detected. Make sure that getting the result raises an exception. @@ -1300,132 +1313,74 @@ def inc(self): # Submit some new actor tasks. x_ids = [actor.inc.remote() for _ in range(5)] for x_id in x_ids: - with pytest.raises(ray.worker.RayGetError): + with pytest.raises(ray.worker.RayTaskError): # There is some small chance that ray.get will actually # succeed (if the object is transferred before the raylet # dies). ray.get(x_id) - # Make sure the process has exited. - process.wait() - -@pytest.mark.skip("This test does not work yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") -def test_local_scheduler_dying(shutdown_only): - ray.worker._init( - start_ray_local=True, - num_local_schedulers=2, - num_cpus=1, - redirect_output=True) +def test_actor_init_fails(head_node_cluster): + remote_node = head_node_cluster.add_node() - @ray.remote + @ray.remote(max_reconstructions=1) class Counter(object): def __init__(self): self.x = 0 - def local_plasma(self): - return ray.worker.global_worker.plasma_client.store_socket_name - def inc(self): self.x += 1 return self.x - local_plasma = ray.worker.global_worker.plasma_client.store_socket_name - - # Create an actor that is not on the local scheduler. - actor = Counter.remote() - while ray.get(actor.local_plasma.remote()) == local_plasma: - actor = Counter.remote() - - ids = [actor.inc.remote() for _ in range(100)] - - # Wait for the last task to finish running. - ray.get(ids[-1]) - - # Kill the second plasma store to get rid of the cached objects and - # trigger the corresponding local scheduler to exit. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][1] - process.kill() - process.wait() + # Create many actors. It should take a while to finish initializing them. + actors = [Counter.remote() for _ in range(100)] + # Allow some time to forward the actor creation tasks to the other node. + time.sleep(0.1) + # Kill the second node. + head_node_cluster.remove_node(remote_node) # Get all of the results - results = ray.get(ids) - - assert results == list(range(1, 1 + len(results))) + results = ray.get([actor.inc.remote() for actor in actors]) + assert results == [1 for actor in actors] -@pytest.mark.skip("This test does not work yet.") -@pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") == "on", - reason="Hanging with new GCS API.") -def test_many_local_schedulers_dying(shutdown_only): - # This test can be made more stressful by increasing the numbers below. - # The total number of actors created will be - # num_actors_at_a_time * num_local_schedulers. - num_local_schedulers = 5 - num_actors_at_a_time = 3 - num_function_calls_at_a_time = 10 - - ray.worker._init( - start_ray_local=True, - num_local_schedulers=num_local_schedulers, - num_cpus=3, - redirect_output=True) +def test_reconstruction_suppression(head_node_cluster): + num_local_schedulers = 10 + worker_nodes = [ + head_node_cluster.add_node() for _ in range(num_local_schedulers) + ] - @ray.remote - class SlowCounter(object): + @ray.remote(max_reconstructions=1) + class Counter(object): def __init__(self): self.x = 0 - def inc(self, duration): - time.sleep(duration) + def inc(self): self.x += 1 return self.x - # Create some initial actors. - actors = [SlowCounter.remote() for _ in range(num_actors_at_a_time)] + @ray.remote + def inc(actor_handle): + return ray.get(actor_handle.inc.remote()) - # Wait for the actors to start up. - time.sleep(1) + # Make sure all of the actors have started. + actors = [Counter.remote() for _ in range(20)] + ray.get([actor.inc.remote() for actor in actors]) - # This is a mapping from actor handles to object IDs returned by - # methods on that actor. - result_ids = collections.defaultdict(lambda: []) + # Kill a node. + head_node_cluster.remove_node(worker_nodes[0]) - # In a loop we are going to create some actors, run some methods, kill - # a local scheduler, and run some more methods. - for i in range(num_local_schedulers - 1): - # Create some actors. - actors.extend( - [SlowCounter.remote() for _ in range(num_actors_at_a_time)]) - # Run some methods. - for j in range(len(actors)): - actor = actors[j] - for _ in range(num_function_calls_at_a_time): - result_ids[actor].append(actor.inc.remote(j**2 * 0.000001)) - # Kill a plasma store to get rid of the cached objects and trigger - # exit of the corresponding local scheduler. Don't kill the first - # local scheduler since that is the one that the driver is - # connected to. - process = ray.services.all_processes[ - ray.services.PROCESS_TYPE_PLASMA_STORE][i + 1] - process.kill() - process.wait() - - # Run some more methods. - for j in range(len(actors)): - actor = actors[j] - for _ in range(num_function_calls_at_a_time): - result_ids[actor].append(actor.inc.remote(j**2 * 0.000001)) - - # Get the results and check that they have the correct values. - for _, result_id_list in result_ids.items(): - results = list(range(1, len(result_id_list) + 1)) - assert ray.get(result_id_list) == results + # Submit several tasks per actor. These should be randomly scheduled to the + # nodes, so that multiple nodes will detect and try to reconstruct the + # actor that died, but only one should succeed. + results = [] + for _ in range(10): + results += [inc.remote(actor) for actor in actors] + # Make sure that we can get the results from the reconstructed actor. + results = ray.get(results) def setup_counter_actor(test_checkpoint=False, @@ -2142,3 +2097,215 @@ def method(self): ray.wait([object_id]) ray.get(results) + + +def test_actor_eviction(shutdown_only): + @ray.remote + class Actor(object): + def __init__(self): + pass + + def create_object(self, size): + return np.random.rand(size) + + object_store_memory = 10**8 + ray.init( + object_store_memory=object_store_memory, + _internal_config=json.dumps({ + "initial_reconstruction_timeout_milliseconds": 200 + })) + + a = Actor.remote() + # Submit enough methods on the actor so that they exceed the size of the + # object store. + objects = [] + num_objects = 20 + for _ in range(num_objects): + obj = a.create_object.remote(object_store_memory // num_objects) + objects.append(obj) + # Get each object once to make sure each object gets created. + ray.get(obj) + + # Get each object again. At this point, the earlier objects should have + # been evicted. + num_evicted, num_success = 0, 0 + for obj in objects: + try: + ray.get(obj) + num_success += 1 + except ray.worker.RayTaskError: + num_evicted += 1 + # Some objects should have been evicted, and some should still be in the + # object store. + assert num_evicted > 0 + assert num_success > 0 + + +def test_actor_reconstruction(ray_start_regular): + """Test actor reconstruction when actor process is killed.""" + + @ray.remote(max_reconstructions=1) + class ReconstructableActor(object): + """An actor that will be reconstructed at most once.""" + + def __init__(self): + self.value = 0 + + def increase(self): + self.value += 1 + return self.value + + def get_pid(self): + return os.getpid() + + def kill_actor(actor): + """Kill actor process.""" + pid = ray.get(actor.get_pid.remote()) + os.kill(pid, signal.SIGKILL) + time.sleep(1) + + actor = ReconstructableActor.remote() + # Call increase 3 times + for _ in range(3): + ray.get(actor.increase.remote()) + # kill actor process + kill_actor(actor) + # Call increase again. + # Check that actor is reconstructed and value is 4. + assert ray.get(actor.increase.remote()) == 4 + # kill actor process one more time. + kill_actor(actor) + # The actor has exceeded max reconstructions, and this task should fail. + with pytest.raises(ray.worker.RayTaskError): + ray.get(actor.increase.remote()) + + # Create another actor. + actor = ReconstructableActor.remote() + # Intentionlly exit the actor + actor.__ray_terminate__.remote() + # Check that the actor won't be reconstructed. + with pytest.raises(ray.worker.RayTaskError): + ray.get(actor.increase.remote()) + + +def test_actor_reconstruction_on_node_failure(head_node_cluster): + """Test actor reconstruction when node dies unexpectedly.""" + cluster = head_node_cluster + max_reconstructions = 3 + # Add a few nodes to the cluster. + # Use custom resource to make sure the actor is only created on worker + # nodes, not on the head node. + for _ in range(max_reconstructions + 2): + cluster.add_node( + resources={"a": 1}, + _internal_config=json.dumps({ + "initial_reconstruction_timeout_milliseconds": 200, + "num_heartbeats_timeout": 10, + }), + ) + + def kill_node(object_store_socket): + node_to_remove = None + for node in cluster.worker_nodes: + if object_store_socket == node.get_plasma_store_name(): + node_to_remove = node + cluster.remove_node(node_to_remove) + + @ray.remote(max_reconstructions=max_reconstructions, resources={"a": 1}) + class MyActor(object): + def __init__(self): + self.value = 0 + + def increase(self): + self.value += 1 + return self.value + + def get_object_store_socket(self): + return ray.worker.global_worker.plasma_client.store_socket_name + + actor = MyActor.remote() + # Call increase 3 times. + for _ in range(3): + ray.get(actor.increase.remote()) + + for i in range(max_reconstructions): + object_store_socket = ray.get(actor.get_object_store_socket.remote()) + # Kill actor's node and the actor should be reconstructed + # on a different node. + kill_node(object_store_socket) + # Call increase again. + # Check that the actor is reconstructed and value is correct. + assert ray.get(actor.increase.remote()) == 4 + i + # Check that the actor is now on a different node. + assert object_store_socket != ray.get( + actor.get_object_store_socket.remote()) + + # kill the node again. + object_store_socket = ray.get(actor.get_object_store_socket.remote()) + kill_node(object_store_socket) + # The actor has exceeded max reconstructions, and this task should fail. + with pytest.raises(ray.worker.RayTaskError): + ray.get(actor.increase.remote()) + + +def test_multiple_actor_reconstruction(head_node_cluster): + # This test can be made more stressful by increasing the numbers below. + # The total number of actors created will be + # num_actors_at_a_time * num_local_schedulers. + num_local_schedulers = 5 + num_actors_at_a_time = 3 + num_function_calls_at_a_time = 10 + + worker_nodes = [ + head_node_cluster.add_node( + resources={"CPU": 3}, + _internal_config=json.dumps({ + "initial_reconstruction_timeout_milliseconds": 200, + "num_heartbeats_timeout": 10, + })) for _ in range(num_local_schedulers) + ] + + @ray.remote(max_reconstructions=ray.ray_constants.INFINITE_RECONSTRUCTION) + class SlowCounter(object): + def __init__(self): + self.x = 0 + + def inc(self, duration): + time.sleep(duration) + self.x += 1 + return self.x + + # Create some initial actors. + actors = [SlowCounter.remote() for _ in range(num_actors_at_a_time)] + + # Wait for the actors to start up. + time.sleep(1) + + # This is a mapping from actor handles to object IDs returned by + # methods on that actor. + result_ids = collections.defaultdict(lambda: []) + + # In a loop we are going to create some actors, run some methods, kill + # a local scheduler, and run some more methods. + for node in worker_nodes: + # Create some actors. + actors.extend( + [SlowCounter.remote() for _ in range(num_actors_at_a_time)]) + # Run some methods. + for j in range(len(actors)): + actor = actors[j] + for _ in range(num_function_calls_at_a_time): + result_ids[actor].append(actor.inc.remote(j**2 * 0.000001)) + # Kill a node. + head_node_cluster.remove_node(node) + + # Run some more methods. + for j in range(len(actors)): + actor = actors[j] + for _ in range(num_function_calls_at_a_time): + result_ids[actor].append(actor.inc.remote(j**2 * 0.000001)) + + # Get the results and check that they have the correct values. + for _, result_id_list in result_ids.items(): + results = list(range(1, len(result_id_list) + 1)) + assert ray.get(result_id_list) == results diff --git a/test/component_failures_test.py b/test/component_failures_test.py index fd09a17599cfa..30071b3c1917c 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -408,7 +408,7 @@ def ping(self): for i, out in enumerate(children_out): try: ray.get(out) - except ray.worker.RayGetError: + except ray.worker.RayTaskError: children[i] = Child.remote(death_probability) # Remove a node. Any actor creation tasks that were forwarded to this # node must be reconstructed. diff --git a/test/failure_test.py b/test/failure_test.py index 027ed38d64111..3efb9bc69a7a8 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -3,13 +3,16 @@ from __future__ import print_function import numpy as np +import json import os import ray import sys import tempfile +import threading import time import ray.ray_constants as ray_constants +from ray.utils import _random_string import pytest @@ -570,7 +573,13 @@ class Foo(object): @pytest.fixture def ray_start_two_nodes(): # Start the Ray processes. - ray.worker._init(start_ray_local=True, num_local_schedulers=2, num_cpus=0) + ray.worker._init( + start_ray_local=True, + num_local_schedulers=2, + num_cpus=0, + _internal_config=json.dumps({ + "num_heartbeats_timeout": 40 + })) yield None # The code after the yield will run as teardown code. ray.shutdown() @@ -594,7 +603,7 @@ def test_warning_for_dead_node(ray_start_two_nodes): ray.services.all_processes[ray.services.PROCESS_TYPE_RAYLET][0].kill() # Check that we get warning messages for both raylets. - wait_for_errors(ray_constants.REMOVED_NODE_ERROR, 2, timeout=20) + wait_for_errors(ray_constants.REMOVED_NODE_ERROR, 2, timeout=40) # Extract the client IDs from the error messages. This will need to be # changed if the error message changes. @@ -604,3 +613,18 @@ def test_warning_for_dead_node(ray_start_two_nodes): } assert client_ids == warning_client_ids + + +def test_raylet_crash_when_get(ray_start_regular): + nonexistent_id = ray.ObjectID(_random_string()) + + def sleep_to_kill_raylet(): + # Don't kill raylet before default workers get connected. + time.sleep(2) + ray.services.all_processes[ray.services.PROCESS_TYPE_RAYLET][0].kill() + + thread = threading.Thread(target=sleep_to_kill_raylet) + thread.start() + with pytest.raises(Exception, match=r".*raylet client may be closed.*"): + ray.get(nonexistent_id) + thread.join() diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 9b8d9295eae33..06a927b1873db 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -6,54 +6,57 @@ set -e # Show explicitly which commands are currently running. set -x +MEMORY_SIZE="20G" +SHM_SIZE="20G" + ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) DOCKER_SHA=$($ROOT_DIR/../../build-docker.sh --output-sha --no-cache) echo "Using Docker image" $DOCKER_SHA -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v0 \ --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pong-ram-v4 \ --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v0 \ --run A2C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "model": {"free_log_std": true}}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"simple_optimizer": false, "num_sgd_iter": 2, "model": {"use_lstm": true}}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"simple_optimizer": true, "num_sgd_iter": 2, "model": {"use_lstm": true}}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run PPO \ @@ -61,180 +64,180 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --config '{"num_gpus": 0.1}' \ --ray-num-gpus 1 -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false, "batch_mode": "complete_episodes"}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pendulum-v0 \ --run ES \ --stop '{"training_iteration": 2}' \ --config '{"stepsize": 0.01, "episodes_per_batch": 20, "train_batch_size": 100, "num_workers": 2}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pong-v0 \ --run ES \ --stop '{"training_iteration": 2}' \ --config '{"stepsize": 0.01, "episodes_per_batch": 20, "train_batch_size": 100, "num_workers": 2}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run A3C \ --stop '{"training_iteration": 2}' \ -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"lr": 1e-3, "schedule_max_timesteps": 100000, "exploration_fraction": 0.1, "exploration_final_eps": 0.02, "dueling": false, "hiddens": [], "model": {"fcnet_hiddens": [64], "fcnet_activation": "relu"}}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run APEX \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "timesteps_per_iteration": 1000, "num_gpus": 0, "min_iter_time_s": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env FrozenLake-v0 \ --run DQN \ --stop '{"training_iteration": 2}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env FrozenLake-v0 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"num_sgd_iter": 10, "sgd_minibatch_size": 64, "train_batch_size": 1000, "num_workers": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v4 \ --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"lr": 1e-4, "schedule_max_timesteps": 2000000, "buffer_size": 10000, "exploration_fraction": 0.1, "exploration_final_eps": 0.01, "sample_batch_size": 4, "learning_starts": 10000, "target_network_update_freq": 1000, "gamma": 0.99, "prioritized_replay": true}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env MontezumaRevenge-v0 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "model": {"use_lstm": true}}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1, "model": {"use_lstm": true, "max_seq_len": 100}}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1, "num_envs_per_worker": 10}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pong-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env FrozenLake-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pendulum-v0 \ --run DDPG \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "model": {"use_lstm": true}}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ - --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "num_parallel_data_loaders": 2, "replay_proportion": 1.0}' + --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "num_data_loader_buffers": 2, "replay_proportion": 1.0}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ - --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "num_parallel_data_loaders": 2, "replay_proportion": 1.0, "model": {"use_lstm": true}}' + --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "num_data_loader_buffers": 2, "replay_proportion": 1.0, "model": {"use_lstm": true}}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env MountainCarContinuous-v0 \ --run DDPG \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ rllib train \ --env MountainCarContinuous-v0 \ --run DDPG \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pendulum-v0 \ --run APEX_DDPG \ @@ -242,154 +245,188 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "optimizer": {"num_replay_buffer_shards": 1}, "learning_starts": 100, "min_iter_time_s": 1}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_local.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/test/test_io.py + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_checkpoint_restore.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_policy_evaluator.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_nested_spaces.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_external_env.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/parametric_action_cartpole.py --run=PG --stop=50 -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/parametric_action_cartpole.py --run=PPO --stop=50 -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/parametric_action_cartpole.py --run=DQN --stop=50 -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_lstm.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/batch_norm_model.py --num-iters=1 --run=PPO -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/batch_norm_model.py --num-iters=1 --run=PG -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/batch_norm_model.py --num-iters=1 --run=DQN -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/batch_norm_model.py --num-iters=1 --run=DDPG -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_multi_agent_env.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_supported_spaces.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ pytest /ray/python/ray/tune/test/cluster_tests.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_env_with_subprocess.py -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/python/ray/rllib/test/test_rollout.sh -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +# Run all single-agent regression tests (3x retry each) +for yaml in $(ls $ROOT_DIR/../../python/ray/rllib/tuned_examples/regression_tests); do + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/test/run_regression_tests.py /ray/python/ray/rllib/tuned_examples/regression_tests/$yaml +done + +# Try a couple times since it's stochastic +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/test/multiagent_pendulum.py || \ + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/test/multiagent_pendulum.py || \ + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/test/multiagent_pendulum.py + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_ray.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/pbt_example.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/hyperband_example.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/async_hyperband_example.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_ray_hyperband.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_async_hyperband.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/tune/examples/logging_example.py \ + --smoke-test + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/hyperopt_example.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_keras.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/genetic_example.py \ --smoke-test -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/multiagent_cartpole.py --num-iters=2 -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/multiagent_two_trainers.py --num-iters=2 -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/test/test_avail_actions_qmix.py + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/cartpole_lstm.py --run=PPO --stop=200 -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/cartpole_lstm.py --run=IMPALA --stop=100 -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/cartpole_lstm.py --stop=200 --use-prev-action-reward -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2 -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/contrib/random_agent/random_agent.py + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/twostep_game.py --stop=2000 --run=PG + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/twostep_game.py --stop=2000 --run=QMIX + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/twostep_game.py --stop=2000 --run=APEX_QMIX + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \ --batch-size=1 --strategy=simple -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \ --batch-size=1 --strategy=ps -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \ --num-workers=1 --devices-per-worker=1 --strategy=ps -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \ --num-workers=1 --devices-per-worker=1 --strategy=ps --tune -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v4 \ --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "use_pytorch": true, "sample_async": false, "model": {"use_lstm": false, "grayscale": true, "zero_mean": false, "dim": 84, "channel_major": true}, "preprocessor_pref": "rllib"}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "use_pytorch": true, "sample_async": false}' -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA python -m pytest /ray/test/object_manager_test.py - python3 $ROOT_DIR/multi_node_docker_test.py \ --docker-image=$DOCKER_SHA \ --num-nodes=5 \ diff --git a/test/runtest.py b/test/runtest.py index 767960a166580..91862023bb5dc 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -3,6 +3,7 @@ from __future__ import print_function import json +import logging import os import re import setproctitle @@ -21,6 +22,8 @@ import ray.test.cluster_utils import ray.test.test_utils +logger = logging.getLogger(__name__) + def assert_equal(obj1, obj2): module_numpy = (type(obj1).__module__ == np.__name__ @@ -378,6 +381,23 @@ def f(): assert ray.get(f.remote()) == ((3, "string1", Bar.__name__), "string2") +def test_serialization_final_fallback(ray_start): + pytest.importorskip("catboost") + # This test will only run when "catboost" is installed. + from catboost import CatBoostClassifier + + model = CatBoostClassifier( + iterations=2, + depth=2, + learning_rate=1, + loss_function="Logloss", + logging_level="Verbose") + + reconstructed_model = ray.get(ray.put(model)) + assert set(model.get_params().items()) == set( + reconstructed_model.get_params().items()) + + def test_register_class(shutdown_only): ray.init(num_cpus=2) @@ -682,7 +702,7 @@ def f(x): if val == 10: break else: - print("Still using old definition of f, trying again.") + logger.info("Still using old definition of f, trying again.") # Test that we can close over plain old data. data = [ @@ -1231,6 +1251,17 @@ def join(self): def test_free_objects_multi_node(shutdown_only): + # This test will do following: + # 1. Create 3 raylets that each hold an actor. + # 2. Each actor creates an object which is the deletion target. + # 3. Invoke 64 methods on each actor to flush plasma client. + # 4. After flushing, the plasma client releases the targets. + # 5. Check that the deletion targets have been deleted. + # Caution: if remote functions are used instead of actor methods, + # one raylet may create more than one worker to execute the + # tasks, so the flushing operations may be executed in different + # workers and the plasma client holding the deletion target + # may not be flushed. config = json.dumps({"object_manager_repeated_push_delay_ms": 1000}) ray.worker._init( start_ray_local=True, @@ -1246,53 +1277,61 @@ def test_free_objects_multi_node(shutdown_only): _internal_config=config) @ray.remote(resources={"Custom0": 1}) - def run_on_0(): - return ray.worker.global_worker.plasma_client.store_socket_name + class ActorOnNode0(object): + def get(self): + return ray.worker.global_worker.plasma_client.store_socket_name @ray.remote(resources={"Custom1": 1}) - def run_on_1(): - return ray.worker.global_worker.plasma_client.store_socket_name + class ActorOnNode1(object): + def get(self): + return ray.worker.global_worker.plasma_client.store_socket_name @ray.remote(resources={"Custom2": 1}) - def run_on_2(): - return ray.worker.global_worker.plasma_client.store_socket_name + class ActorOnNode2(object): + def get(self): + return ray.worker.global_worker.plasma_client.store_socket_name - def create(): - a = run_on_0.remote() - b = run_on_1.remote() - c = run_on_2.remote() + def create(actors): + a = actors[0].get.remote() + b = actors[1].get.remote() + c = actors[2].get.remote() (l1, l2) = ray.wait([a, b, c], num_returns=3) assert len(l1) == 3 assert len(l2) == 0 return (a, b, c) - def flush(): + def flush(actors): # Flush the Release History. # Current Plasma Client Cache will maintain 64-item list. # If the number changed, this will fail. - print("Start Flush!") + logger.info("Start Flush!") for i in range(64): - ray.get([run_on_0.remote(), run_on_1.remote(), run_on_2.remote()]) - print("Flush finished!") + ray.get([actor.get.remote() for actor in actors]) + logger.info("Flush finished!") - def run_one_test(local_only): - (a, b, c) = create() + def run_one_test(actors, local_only): + (a, b, c) = create(actors) # The three objects should be generated on different object stores. assert ray.get(a) != ray.get(b) assert ray.get(a) != ray.get(c) assert ray.get(c) != ray.get(b) ray.internal.free([a, b, c], local_only=local_only) - flush() + flush(actors) return (a, b, c) + actors = [ + ActorOnNode0.remote(), + ActorOnNode1.remote(), + ActorOnNode2.remote() + ] # Case 1: run this local_only=False. All 3 objects will be deleted. - (a, b, c) = run_one_test(False) + (a, b, c) = run_one_test(actors, False) (l1, l2) = ray.wait([a, b, c], timeout=10, num_returns=1) # All the objects are deleted. assert len(l1) == 0 assert len(l2) == 3 # Case 2: run this local_only=True. Only 1 object will be deleted. - (a, b, c) = run_one_test(True) + (a, b, c) = run_one_test(actors, True) (l1, l2) = ray.wait([a, b, c], timeout=10, num_returns=3) # One object is deleted and 2 objects are not. assert len(l1) == 2 @@ -2097,7 +2136,7 @@ def attempt_to_load_balance(remote_function, [remote_function.remote(*args) for _ in range(total_tasks)]) names = set(locations) counts = [locations.count(name) for name in names] - print("Counts are {}.".format(counts)) + logger.info("Counts are {}.".format(counts)) if (len(names) == num_local_schedulers and all(count >= minimum_count for count in counts)): break @@ -2282,7 +2321,7 @@ def test_log_file_api(shutdown_only): @ray.remote def f(): - print(message) + logger.info(message) # The call to sys.stdout.flush() seems to be necessary when using # the system Python 2.7 on Ubuntu. sys.stdout.flush() diff --git a/test/stress_tests.py b/test/stress_tests.py index 3d4b0fb363e13..3771f58053a91 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -294,6 +294,12 @@ def foo(i, size): del values +def sorted_random_indexes(total, output_num): + random_indexes = [np.random.randint(total) for _ in range(output_num)] + random_indexes.sort() + return random_indexes + + @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Failing with new GCS API on Linux.") @@ -338,8 +344,8 @@ def single_dependency(i, arg): value = ray.get(args[i]) assert value[0] == i # Get 10 values randomly. - for _ in range(10): - i = np.random.randint(num_objects) + random_indexes = sorted_random_indexes(num_objects, 10) + for i in random_indexes: value = ray.get(args[i]) assert value[0] == i # Get values sequentially, in chunks. @@ -398,8 +404,8 @@ def multiple_dependency(i, arg1, arg2, arg3): value = ray.get(args[i]) assert value[0] == i # Get 10 values randomly. - for _ in range(10): - i = np.random.randint(num_objects) + random_indexes = sorted_random_indexes(num_objects, 10) + for i in random_indexes: value = ray.get(args[i]) assert value[0] == i @@ -535,8 +541,8 @@ def single_dependency(i, arg): # were evicted and whose originating tasks are still running, this # for-loop should hang on its first iteration and push an error to the # driver. - ray.worker.global_worker.local_scheduler_client.fetch_or_reconstruct( - [args[0]], False) + ray.worker.global_worker.raylet_client.fetch_or_reconstruct([args[0]], + False) def error_check(errors): return len(errors) > 1 diff --git a/test/stress_tests/stress_testing_config.yaml b/test/stress_tests/stress_testing_config.yaml index 6ef5285035e82..718d313953474 100644 --- a/test/stress_tests/stress_testing_config.yaml +++ b/test/stress_tests/stress_testing_config.yaml @@ -94,9 +94,9 @@ setup_commands: - echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc # # Build Ray. # - git clone https://github.com/ray-project/ray || true - - pip install boto3==1.4.8 cython==0.27.3 + - pip install boto3==1.4.8 cython==0.29.0 # - cd ray/python; git checkout master; git pull; pip install -e . --verbose - - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.0-cp36-cp36m-manylinux1_x86_64.whl + - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.6.1-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. head_setup_commands: [] diff --git a/thirdparty/scripts/build_boost.sh b/thirdparty/scripts/build_boost.sh index 3f763be7e64e6..13de9f878ff2d 100755 --- a/thirdparty/scripts/build_boost.sh +++ b/thirdparty/scripts/build_boost.sh @@ -28,6 +28,6 @@ if [[ ! -d $TP_DIR/pkg/boost ]]; then # Compile boost. pushd $TP_DIR/build/boost_$BOOST_VERSION_UNDERSCORE ./bootstrap.sh - ./bjam cxxflags=-fPIC cflags=-fPIC variant=release link=static --prefix=$TP_DIR/pkg/boost --with-filesystem --with-system --with-regex install > /dev/null + ./bjam cxxflags=-fPIC cflags=-fPIC variant=release link=static --prefix=$TP_DIR/pkg/boost --with-filesystem --with-system --with-thread --with-regex install > /dev/null popd fi diff --git a/thirdparty/scripts/build_modin.sh b/thirdparty/scripts/build_modin.sh index 96563fdb21067..0f976eca9ab69 100755 --- a/thirdparty/scripts/build_modin.sh +++ b/thirdparty/scripts/build_modin.sh @@ -14,7 +14,7 @@ fi PYTHON_VERSION="$($PYTHON_EXECUTABLE -c 'import sys; print(sys.version_info[0])')" TP_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd)/../ -MODIN_VERSION=0.2.4 +MODIN_VERSION=0.2.5 MODIN_WHEELS_FNAME="modin-$MODIN_VERSION-py$PYTHON_VERSION-none-any.whl" MODIN_WHEELS_URL="https://github.com/modin-project/modin/releases/download/v$MODIN_VERSION/" diff --git a/thirdparty/scripts/collect_dependent_libs.sh b/thirdparty/scripts/collect_dependent_libs.sh new file mode 100755 index 0000000000000..ffd7dc218fdad --- /dev/null +++ b/thirdparty/scripts/collect_dependent_libs.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash +set -x + +# Cause the script to exit if a single command fails. +set -e + +ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) + +function usage() { + echo "Usage: collect_dependent_libs.sh []" + echo + echo "Options:" + echo " -h|--help print the help info" + echo " -d|--target-dir the target directory to put all the thirdparty libs" + echo " -n|--no-build do not build ray, used in case that ray is already built" + echo " -r|--resource the resource file name (default: resource.txt)" + echo +} + +# By default all the libs will be put into ./thirdparty/external_project_libs. +# However, this directory could be cleaned by `git clean`. +# Users can provide another directory using -d option. +DIR="$ROOT_DIR/../external_project_libs" +# By default ray will be built before copying the libs. +# Users can skip the building process if they have built ray. +BUILD="YES" + +RESOURCE="resource.txt" + +# Parse options +while [[ $# > 0 ]]; do + key="$1" + case $key in + -h|--help) + usage + exit 0 + ;; + -d|--target-dir) + DIR="$2" + shift + ;; + -n|--no-build) + BUILD="NO" + ;; + -r|--resource) + RESOURCE="$2" + shift + ;; + *) + echo "ERROR: unknown option \"$key\"" + echo + usage + exit -1 + ;; + esac + shift +done + +echo "External project libs will be put to $DIR" +if [ ! -d "$DIR" ]; then + mkdir -p $DIR +fi + +pushd $ROOT_DIR +if [ "$BUILD" = "YES" ]; then + echo "Build Ray First." + ../../build.sh +fi + +RAY_BUILD_DIR=$ROOT_DIR/../../build/external/ +ARROW_BUILD_DIR=$ROOT_DIR/../../build/external/arrow/src/arrow_ep-build/ + +function cp_one_lib() { + if [[ ! -d "$1" ]]; then + echo "Lib root dir $1 does not exist!" + exit -1 + fi + if [[ ! -d "$1/include" ]]; then + echo "Lib inlcude dir $1 does not exist!" + exit -1 + fi + if [[ ! -d "$1/lib" && ! -d "$1/lib64" ]]; then + echo "Lib dir $1 does not exist!" + exit -1 + fi + cp -rf $1 $DIR +} + +# copy libs that ray needs. +cp_one_lib $RAY_BUILD_DIR/boost-install +cp_one_lib $RAY_BUILD_DIR/flatbuffers-install +cp_one_lib $RAY_BUILD_DIR/glog-install +cp_one_lib $RAY_BUILD_DIR/googletest-install + +# copy libs that arrow needs. +cp_one_lib $ARROW_BUILD_DIR/snappy_ep/src/snappy_ep-install +cp_one_lib $ARROW_BUILD_DIR/thrift_ep/src/thrift_ep-install + +# generate the export script. +echo "Output the exporting resource file to $DIR/$RESOURCE." +echo "export BOOST_ROOT=$DIR/boost-install" > $DIR/$RESOURCE +echo "export RAY_BOOST_ROOT=\$BOOST_ROOT" >> $DIR/$RESOURCE + +echo "export FLATBUFFERS_HOME=$DIR/flatbuffers-install" >> $DIR/$RESOURCE +echo "export RAY_FLATBUFFERS_HOME=\$FLATBUFFERS_HOME" >> $DIR/$RESOURCE + +echo "export GTEST_HOME=$DIR/googletest-install" >> $DIR/$RESOURCE +echo "export RAY_GTEST_HOME=\$GTEST_HOME" >> $DIR/$RESOURCE + +echo "export GLOG_HOME=$DIR/glog-install" >> $DIR/$RESOURCE +echo "export RAY_GLOG_HOME=\$GLOG_HOME" >> $DIR/$RESOURCE + +echo "export SNAPPY_HOME=$DIR/snappy_ep-install" >> $DIR/$RESOURCE +echo "export THRIFT_HOME=$DIR/thrift_ep-install" >> $DIR/$RESOURCE + +popd