浣跨敤 REVIVE SDK 鎺у埗 Mujoco-HalfCheetah 杩愬姩 =============================================== .. image:: images/halfcheetah.gif :alt: example-of-Mujoco-Halfcheetah :align: center Mujoco-HalfCheetah 浠诲姟鎻忚堪 ~~~~~~~~~~~~~~~~~~~~~~~~~~~ HalfCheetah 鏄紶缁熷己鍖栧涔犱腑涓€涓粡鍏哥殑鎺у埗闂銆� ================= ==================== Action Space Continuous(6,) Observation Shape (17,) Observation High [inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf] Observation Low [-inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf] ================= ==================== 鍏充簬杩欎釜浠诲姟鐨勮缁嗘弿杩板彲浠ュ湪 `Mujoco-HalfCheetah <https://www.gymlibrary.dev/environments/mujoco/half_cheetah/>`__ 杩欓噷鎵惧埌銆� `D4RL <https://github.com/Farama-Foundation/D4RL>`__ 鏄竴涓粡鍏哥殑绂荤嚎寮哄寲瀛︿範鏁版嵁闆嗭紝鍦ㄨ繖涓緥瀛愪腑鎴戜滑灏嗛噰鐢ㄥ叾涓殑 halfcheetah-medium-v2 鏁版嵁闆嗕綔涓轰緥瀛愭潵杩涜 REVIVE 璁粌銆� 涓嬮潰灏嗕粙缁嶅浣曢€氳繃 REVIVE 鍦� halfcheetah-medium-v2 鏁版嵁闆嗕笂璁粌寰楀埌涓€涓悊鎯崇殑鐜涓庣瓥鐣ャ€傛渶鍚庯紝鎴戜滑灏嗗姣旀暟鎹泦涓殑鍘熷绛栫暐涓� REVIVE 寰楀埌鐨勭瓥鐣ワ紝鏉ョ洿瑙傚睍绀� REVIVE 鐨勫己澶ц兘鍔涖€� 鍔ㄤ綔绌洪棿 -------------------------- 鍔ㄤ綔绌洪棿鐢辫繛缁殑 6 缁村悜閲忕粍鎴愶紝鍒嗗埆琛ㄧず鎺у埗鍚勫叧鑺傜殑鍔涚煩銆傛瘡涓淮搴︾殑鍙栧€艰寖鍥翠负 [-1, 1]銆� 瑙傚療绌洪棿 -------------------------- 鐘舵€佹槸涓€涓� 17 缁村悜閲忥紝鍖呮嫭鍚勫叧鑺傜殑瑙掑害锛岃閫熷害锛屼互鍙婃満鍣ㄤ汉鍦� :math:`X` 杞达紝 :math:`Z` 杞寸嚎閫熷害锛屼互鍙婃満鍣ㄤ汉鐨勯珮搴︺€� Mujoco-HalfCheetah 鐨勭洰鏍� -------------------------- 鏈哄櫒浜虹殑鐩爣鏄湪鍥哄畾鐨勬鏁板唴锛屽敖鍙兘澶氬湴寰€鍓嶈窇锛屽悓鏃朵繚鎸佽嚜韬兘閲忔秷鑰楄緝灏忋€傝繖鐐瑰湪濂栧姳鍑芥暟鐨勫畾涔変腑寰椾互浣撶幇銆� 鍒濆鐘舵€� -------------------------- 鏈哄櫒浜轰粠鍘熺偣寮€濮嬶紝浠庡垵濮嬬姸鎬佸噯澶囧悜鍓嶅璺戙€� 浠诲姟缁撴潫 -------------------------- 鏈哄櫒浜鸿窇瀹� 1000 姝ュ悗浠诲姟浼氳寮哄埗缁撴潫銆� 浣跨敤REVIVE SDK璁粌鎺у埗绛栫暐 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ REVIVE SDK鏄竴涓巻鍙叉暟鎹┍鍔ㄧ殑宸ュ叿锛屾牴鎹枃妗f暀绋嬮儴鍒嗙殑鎻忚堪锛屽湪HalfCheetah杩愬姩浠诲姟涓婁娇鐢≧EVIVE SDK鍙互鍒嗕负浠ヤ笅鍑犳锛� 1. 澶勭悊鍘嗗彶鍐崇瓥鏁版嵁锛� 2. 缁撳悎涓氬姟鍦烘櫙鍜屾敹闆嗙殑鍘嗗彶鏁版嵁鏋勫缓 :doc:`鍐崇瓥娴佸浘鍜屾暟缁勬暟鎹�<../tutorial/data_preparation_cn>`锛屽叾涓喅绛栨祦鍥句富瑕佹弿杩颁簡涓氬姟鏁版嵁鐨勪氦浜掗€昏緫锛� 浣跨敤 ``.yaml`` 鏂囦欢杩涘瓨鍌紝鏁扮粍鏁版嵁瀛樺偍浜嗗喅绛栨祦鍥句腑瀹氫箟鐨勮妭鐐规暟鎹紝浣跨敤 ``.npz`` 鎴� ``.h5`` 鏂囦欢杩涜瀛樺偍銆� 3. 鏈変簡涓婅堪鐨勫喅绛栨祦鍥惧拰鏁扮粍鏁版嵁锛孯EVIVE SDK宸茬粡鍙互杩涜铏氭嫙鐜妯″瀷鐨勮缁冦€備絾涓轰簡鑾峰緱鏇翠紭鐨勬帶鍒剁瓥鐣ワ紝闇€瑕佹牴鎹换鍔$洰鏍囧畾涔� :doc:`濂栧姳鍑芥暟<../tutorial/reward_function_cn>` 锛� 濂栧姳鍑芥暟瀹氫箟浜嗙瓥鐣ョ殑浼樺寲鐩爣锛屽彲浠ユ寚瀵兼帶鍒剁瓥鐣ヤ娇寰楁満鍣ㄤ汉鏇村揩锛屾洿绋冲畾鍦板悜鍓嶅璺戙€� 4. 瀹氫箟瀹� :doc:`鍐崇瓥娴佸浘<../tutorial/data_preparation_cn>`锛� :doc:`璁粌鏁版嵁<../tutorial/data_preparation_cn>` 鍜� :doc:`濂栧姳鍑芥暟<../tutorial/reward_function_cn>` 涔嬪悗锛� 鎴戜滑灏卞彲浠ヤ娇鐢≧EVIVE SDK寮€濮嬭櫄鎷熺幆澧冩ā鍨嬭缁冨拰绛栫暐妯″瀷璁粌銆� 5. 鏈€鍚庡皢REVIVE SDK璁粌鐨勭瓥鐣ユā鍨嬭繘琛屼笂绾挎祴璇曘€� 鍑嗗鏁版嵁 --------------------------------- 杩欓噷鎴戜滑涓嶉渶瑕佹墜鍔ㄦ敹闆嗗巻鍙叉暟鎹紝鍥犱负 D4RL 搴撳凡缁忔彁渚涙爣鍑嗙殑绂荤嚎鍘嗗彶鏁版嵁銆傞鍏堬紝鎴戜滑闇€瑕佷笅杞藉苟棰勫鐞� D4RL 鏁版嵁闆嗭紝浣垮緱瀹冪鍚� REVIVE 鐨勮緭鍏ュ舰寮忋€� 鏁版嵁澶勭悊鑴氭湰鍦� ``data/generate_data.py``锛屾垜浠彲浠ヨ繘鍏� ``data`` 鐩綍锛岃繍琛屼互涓嬪懡浠ゅ緱鍒板鐞嗗悗鐨勬暟鎹泦銆� .. code:: bash python generate_data.py 澶勭悊杩囩▼涓湁鍑犵偣闇€瑕佹敞鎰忥細 1. 杞ㄨ抗鍒囧垎锛氭濡� :doc:`鍑嗗鏁版嵁 <../tutorial/data_preparation_cn>` 鎻愬埌鐨勶紝REVIVE 鐨勬暟鎹泦瀛楁涓姹傛湁 ``index`` 淇℃伅锛岃€岃繖涓€椤归渶瑕佹垜浠粠 halfcheetah-medium-v2 鏁版嵁闆嗕腑鏋勫缓銆� 鍏蜂綋鏂规硶涓猴細鎴戜滑鏍规嵁鏁版嵁闆嗕腑 t+1 鏃跺埢鐨� ``obs`` 鏄惁涓� t 鏃跺埢鐨� ``next_obs`` 涓€鑷达紝鏉ュ垏鍒嗚建杩癸紝骞剁敓鎴� ``index`` 淇℃伅銆� 2. 杩樺師 ``delta_x`` 淇℃伅锛氱敱浜� halfcheetah-medium-v2 鏁版嵁闆嗗苟涓嶇洿鎺ユ彁渚� x 鍧愭爣淇℃伅锛岃€屽湪 HalfCeetah 浠诲姟涓紝x 鍧愭爣淇℃伅瀵逛簬璁$畻 ``reward`` 灏や负鍏抽敭銆傚洜姝わ紝鎴戜滑閫氳繃鏁版嵁闆嗕腑鐨� ``reward`` 淇℃伅鏉ヨ繕鍘� ``delta_x`` 淇℃伅銆傝繖閲岋細 .. math:: delta\_x := x_{t+1} - x_{t} 澶勭悊缁嗚妭鐢ㄦ埛鍙互鍙傝€� ``data/generate_data.py``銆� 杩欐牱鎴戜滑灏卞緱鍒颁簡 ``.npz`` 鏂囦欢锛屾垜浠皢瀹冩斁鍏� ``data/`` 鏂囦欢澶逛腑銆� 瀹氫箟鍐崇瓥娴佸浘 -------------------------------------- 涓嬮潰鐨勭ず渚嬫樉绀� ``.yaml`` 涓殑璇︾粏淇℃伅銆傞€氬父锛屾湁涓ら儴鍒嗕俊鎭瀯鎴� ``.yaml`` 鏂囦欢锛屽垎鍒槸 ``graph`` 鍜� ``columns`` 銆� 鍏朵腑 ``graph`` 閮ㄥ垎瀹氫箟浜嗗喅绛栨祦鍥俱€� ``columns`` 閮ㄥ垎瀹氫箟浜嗘暟鎹殑缁勬垚銆傚叿浣撹鍙傝€冩枃妗o細:doc:`鍑嗗鏁版嵁 <../tutorial/data_preparation_cn>` 銆� 璇锋敞鎰忥紝鐢变簬 ``obs`` 瀛樺湪 17 涓淮搴︼紝 ``obs`` 鐨勫垪搴旇 **鎸夐『搴�** 瀹氫箟鍦� ``columns`` 閮ㄥ垎銆� 濡� `Mujoco-HalfCheetah <https://www.gymlibrary.dev/environments/mujoco/half_cheetah/>`__ 鎵€绀猴紝 鐘舵€佸拰鍔ㄤ綔涓殑鍙橀噺鏄� **杩炵画** 鐨勶紝鎴戜滑浣跨敤 ``continuous`` 鏉ユ弿杩版瘡涓€鍒楁暟鎹€� 鍙﹀锛屾敞鎰忚繖閲屾垜浠 ``delta_x`` 瀹氫箟浜� ``min``, ``max`` 鑼冨洿 [-1, 1]銆� 杩欐槸鍥犱负鍦� policy 璁粌鏃讹紝姣忎竴姝ョ殑 ``delta_x`` 鍙兘浼氳秴瓒婃暟鎹泦鐨勮寖鍥达紙鏁版嵁闆嗕腑 ``delta_x`` 绾﹀湪 [-0.12, 0.43]锛夈€� 杩欏 REVIVE 涓� **鏁版嵁褰掍竴鍖�** 澶勭悊鏈夊緢澶у奖鍝嶃€� **榛樿鎯呭喌 REVIVE 灏嗚鍙栨暟鎹腑鐨� min, max 鍊间綔褰掍竴鍖栥€�** .. code:: yaml metadata: columns: - obs_0: dim: obs type: continuous - obs_1: dim: obs type: continuous ... - obs_16: dim: obs type: continuous - action_0: dim: action type: continuous - action_1: dim: action type: continuous ... - action_5: dim: action type: continuous - delta_x: dim: delta_x type: continuous min: -1 max: 1 graph: action: - obs delta_x: - obs - action next_obs: - obs - action - delta_x 杩欐牱鎴戜滑灏卞緱鍒颁簡 ``.yaml`` 鏂囦欢锛屾垜浠篃灏嗗畠鏀惧叆 ``data/`` 鏂囦欢澶逛腑銆� 鏋勫缓濂栧姳鍑芥暟 ---------------------------------------------- 杩欓噷鎴戜滑鍙互浣跨敤 Mujoco 涓 HalfCheetah 瀹氫箟鐨勫鍔卞嚱鏁帮紝璇︽儏鍙傝€� `HalfCheetah-Env <https://github.com/openai/gym/blob/master/gym/envs/mujoco/half_cheetah_v3.py>`__ .. code:: python import torch import numpy as np from typing import Dict def get_reward(data : Dict[str, torch.Tensor]) -> torch.Tensor: action = data["action"] delta_x = data["delta_x"] forward_reward_weight = 1.0 ctrl_cost_weight = 0.1 dt = 0.05 if isinstance(action, np.ndarray): array_type = np ctrl_cost = ctrl_cost_weight * array_type.sum(array_type.square(action),axis=-1, keepdims=True) else: array_type = torch # ctrl_cost 浠h〃鍋� action 鐨勪綋鑳藉紑閿€锛岀敱 action 鐨勪簩鑼冩暟骞虫柟鏋勬垚 ctrl_cost = ctrl_cost_weight * array_type.sum(array_type.square(action),axis=-1, keepdim=True) x_velocity = delta_x / dt # forward_reward 浠h〃 halfcheetah 鍚戝墠杩愬姩鐨勫鍔憋紝x_velocity 瓒婂ぇ锛屽鍔卞€艰秺楂� forward_reward = forward_reward_weight * x_velocity # 鏈€缁� halfcheetah 寰楀埌鐨� reward 鐢� forward_reward锛宑trl_cost 鏋勬垚 # 杩欎篃瀵瑰簲浜� halfcheetah 浠诲姟鐨勭洰鏍囷細鍦ㄥ浐瀹氱殑姝ユ暟鍐咃紝halfcheetah 闇€瑕佸敖鍙兘澶氬湴寰€鍓嶈窇锛屽悓鏃朵繚鎸佽嚜韬兘閲忔秷鑰楄緝灏� reward = forward_reward - ctrl_cost return reward 杩欐牱鎴戜滑灏卞緱鍒颁簡濂栧姳鍑芥暟鏂囦欢锛屾垜浠篃灏嗗畠鏀惧叆 ``data/`` 鏂囦欢澶逛腑銆� 浣跨敤REVIVE SDK璁粌鎺у埗绛栫暐 -------------------------------------------------------- 鐜板湪锛屾垜浠凡缁忔瀯寤哄畬鎴愯繍琛� REVIVE SDK 鎵€闇€鐨勬枃浠讹紝 鍖呮嫭 ``.npz`` 鏁版嵁鏂囦欢锛� ``.yaml`` 鏂囦欢鍜� ``reward.py`` 濂栧姳鍑芥暟銆� 杩樻湁鍙︿竴涓枃浠� ``config.json``锛岃鏂囦欢淇濆瓨浜嗚缁冩墍闇€鐨勮秴鍙傛暟銆傝繖鍥涗釜鏂囦欢浣嶄簬 ``data/`` 鏂囦欢澶逛腑銆� 鐜板湪鎴戜滑鐨勬枃浠剁洰褰曞涓嬫墍绀�:: |-- data | |-- config.json | |-- generate_data.py | |-- halfcheetah_medium-v2.hdf5 | |-- halfcheetah-medium-v2.npz | |-- halfcheetah-medium-v2.yaml | `-- halfcheetah_reward.py `-- train.py 鐢ㄦ埛鍙互鍒囨崲鍒� ``examples/task/HalfCheetah`` 鐩綍涓嬶紝杩愯涓嬮潰鐨� python 鍛戒护寮€鍚櫄鎷熺幆澧冩ā鍨嬭缁冨拰绛栫暐妯″瀷璁粌銆傚湪璁粌杩囩▼涓紝鎴戜滑鍙互闅忔椂浣跨敤tensorboard鎵撳紑鏃ュ織鐩綍浠ョ洃鎺ц缁冭繃绋嬨€� .. code:: bash python train.py -df data/halfcheetah-medium-v2.npz -cf data/halfcheetah-medium-v2.yaml -rf data/halfcheetah_reward.py -rcf data/config.json --target_policy_name action -vm once -pm once --run_id halfcheetah-medium-v2-revive --revive_epoch 1500 --sac_epoch 1500 .. note:: REVIVE SDK宸茬粡鎻愪緵浜嗚缁冩墍闇€鐨勬暟鎹拰浠g爜锛岃鎯呰鍙傝€� `REVIVE SDK婧愮爜搴� <https://agit.ai/Polixir/revive/src/branch/master/examples/task/HalfCheetah>`__銆� 浣跨敤璁粌寰楀埌鐨勭瓥鐣ユ帶鍒� HalfCeetah ---------------------------------------------- 褰揜EVIVE SDK瀹屾垚铏氭嫙鐜妯″瀷璁粌鍜岀瓥鐣ユā鍨嬭缁冨悗, 鎴戜滑鍙互鍦ㄦ棩蹇楁枃浠跺す锛� ``logs/<run_id>``锛変笅鎵惧埌淇濆瓨鐨勬ā鍨嬶紙 ``.pkl`` 鎴� ``.onnx``锛夈€� 鎴戜滑灏濊瘯鍦ㄧ湡瀹炵幆澧冧笂娴嬭瘯绛栫暐鐨勬晥鏋滐紝骞跺拰鏁版嵁涓殑鎺у埗鏁堟灉杩涜瀵规瘮銆� 鍦ㄤ笅闈㈢殑娴嬭瘯浠g爜涓紝 鎴戜滑灏嗙瓥鐣ュ湪鐪熷疄鐜涓窇 100 杞紝姣忚疆鎵ц 1000 姝ワ紝杈撳嚭杩�100娆℃€荤殑骞冲潎鍥炴姤锛堢疮璁″鍔憋級銆� REVIVE SDK 鐨勭瓥鐣ヨ幏寰椾簡 7156.0 骞冲潎濂栧姳锛岃繙楂樹簬鏁版嵁涓瓥鐣ョ殑 4770.3 濂栧姳鍊硷紝鎺у埗鏁堟灉鎻愰珮浜嗙害 50%銆� .. code:: python import pickle import d4rl import gym import numpy as np def take_revive_action(state): new_data = {} new_data['obs'] = state action = policy_revive.infer(new_data) return action policy_revive = pickle.load(open('policy.pkl', 'rb')) env = gym.make('halfcheetah-medium-v2') re_list = [] for traj in range(100): state = env.reset() obs = state re_turn = [] done = False while not done: action = take_revive_action(obs) next_state, reward, done, _ = env.step(action) obs = next_state re_turn.append(reward) print(np.sum(np.array(re_turn)[:])) re_list.append(np.sum(re_turn)) print('mean return:',np.mean(re_list), ' std:',np.std(re_list), ' normal_score:', env.get_normalized_score(np.mean(re_list)) ) # REVIVE骞冲潎鍥炴姤: # mean return: 7155.900144836804 std: 63.78200350280033 normal_score: 0.5989506173038248 涓轰簡鏇寸洿瑙傚湴姣旇緝绛栫暐锛屾垜浠敓鎴愮瓥鐣ョ殑鎺у埗瀵规瘮鍔ㄧ敾銆傚彲浠ュ彂鐜帮紝REVIVE SDK 鐨勭瓥鐣ュ彲浠ユ帶鍒� HalfCheetah 璺戝緱鏇村揩鏇寸ǔ瀹氾紝 姣旀暟鎹腑鐨勫師濮嬬瓥鐣ユ洿鍔犱紭绉€銆� .. image:: images/halfcheetah_result.gif :alt: example-of-Mujoco-Halfcheetah :align: center