浣跨敤 REVIVE SDK 鎺у埗 Mujoco-HalfCheetah 杩愬姩
===============================================

.. image:: images/halfcheetah.gif
 :alt: example-of-Mujoco-Halfcheetah
 :align: center


Mujoco-HalfCheetah 浠诲姟鎻忚堪
~~~~~~~~~~~~~~~~~~~~~~~~~~~
HalfCheetah 鏄紶缁熷己鍖栧涔犱腑涓€涓粡鍏哥殑鎺у埗闂銆�

================= ====================
Action Space      Continuous(6,)
Observation       Shape (17,)
Observation       High [inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf]
Observation       Low [-inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf]
================= ====================


鍏充簬杩欎釜浠诲姟鐨勮缁嗘弿杩板彲浠ュ湪 `Mujoco-HalfCheetah <https://www.gymlibrary.dev/environments/mujoco/half_cheetah/>`__ 杩欓噷鎵惧埌銆�
`D4RL <https://github.com/Farama-Foundation/D4RL>`__ 鏄竴涓粡鍏哥殑绂荤嚎寮哄寲瀛︿範鏁版嵁闆嗭紝鍦ㄨ繖涓緥瀛愪腑鎴戜滑灏嗛噰鐢ㄥ叾涓殑 halfcheetah-medium-v2 鏁版嵁闆嗕綔涓轰緥瀛愭潵杩涜 REVIVE 璁粌銆�
涓嬮潰灏嗕粙缁嶅浣曢€氳繃 REVIVE 鍦� halfcheetah-medium-v2 鏁版嵁闆嗕笂璁粌寰楀埌涓€涓悊鎯崇殑鐜涓庣瓥鐣ャ€傛渶鍚庯紝鎴戜滑灏嗗姣旀暟鎹泦涓殑鍘熷绛栫暐涓� REVIVE 寰楀埌鐨勭瓥鐣ワ紝鏉ョ洿瑙傚睍绀� REVIVE 鐨勫己澶ц兘鍔涖€�

鍔ㄤ綔绌洪棿
--------------------------

鍔ㄤ綔绌洪棿鐢辫繛缁殑 6 缁村悜閲忕粍鎴愶紝鍒嗗埆琛ㄧず鎺у埗鍚勫叧鑺傜殑鍔涚煩銆傛瘡涓淮搴︾殑鍙栧€艰寖鍥翠负 [-1, 1]銆�


瑙傚療绌洪棿 
--------------------------

鐘舵€佹槸涓€涓� 17 缁村悜閲忥紝鍖呮嫭鍚勫叧鑺傜殑瑙掑害锛岃閫熷害锛屼互鍙婃満鍣ㄤ汉鍦� :math:`X` 杞达紝 :math:`Z` 杞寸嚎閫熷害锛屼互鍙婃満鍣ㄤ汉鐨勯珮搴︺€�


Mujoco-HalfCheetah 鐨勭洰鏍�
--------------------------

鏈哄櫒浜虹殑鐩爣鏄湪鍥哄畾鐨勬鏁板唴锛屽敖鍙兘澶氬湴寰€鍓嶈窇锛屽悓鏃朵繚鎸佽嚜韬兘閲忔秷鑰楄緝灏忋€傝繖鐐瑰湪濂栧姳鍑芥暟鐨勫畾涔変腑寰椾互浣撶幇銆�


鍒濆鐘舵€�
--------------------------

鏈哄櫒浜轰粠鍘熺偣寮€濮嬶紝浠庡垵濮嬬姸鎬佸噯澶囧悜鍓嶅璺戙€�


浠诲姟缁撴潫
--------------------------

鏈哄櫒浜鸿窇瀹� 1000 姝ュ悗浠诲姟浼氳寮哄埗缁撴潫銆�


浣跨敤REVIVE SDK璁粌鎺у埗绛栫暐
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

