''''''
"""
POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is
distributed under the GNU Lesser General Public License (GNU LGPL).
POLIXIR REVIVE is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
"""
import torch
import numpy as np
from copy import deepcopy
from loguru import logger
import matplotlib.pyplot as plt
from collections import deque
from torch import nn
from torch.functional import F
from revive.computation.graph import *
from revive.computation.modules import *
from revive.utils.common_utils import *
from revive.data.dataset import data_creator
from revive.algo.venv.base import VenvOperator, catch_error
try:
from revive.dist.algo.venv.revive import ReviveOperator
logger.info(f"Import encryption venv algorithm module -> ReviveOperator!")
except:
from revive.algo.venv.revive import ReviveOperator
logger.info(f"Import venv algorithm module -> ReviveOperator!")
[docs]
class ReplayBuffer:
"""
A simple FIFO experience replay buffer for SAC agents.
"""
def __init__(self, buffer_size):
self.data = None
self.buffer_size = int(buffer_size)
[docs]
def put(self, batch_data : Batch):
batch_data.to_torch(device='cpu')
if self.data is None:
self.data = batch_data
else:
self.data.cat_(batch_data)
if len(self) > self.buffer_size:
self.data = self.data[len(self) - self.buffer_size : ]
[docs]
def __len__(self):
if self.data is None: return 0
return self.data.shape[0]
[docs]
def sample(self, batch_size):
assert len(self) > 0, 'Cannot sample from an empty buffer!'
indexes = np.random.randint(0, len(self), size=(batch_size))
return self.data[indexes]
[docs]
class TD3Operator(ReviveOperator):
NAME = "REVIVE_TD3"
PARAMETER_DESCRIPTION = [
{
"name" : "bc_batch_size",
"abbreviation" : "bbs",
"type" : int,
"default" : 256,
},
{
"name" : "bc_epoch",
"abbreviation" : "bep",
"type" : int,
"default" : 0,
},
{
"name" : "bc_lr",
"type" : float,
"default" : 1e-3,
},
{
"name" : "revive_batch_size",
"description": "Batch size of training process.",
"abbreviation" : "mbs",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "revive_epoch",
"description": "Number of epcoh for the training process",
"abbreviation" : "mep",
"type" : int,
"default" : 1000,
'doc': True,
},
{
"name" : "fintune",
"abbreviation" : "bet",
"type" : int,
"default" : 1,
'doc': True,
},
{
"name" : "finetune_fre",
"abbreviation" : "betfre",
"type" : int,
"default" : 1,
'doc': True,
},
{
"name" : "matcher_pretrain_epoch",
"abbreviation" : "dpe",
"type" : int,
"default" : 0,
},
{
"name" : "policy_hidden_features",
"description": "Number of neurons per layer of the policy network.",
"abbreviation" : "phf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "policy_hidden_layers",
"description": "Depth of policy network.",
"abbreviation" : "phl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "policy_activation",
"abbreviation" : "pa",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "policy_normalization",
"abbreviation" : "pn",
"type" : str,
"default" : None,
},
{
"name" : "policy_backbone",
"description": "Backbone of policy network.",
"abbreviation" : "pb",
"type" : str,
"default" : "res",
'doc': True,
},
{
"name" : "transition_hidden_features",
"description": "Number of neurons per layer of the transition network.",
"abbreviation" : "thf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "transition_hidden_layers",
"abbreviation" : "thl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "transition_activation",
"abbreviation" : "ta",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "transition_normalization",
"abbreviation" : "tn",
"type" : str,
"default" : None,
},
{
"name" : "transition_backbone",
"description": "Backbone of Transition network.",
"abbreviation" : "tb",
"type" : str,
"default" : "res",
'doc': True,
},
{
"name" : "matching_nodes",
"type" : list,
"default" : 'auto',
},
{
"name" : "matcher_hidden_features",
"description": "Number of neurons per layer of the matcher network.",
"abbreviation" : "dhf",
"type" : int,
"default" : 256,
'doc': True,
},
{
"name" : "matcher_hidden_layers",
"description": "Depth of the matcher network.",
"abbreviation" : "dhl",
"type" : int,
"default" : 4,
'doc': True,
},
{
"name" : "matcher_activation",
"abbreviation" : "da",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "matcher_normalization",
"abbreviation" : "dn",
"type" : str,
"default" : None,
},
{
"name" : "state_nodes",
"type" : list,
"default" : 'auto',
},
{
"name" : "value_hidden_features",
"abbreviation" : "vhf",
"type" : int,
"default" : 256,
},
{
"name" : "value_hidden_layers",
"abbreviation" : "vhl",
"type" : int,
"default" : 4,
},
{
"name" : "value_activation",
"abbreviation" : "va",
"type" : str,
"default" : 'leakyrelu',
},
{
"name" : "value_normalization",
"abbreviation" : "vn",
"type" : str,
"default" : None,
},
{
"name" : "generator_type",
"abbreviation" : "gt",
"type" : str,
"default" : "res",
},
{
"name" : "matcher_type",
"abbreviation" : "dt",
"type" : str,
"default" : "res",
},
{
"name" : "birnn",
"type" : bool,
"default" : False,
},
{
"name" : "std_adapt_strategy",
"abbreviation" : "sas",
"type" : str,
"default" : None,
},
{
"name" : "g_steps",
"description": "The number of update rounds of the generator in each epoch.",
"type" : int,
"default" : 1,
"search_mode" : "grid",
"search_values" : [1, 3, 5],
'doc': True,
},
{
"name" : "d_steps",
"description": "Number of update rounds of matcher in each epoch.",
"type" : int,
"default" : 1,
"search_mode" : "grid",
"search_values" : [1, 3, 5],
'doc': True,
},
{
"name" : "g_lr",
"description": "Initial learning rate of the generator.",
"type" : float,
"default" : 4e-5,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-4],
'doc': True,
},
{
"name" : "d_lr",
"description": "Initial learning rate of the matcher.",
"type" : float,
"default" : 6e-4,
"search_mode" : "continuous",
"search_values" : [1e-6, 1e-3],
'doc': True,
},
{
"name" : "matcher_loss_length",
"description": "Matcher loss length.",
"type" : int,
"default" : 1,
},
{
"name" : "matcher_loss_high",
"description": "Matcher loss high value. When the matcher_loss beyond the value, the generator would stop train",
"type" : float,
"default" : 1.2,
},
{
"name" : "matcher_loss_low",
"description": "Matcher loss high value. When the matcher_loss low the value, the matcher would stop train",
"type" : float,
"default" : 0.6,
},
{
"name" : "matcher_sample",
"description": "Sample the data for tring the matcher.",
"type" : bool,
"default" : False,
},
{
"name" : "mae_reward_weight",
"description": "reward = (1-mae_reward_weight)*matcher_reward + mae_reward_weight*mae_reward.",
"type" : float,
"default" : 0,
},
{
"name" : "history_matcher_num",
"description": "Number of historical discriminators saved.",
"type" : int,
"default" : 0,
},
{
"name" : "buffer_size",
"description": "Size of the buffer to store data.",
"abbreviation" : "bfs",
"type" : int,
"default" : 5e3,
'doc': True,
},
{
"name" : "td3_steps_per_epoch",
"description": "td3_steps_per_epoch.",
"abbreviation" : "tsph",
"type" : int,
"default" : 10,
'doc': True,
},
]
@catch_error
def setup(self, config):
super(TD3Operator, self).setup(config)
self.graph_train.register_target_nodes()
self.graph_val.register_target_nodes()
self.replay_buffer_train = ReplayBuffer(config['buffer_size'])
self.replay_buffer_val = ReplayBuffer(config['buffer_size'])
self.state_nodes = self._graph.get_relation_node_names() if config['state_nodes'] == 'auto' else config['state_nodes']
self.state_nodes = [node_name for node_name in self.state_nodes if node_name not in self.graph_train.transition_map.values()]
logger.info(f'Using {self.state_nodes} as state nodes!')
[docs]
def generator_model_creator(self, config, graph):
"""
Create generator models.
:param config: configuration parameters
:return: all the models.
"""
total_dims = config['total_dims']
networks = []
for node_name in list(graph.keys()):
if node_name in graph.transition_map.values(): continue
node = graph.get_node(node_name)
if not node.node_type == 'network':
logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
continue
input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
node.initialize_network(input_dim, total_dims[node_name]['output'],
hidden_features=config['policy_hidden_features'],
hidden_layers=config['policy_hidden_layers'],
hidden_activation=config['policy_activation'],
norm=config['policy_normalization'],
backbone_type=config['policy_backbone'],
dist_config=config['dist_configs'][node_name],
is_transition=False)
networks.append(node.get_network())
for node_name in graph.transition_map.values():
node = graph.get_node(node_name)
if not node.node_type == 'network':
logger.info(f'skip {node_name}, the node is already registered as {node.node_type}.')
continue
input_dim = get_input_dim_from_graph(graph, node_name, total_dims)
if node_name[5:] not in node.input_names:
transition_mode = 'global'
if config['transition_mode']:
warnings.warn('Fallback to global transition mode since the transition variable is not provided as an input!')
else:
transition_mode = config['transition_mode']
node.initialize_network(input_dim, total_dims[node_name]['output'],
hidden_features=config['transition_hidden_features'],
hidden_layers=config['transition_hidden_layers'],
hidden_activation=config['transition_activation'],
norm=config['transition_normalization'],
backbone_type=config['transition_backbone'],
dist_config=config['dist_configs'][node_name],
is_transition=True,
transition_mode=transition_mode,
obs_dim=total_dims[node_name]['input'])
networks.append(node.get_network())
assert len(networks) > 0, 'at least one node need to be a network to run training!'
# create value function
input_dim = 0
state_nodes = graph.get_relation_node_names() if config['state_nodes'] == 'auto' else config['state_nodes']
# state_nodes = graph.get_leaf() if config['state_nodes'] == 'auto' else config['state_nodes']
state_nodes = [node_name for node_name in state_nodes if node_name not in graph.transition_map.values()]
for node_name in state_nodes:
input_dim += total_dims[node_name]['input']
networks.append(MLP(input_dim, 1,
config['value_hidden_features'], config['value_hidden_layers'],
norm=config['value_normalization'], hidden_activation=config['value_activation']))
networks.append(MLP(input_dim, 1,
config['value_hidden_features'], config['value_hidden_layers'],
norm=config['value_normalization'], hidden_activation=config['value_activation']))
networks += deepcopy(networks[-2:])
return networks
[docs]
def optimizer_creator(self, models, config):
"""
Optimizer creator including generator optimizers and matcher optimizers.
:param models: node models, matcher, value_net
:param config: configuration parameters
:return: generator_optimizer, matcher_optimizer
"""
node_models = models[:2]
value_net_1, value_net_2 = models[2:4]
matcher = models[-1]
bc_optimizer = torch.optim.Adam(get_models_parameters(*node_models), lr=config['bc_lr'], weight_decay=1e-3)
generator_optimizer = torch.optim.Adam(get_models_parameters(*node_models), lr=config['g_lr'])
value_net_1_optimizer = torch.optim.Adam(get_models_parameters(value_net_1), lr=config['g_lr'])
value_net_2_optimizer = torch.optim.Adam(get_models_parameters(value_net_2), lr=config['g_lr'])
matcher_optimizer = torch.optim.Adam(matcher.parameters(), lr=config['d_lr'])
self.value_net_1_parameters = get_models_parameters(value_net_1)
return bc_optimizer, generator_optimizer, value_net_1_optimizer, value_net_2_optimizer, matcher_optimizer
def _action_add_noise(self, graph, generated_data):
for node_name in graph.learnable_node_names:
noise = torch.randn(size=generated_data[node_name].shape, device=self._device) * 0.2
self._noise_clip = 0.5
if self._noise_clip > 0.0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
generated_data[node_name] += noise
return generated_data
def _run_generator(self, expert_data, graph, generator_other_net, matcher, generator_optimizer=None, other_generator_optimizer=None):
#graph.not_use_target_network()
# sample actions from a uniform distribution for better exploration.
with torch.no_grad():
if self._epoch_cnt < self.config['bc_epoch'] + self.config['matcher_pretrain_epoch'] + 5 and False:
generated_data = self._generate_rollout(expert_data, graph, matcher, generate_reward=True, sample_fn = lambda dist: dist.sample)
else:
generated_data = self._generate_rollout(expert_data, graph, matcher, generate_reward=True, sample_fn = lambda dist: dist.mode)
buffer = self.replay_buffer_train if self.scope == "train" else self.replay_buffer_val
buffer.put(Batch({k:v for k,v in generated_data.items()}))
value_net_1, value_net_2 = generator_other_net[:2]
self.target_actor_nets = generator_other_net[2:-2]
self.target_value_net_1, self.target_value_net_2 = generator_other_net[-2:]
value_net_1_optimizer, value_net_2_optimizer = other_generator_optimizer
info,generated_data = self.TD3_step(buffer, graph, value_net_1, value_net_2,
generator_optimizer, value_net_1_optimizer, value_net_2_optimizer)
#graph.use_target_network()
return info,generated_data
[docs]
def TD3_step(self, buffer, graph, value_net_1, value_net_2,
generator_optimizer=None, value_net_1_optimizer=None, value_net_2_optimizer=None):
for i in range(self.config['td3_steps_per_epoch']):
generated_data = buffer.sample(self.config['revive_batch_size'])
generated_data.to_torch(device=self._device)
# TODO: update reward use current matcher
# update critic
with torch.no_grad():
next_generated_data = self._generate_rollout(Batch({k:v.reshape(1,-1,v.shape[-1]) for k,v in generated_data.items()}), graph, None, generate_reward=False, sample_fn = lambda dist: dist.mode)
next_generated_data = self._action_add_noise(graph,next_generated_data)
next_generated_data_value_state = get_concat_traj(next_generated_data, self.state_nodes)
next_target_q = torch.min(self.target_value_net_1(next_generated_data_value_state), self.target_value_net_2(next_generated_data_value_state))
next_target_q = next_target_q.reshape(generated_data.reward.shape)
y = generated_data.reward + 0.99*next_target_q
generated_data_value_state = get_concat_traj(generated_data, self.state_nodes)
v_loss_1 = ((value_net_1(generated_data_value_state) - y)**2).mean()
# for param in value_net_1.parameters():
# v_loss_1 += param.pow(2).sum() * self.config['gae_lambda']
value_net_1_optimizer.zero_grad()
v_loss_1.backward()
value_net_1_optimizer.step()
v_loss_2 = ((value_net_2(generated_data_value_state) - y)**2).mean()
# for param in value_net_2.parameters():
# v_loss_2 += param.pow(2).sum() * self.config['gae_lambda']
value_net_2_optimizer.zero_grad()
v_loss_2.backward()
value_net_2_optimizer.step()
info = {
"v_loss_1" : v_loss_1.mean().item(),
"v_loss_2" : v_loss_2.mean().item(),
}
# update actor
if (self._epoch_cnt-1) % 2 == 0:
graph.reset()
generated_new_data = Batch()
for node_name in graph.keys():
node = graph.get_node(node_name)
if node.node_type == 'network':
action_dist = graph.compute_node(node_name, generated_data)
action = action_dist.mode
generated_new_data[node_name] = action
for node_name in self.state_nodes:
if node_name not in generated_new_data.keys():
generated_new_data[node_name] = generated_data[node_name].detach()
value_state = get_concat_traj(generated_new_data, self.state_nodes)
for p in self.value_net_1_parameters:
p.requires_grad = False
generator_optimizer.zero_grad()
actor_loss = - value_net_1(value_state).mean()
actor_loss.backward()
generator_optimizer.step()
for p in self.value_net_1_parameters:
p.requires_grad = True
with torch.no_grad():
polyak = 0.99
for p, p_targ in zip(value_net_1.parameters(), self.target_value_net_1.parameters()):
# NB: We use an in-place operations "mul_", "add_" to update target
# params, as opposed to "mul" and "add", which would make new tensors.
p_targ.data.mul_(polyak)
p_targ.data.add_((1 - polyak) * p.data)
for p, p_targ in zip(value_net_2.parameters(), self.target_value_net_2.parameters()):
# NB: We use an in-place operations "mul_", "add_" to update target
# params, as opposed to "mul" and "add", which would make new tensors.
p_targ.data.mul_(polyak)
p_targ.data.add_((1 - polyak) * p.data)
graph.update_target_network()
info.update({
"p_loss" : actor_loss.mean().item(),})
return info,generated_data