REVIVE SDK鏄竴涓巻鍙叉暟鎹┍鍔ㄧ殑宸ュ叿锛屾牴鎹枃妗f暀绋嬮儴鍒嗙殑鎻忚堪锛屽湪HalfCheetah杩愬姩浠诲姟涓婁娇鐢≧EVIVE SDK鍙互鍒嗕负浠ヤ笅鍑犳锛�

 1. 澶勭悊鍘嗗彶鍐崇瓥鏁版嵁锛�
 2. 缁撳悎涓氬姟鍦烘櫙鍜屾敹闆嗙殑鍘嗗彶鏁版嵁鏋勫缓 :doc:`鍐崇瓥娴佸浘鍜屾暟缁勬暟鎹�<../tutorial/data_preparation_cn>`锛屽叾涓喅绛栨祦鍥句富瑕佹弿杩颁簡涓氬姟鏁版嵁鐨勪氦浜掗€昏緫锛�
    浣跨敤 ``.yaml`` 鏂囦欢杩涘瓨鍌紝鏁扮粍鏁版嵁瀛樺偍浜嗗喅绛栨祦鍥句腑瀹氫箟鐨勮妭鐐规暟鎹紝浣跨敤 ``.npz`` 鎴� ``.h5`` 鏂囦欢杩涜瀛樺偍銆�
 3. 鏈変簡涓婅堪鐨勫喅绛栨祦鍥惧拰鏁扮粍鏁版嵁锛孯EVIVE SDK宸茬粡鍙互杩涜铏氭嫙鐜妯″瀷鐨勮缁冦€備絾涓轰簡鑾峰緱鏇翠紭鐨勬帶鍒剁瓥鐣ワ紝闇€瑕佹牴鎹换鍔$洰鏍囧畾涔� :doc:`濂栧姳鍑芥暟<../tutorial/reward_function_cn>` 锛�
    濂栧姳鍑芥暟瀹氫箟浜嗙瓥鐣ョ殑浼樺寲鐩爣锛屽彲浠ユ寚瀵兼帶鍒剁瓥鐣ヤ娇寰楁満鍣ㄤ汉鏇村揩锛屾洿绋冲畾鍦板悜鍓嶅璺戙€�
 4. 瀹氫箟瀹� :doc:`鍐崇瓥娴佸浘<../tutorial/data_preparation_cn>`锛� :doc:`璁粌鏁版嵁<../tutorial/data_preparation_cn>` 鍜� :doc:`濂栧姳鍑芥暟<../tutorial/reward_function_cn>` 涔嬪悗锛�
    鎴戜滑灏卞彲浠ヤ娇鐢≧EVIVE SDK寮€濮嬭櫄鎷熺幆澧冩ā鍨嬭缁冨拰绛栫暐妯″瀷璁粌銆�
 5. 鏈€鍚庡皢REVIVE SDK璁粌鐨勭瓥鐣ユā鍨嬭繘琛屼笂绾挎祴璇曘€�


鍑嗗鏁版嵁
---------------------------------

杩欓噷鎴戜滑涓嶉渶瑕佹墜鍔ㄦ敹闆嗗巻鍙叉暟鎹紝鍥犱负 D4RL 搴撳凡缁忔彁渚涙爣鍑嗙殑绂荤嚎鍘嗗彶鏁版嵁銆傞鍏堬紝鎴戜滑闇€瑕佷笅杞藉苟棰勫鐞� D4RL 鏁版嵁闆嗭紝浣垮緱瀹冪鍚� REVIVE 鐨勮緭鍏ュ舰寮忋€�
鏁版嵁澶勭悊鑴氭湰鍦� ``data/generate_data.py``锛屾垜浠彲浠ヨ繘鍏� ``data`` 鐩綍锛岃繍琛屼互涓嬪懡浠ゅ緱鍒板鐞嗗悗鐨勬暟鎹泦銆�

.. code:: bash

 python generate_data.py

澶勭悊杩囩▼涓湁鍑犵偣闇€瑕佹敞鎰忥細

1. 杞ㄨ抗鍒囧垎锛氭濡� :doc:`鍑嗗鏁版嵁 <../tutorial/data_preparation_cn>` 鎻愬埌鐨勶紝REVIVE 鐨勬暟鎹泦瀛楁涓姹傛湁 ``index`` 淇℃伅锛岃€岃繖涓€椤归渶瑕佹垜浠粠 halfcheetah-medium-v2 鏁版嵁闆嗕腑鏋勫缓銆�
   鍏蜂綋鏂规硶涓猴細鎴戜滑鏍规嵁鏁版嵁闆嗕腑 t+1 鏃跺埢鐨� ``obs`` 鏄惁涓� t 鏃跺埢鐨� ``next_obs`` 涓€鑷达紝鏉ュ垏鍒嗚建杩癸紝骞剁敓鎴� ``index`` 淇℃伅銆�
2. 杩樺師 ``delta_x`` 淇℃伅锛氱敱浜� halfcheetah-medium-v2 鏁版嵁闆嗗苟涓嶇洿鎺ユ彁渚� x 鍧愭爣淇℃伅锛岃€屽湪 HalfCeetah 浠诲姟涓紝x 鍧愭爣淇℃伅瀵逛簬璁$畻 ``reward`` 灏や负鍏抽敭銆傚洜姝わ紝鎴戜滑閫氳繃鏁版嵁闆嗕腑鐨� ``reward`` 淇℃伅鏉ヨ繕鍘� ``delta_x`` 淇℃伅銆傝繖閲岋細

.. math:: 
  delta\_x := x_{t+1} - x_{t}

澶勭悊缁嗚妭鐢ㄦ埛鍙互鍙傝€� ``data/generate_data.py``銆�

杩欐牱鎴戜滑灏卞緱鍒颁簡 ``.npz`` 鏂囦欢锛屾垜浠皢瀹冩斁鍏� ``data/`` 鏂囦欢澶逛腑銆�

瀹氫箟鍐崇瓥娴佸浘
--------------------------------------

涓嬮潰鐨勭ず渚嬫樉绀� ``.yaml`` 涓殑璇︾粏淇℃伅銆傞€氬父锛屾湁涓ら儴鍒嗕俊鎭瀯鎴� ``.yaml`` 鏂囦欢锛屽垎鍒槸 ``graph`` 鍜� ``columns`` 銆�
鍏朵腑 ``graph`` 閮ㄥ垎瀹氫箟浜嗗喅绛栨祦鍥俱€� ``columns`` 閮ㄥ垎瀹氫箟浜嗘暟鎹殑缁勬垚銆傚叿浣撹鍙傝€冩枃妗o細:doc:`鍑嗗鏁版嵁 <../tutorial/data_preparation_cn>` 銆�
璇锋敞鎰忥紝鐢变簬 ``obs`` 瀛樺湪 17 涓淮搴︼紝 ``obs`` 鐨勫垪搴旇 **鎸夐『搴�** 瀹氫箟鍦� ``columns`` 閮ㄥ垎銆� 濡� `Mujoco-HalfCheetah <https://www.gymlibrary.dev/environments/mujoco/half_cheetah/>`__ 鎵€绀猴紝
鐘舵€佸拰鍔ㄤ綔涓殑鍙橀噺鏄� **杩炵画** 鐨勶紝鎴戜滑浣跨敤 ``continuous`` 鏉ユ弿杩版瘡涓€鍒楁暟鎹€�
鍙﹀锛屾敞鎰忚繖閲屾垜浠 ``delta_x`` 瀹氫箟浜� ``min``, ``max`` 鑼冨洿 [-1, 1]銆�
杩欐槸鍥犱负鍦� policy 璁粌鏃讹紝姣忎竴姝ョ殑 ``delta_x`` 鍙兘浼氳秴瓒婃暟鎹泦鐨勮寖鍥达紙鏁版嵁闆嗕腑 ``delta_x`` 绾﹀湪 [-0.12, 0.43]锛夈€�
杩欏 REVIVE 涓� **鏁版嵁褰掍竴鍖�** 澶勭悊鏈夊緢澶у奖鍝嶃€� **榛樿鎯呭喌 REVIVE 灏嗚鍙栨暟鎹腑鐨� min, max 鍊间綔褰掍竴鍖栥€�**

.. code:: yaml

  metadata:
    columns:
    - obs_0:
        dim: obs
        type: continuous
    - obs_1:
        dim: obs
        type: continuous
    ...
    - obs_16:
        dim: obs
        type: continuous

    - action_0:
        dim: action
        type: continuous
    - action_1:
        dim: action
        type: continuous
    ...
    - action_5:
        dim: action
        type: continuous

    - delta_x:
        dim: delta_x
        type: continuous
        min: -1
        max: 1

    graph:
      action:
      - obs
      delta_x:
      - obs
      - action
      next_obs:
      - obs
      - action
      - delta_x

杩欐牱鎴戜滑灏卞緱鍒颁簡 ``.yaml`` 鏂囦欢锛屾垜浠篃灏嗗畠鏀惧叆 ``data/`` 鏂囦欢澶逛腑銆�

鏋勫缓濂栧姳鍑芥暟
----------------------------------------------

杩欓噷鎴戜滑鍙互浣跨敤 Mujoco 涓 HalfCheetah 瀹氫箟鐨勫鍔卞嚱鏁帮紝璇︽儏鍙傝€� `HalfCheetah-Env <https://github.com/openai/gym/blob/master/gym/envs/mujoco/half_cheetah_v3.py>`__

.. code:: python

  import torch
  import numpy as np
  from typing import Dict


  def get_reward(data : Dict[str, torch.Tensor]) -> torch.Tensor:
      action = data["action"]
      delta_x = data["delta_x"]
      
      forward_reward_weight = 1.0 
      ctrl_cost_weight = 0.1
      dt = 0.05
      
      if isinstance(action, np.ndarray):
          array_type = np
          ctrl_cost = ctrl_cost_weight * array_type.sum(array_type.square(action),axis=-1, keepdims=True)
      else:
          array_type = torch
          # ctrl_cost 浠h〃鍋� action 鐨勪綋鑳藉紑閿€锛岀敱 action 鐨勪簩鑼冩暟骞虫柟鏋勬垚
          ctrl_cost = ctrl_cost_weight * array_type.sum(array_type.square(action),axis=-1, keepdim=True)
      
      x_velocity = delta_x / dt
      # forward_reward 浠h〃 halfcheetah 鍚戝墠杩愬姩鐨勫鍔憋紝x_velocity 瓒婂ぇ锛屽鍔卞€艰秺楂�
      forward_reward = forward_reward_weight * x_velocity
      
      # 鏈€缁� halfcheetah 寰楀埌鐨� reward 鐢� forward_reward锛宑trl_cost 鏋勬垚
      # 杩欎篃瀵瑰簲浜� halfcheetah 浠诲姟鐨勭洰鏍囷細鍦ㄥ浐瀹氱殑姝ユ暟鍐咃紝halfcheetah 闇€瑕佸敖鍙兘澶氬湴寰€鍓嶈窇锛屽悓鏃朵繚鎸佽嚜韬兘閲忔秷鑰楄緝灏�
      reward = forward_reward - ctrl_cost
      
      return reward

杩欐牱鎴戜滑灏卞緱鍒颁簡濂栧姳鍑芥暟鏂囦欢锛屾垜浠篃灏嗗畠鏀惧叆 ``data/`` 鏂囦欢澶逛腑銆�

浣跨敤REVIVE SDK璁粌鎺у埗绛栫暐
--------------------------------------------------------

鐜板湪锛屾垜浠凡缁忔瀯寤哄畬鎴愯繍琛� REVIVE SDK 鎵€闇€鐨勬枃浠讹紝
鍖呮嫭 ``.npz`` 鏁版嵁鏂囦欢锛� ``.yaml`` 鏂囦欢鍜� ``reward.py`` 濂栧姳鍑芥暟銆�
杩樻湁鍙︿竴涓枃浠� ``config.json``锛岃鏂囦欢淇濆瓨浜嗚缁冩墍闇€鐨勮秴鍙傛暟銆傝繖鍥涗釜鏂囦欢浣嶄簬 ``data/`` 鏂囦欢澶逛腑銆�
鐜板湪鎴戜滑鐨勬枃浠剁洰褰曞涓嬫墍绀�::

  |-- data
  |   |-- config.json
  |   |-- generate_data.py
  |   |-- halfcheetah_medium-v2.hdf5
  |   |-- halfcheetah-medium-v2.npz
  |   |-- halfcheetah-medium-v2.yaml
  |   `-- halfcheetah_reward.py
  `-- train.py

鐢ㄦ埛鍙互鍒囨崲鍒� ``examples/task/HalfCheetah`` 鐩綍涓嬶紝杩愯涓嬮潰鐨� python 鍛戒护寮€鍚櫄鎷熺幆澧冩ā鍨嬭缁冨拰绛栫暐妯″瀷璁粌銆傚湪璁粌杩囩▼涓紝鎴戜滑鍙互闅忔椂浣跨敤tensorboard鎵撳紑鏃ュ織鐩綍浠ョ洃鎺ц缁冭繃绋嬨€�

.. code:: bash

 python train.py -df data/halfcheetah-medium-v2.npz -cf data/halfcheetah-medium-v2.yaml -rf data/halfcheetah_reward.py -rcf data/config.json --target_policy_name action -vm once -pm once --run_id halfcheetah-medium-v2-revive --revive_epoch 1500 --sac_epoch 1500

.. note:: REVIVE SDK宸茬粡鎻愪緵浜嗚缁冩墍闇€鐨勬暟鎹拰浠g爜锛岃鎯呰鍙傝€� `REVIVE SDK婧愮爜搴� <https://agit.ai/Polixir/revive/src/branch/master/examples/task/HalfCheetah>`__銆�

浣跨敤璁粌寰楀埌鐨勭瓥鐣ユ帶鍒� HalfCeetah
----------------------------------------------

褰揜EVIVE SDK瀹屾垚铏氭嫙鐜妯″瀷璁粌鍜岀瓥鐣ユā鍨嬭缁冨悗,
鎴戜滑鍙互鍦ㄦ棩蹇楁枃浠跺す锛� ``logs/<run_id>``锛変笅鎵惧埌淇濆瓨鐨勬ā鍨嬶紙 ``.pkl`` 鎴� ``.onnx``锛夈€�
鎴戜滑灏濊瘯鍦ㄧ湡瀹炵幆澧冧笂娴嬭瘯绛栫暐鐨勬晥鏋滐紝骞跺拰鏁版嵁涓殑鎺у埗鏁堟灉杩涜瀵规瘮銆�
鍦ㄤ笅闈㈢殑娴嬭瘯浠g爜涓紝 鎴戜滑灏嗙瓥鐣ュ湪鐪熷疄鐜涓窇 100 杞紝姣忚疆鎵ц 1000 姝ワ紝杈撳嚭杩�100娆℃€荤殑骞冲潎鍥炴姤锛堢疮璁″鍔憋級銆�
REVIVE SDK 鐨勭瓥鐣ヨ幏寰椾簡 7156.0 骞冲潎濂栧姳锛岃繙楂樹簬鏁版嵁涓瓥鐣ョ殑 4770.3 濂栧姳鍊硷紝鎺у埗鏁堟灉鎻愰珮浜嗙害 50%銆�

.. code:: python

  import pickle
  import d4rl
  import gym
  import numpy as np

  def take_revive_action(state):
      new_data = {}
      new_data['obs'] = state
      action = policy_revive.infer(new_data)
      return action

  policy_revive = pickle.load(open('policy.pkl', 'rb'))
  env = gym.make('halfcheetah-medium-v2')

  re_list = []
  for traj in range(100):
      state = env.reset()

      obs = state
      re_turn = []
      done = False
      while not done:
          action = take_revive_action(obs)
          next_state, reward, done, _ = env.step(action)
          obs = next_state
          re_turn.append(reward)
      
      print(np.sum(np.array(re_turn)[:]))
      re_list.append(np.sum(re_turn))

  print('mean return:',np.mean(re_list), ' std:',np.std(re_list), ' normal_score:', env.get_normalized_score(np.mean(re_list)) )

 # REVIVE骞冲潎鍥炴姤: 
 # mean return: 7155.900144836804  std: 63.78200350280033  normal_score: 0.5989506173038248

涓轰簡鏇寸洿瑙傚湴姣旇緝绛栫暐锛屾垜浠敓鎴愮瓥鐣ョ殑鎺у埗瀵规瘮鍔ㄧ敾銆傚彲浠ュ彂鐜帮紝REVIVE SDK 鐨勭瓥鐣ュ彲浠ユ帶鍒� HalfCheetah 璺戝緱鏇村揩鏇寸ǔ瀹氾紝
姣旀暟鎹腑鐨勫師濮嬬瓥鐣ユ洿鍔犱紭绉€銆�

.. image:: images/halfcheetah_result.gif
 :alt: example-of-Mujoco-Halfcheetah
 :align: center