Source code for revive.utils.causal_discovery_utils

from __future__ import annotations
''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2023 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
from abc import ABC, abstractmethod
from operator import itemgetter
from typing import Any, Callable, Dict, Iterable, Optional, Union, Tuple, List
import multiprocessing as mp
from multiprocessing.dummy import Pool as ThreadPool
from functools import partial
import time

import numpy as np

from causallearn.search.ConstraintBased.PC import pc as _pc
from causallearn.search.ConstraintBased.FCI import fci as _fci
from causallearn.search.ConstraintBased.CDNOD import cdnod as _cdnod
from causallearn.search.ScoreBased.GES import ges as _ges
from causallearn.search.ScoreBased.ExactSearch import bic_exact_search
from causallearn.search.FCMBased import lingam as _lingam
from causallearn.search.FCMBased.ANM.ANM import ANM

from causallearn.utils.cit import fisherz, chisq, gsq, mv_fisherz, kci, CIT
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge


# callback function: (current step, total step, current graph, whether use probability) -> Any
_CALLBACK_TYPE = Callable[[int, int, np.ndarray, bool], Any]


_CIT_METHODS = {
    "fisherz": fisherz,
    "chisq": chisq,
    "gsq": gsq,
    "mv_fisherz": mv_fisherz,
    "kci": kci
}


""" causal discovery methods """


# constraint-based
[docs]def pc( data: np.ndarray, indep: str = 'fisherz', thresh: float = 0.05, bg_rules: Optional[BackgroundKnowledge] = None, callback: Optional[_CALLBACK_TYPE] = None, **kwargs ) -> Tuple[np.ndarray, bool]: # start if callback: callback(0, 1, np.zeros((data.shape[-1], data.shape[-1])), False) cg = _pc(data, thresh, indep, True, 0, -1, background_knowledge=bg_rules) # end if callback: callback(1, 1, cg.G.graph, False) return cg.G.graph, False
[docs]def fci( data: np.ndarray, indep: str = 'fisherz', thresh: float = 0.05, bg_rules: Union[BackgroundKnowledge, None] = None, callback: Optional[_CALLBACK_TYPE] = None, **kwargs ) -> Tuple[np.ndarray, bool]: # start if callback: callback(0, 1, np.zeros((data.shape[-1], data.shape[-1])), False) G, _ = _fci(data, indep, thresh, verbose=False, background_knowledge=bg_rules) # end if callback: callback(1, 1, G.graph, False) return G.graph, False
[docs]def inter_cit( data: np.ndarray, indep: str = "fisherz", inter_classes: Iterable[Iterable[Iterable[int]]] = [], in_parallel: bool = True, parallel_limit: int = 5, callback: Optional[_CALLBACK_TYPE] = None, **kwargs ) -> Tuple[np.ndarray, bool]: """ use cit to discover the relations of variables inter different classes (indicated by indices) """ n = data.shape[1] # initialize p-values graph and cit method p_values = np.zeros((n, n)) cit = CIT(data, method=indep, **kwargs) # completed task counter completed = mp.Value("d", 0) # task parameters: cit algorithm, result matrix, input index, output index, # condition indices, callback, completed counter, total task number, (, lock) task_params = [] for inter_pair in inter_classes: assert len(inter_pair) == 2, "Can only test relation between two classes" input_indices, output_indices = inter_pair[0], inter_pair[1] for idx in range(len(input_indices)): i = input_indices[idx] for o in output_indices: # params: cit algorithm, result matrix, input index, # output index, condition indices, completed counter task_params.append({ "cit": cit, "mat": p_values, "in_idx": i, "out_idx": o, "c_indices": input_indices[:idx]+input_indices[idx+1:], "counter": completed, "callback": callback, }) # add total task number into the params for tp in task_params: tp["tot"] = len(task_params) # start if callback: callback(int(completed.value), len(task_params), p_values, True) # task function def do_cit_once(params): cit_alg, mat, x, y, C, cb, cnt, tot = itemgetter( "cit", "mat", "in_idx", "out_idx", "c_indices", "callback", "counter", "tot")(params) l = params["lock"] if "lock" in params else None mat[x, y] = 1 - cit_alg(x, y, condition_set=C) # modify number of completed tasks if l: l.acquire() cnt.value += 1 if l: l.release() # callback if cb: callback(int(cnt.value), tot, mat, True) if in_parallel: # parallel running lock = mp.Lock() for tp in task_params: tp["lock"] = lock pool = ThreadPool(parallel_limit) pool.map(do_cit_once, task_params) pool.close() pool.join() else: # sequential running for i, tp in enumerate(task_params): do_cit_once(tp) return p_values, True
# FCM-based
[docs]def lingam( data: np.ndarray, ver: str = 'ica', callback: Optional[_CALLBACK_TYPE] = None, **kwargs ) -> Tuple[np.ndarray, bool]: model = _lingam.ICALiNGAM() if ver == 'direct': model = _lingam.DirectLiNGAM() elif ver == 'var': model = _lingam.VARLiNGAM() elif ver == 'rcd': model = _lingam.RCD() # start if callback: callback(0, 1, np.zeros(data.shape[-1], data.shape[-1]), True) model.fit(data) if ver == 'var': adj_matrix = model.adjacency_matrices_[0] else: adj_matrix = model.adjacency_matrix_ # end if callback: callback(1, 1, np.abs(adj_matrix).T, True) return np.abs(adj_matrix).T, True
[docs]def anm( data: np.ndarray, kernelX: str = "Gaussian", kernelY: str = "Gaussian", callback: Optional[_CALLBACK_TYPE] = None, **kwargs ) -> Tuple[np.ndarray, bool]: model = ANM(kernelX, kernelY) graph = np.zeros((data.shape[-1], data.shape[-1])) # start if callback: callback(0, np.prod(graph.shape), graph, True) # orient edge by edge for i in range(graph.shape[0]): for j in range(graph.shape[1]): p_value_f, p_value_b = model.cause_or_effect(data[:, i:i+1], data[:, j:j+1]) graph[i, j], graph[j, i] = p_value_f, p_value_b # callback if callback: callback(i * graph.shape[1] + j + 1, np.prod(graph.shape), graph, True) return graph, True
# score-based _AVAILABLE_SCORE_FNS = [ "BIC", "BDeu", "CV_general", "marginal_general", "CV_multi", "marginal_multi", ]
[docs]def ges( data: np.ndarray, score_func: str = 'BIC', callback: Optional[_CALLBACK_TYPE] = None, **kwargs ) -> Tuple[np.ndarray, bool]: assert score_func in _AVAILABLE_SCORE_FNS, \ "Do not support score function '{}' (available score functions: '{}')".format( score_func, "', '".join(_AVAILABLE_SCORE_FNS)) load_func = f"local_score_{score_func}" # start if callback: callback(0, 1, np.zeros((data.shape[-1], data.shape[-1])), False) record = _ges(data, load_func) # end if callback: callback(1, 1, record['G'].graph, False) return record['G'].graph, False
_AVAILABLE_SEARCH_METHODS = [ "astar", "dp", ]
[docs]class Graph: """ Causal graph class """ def __init__( self, graph: np.ndarray, is_real: bool = False, thresh_info: Optional[Dict[str, Any]] = None, ) -> None: """ :param graph: ndarray, causal-learn style adjacency matrix, shape [state_dim + action_dim + state_dim, state_dim + action_dim + state_dim] :param is_real: whether the every element in graph is a real number :param thresh_info: information about the threshold of the graph, only used when the elements of the graph is real numbers """ assert graph.shape[0] == graph.shape[1], "Graph required to be a square matrix" self._graph = graph.copy() self._is_real = is_real self._thresh_info = thresh_info self._adj_mat = self._format_graph(graph) def _format_graph(self, mat: np.ndarray) -> np.ndarray: """ format binary causal-learn style graph as adjacency matrix (DAG) """ # causal-learn style binary matrix, -1 represents # the start of an edge, 1 represents the end of an edge mat = mat.copy() if not self._is_real: mat[np.arange(mat.shape[0]), np.arange(mat.shape[0])] = 0 # i -> j single_direct = (mat == -1) & (mat.T == 1) # i - j or i <-> j bi_direct = ((mat == -1) & (mat.T == -1)) | ((mat == 1) & (mat == 1)) # other area mat[(~single_direct) & (~bi_direct)] = 0 # i -> j (m[i, j]=1, m[j, i]=0) mat[single_direct] = 1 mat[single_direct.T] = 0 # i - j or i <-> j (m[i, j] = m[j, i] = 1) mat[bi_direct] = 1 mat[bi_direct.T] = 1 return mat @property def graph(self): """ raw graph """ return self._graph @property def thresh_info(self): """ information about threshold """ return self._thresh_info
[docs] def get_adj_matrix(self): """ return transition graph [S+A+S, S] (binary or real) """ return self._adj_mat
[docs] def get_binary_adj_matrix(self, thresh=None): """ return binary transition graph (with threshold specified) """ if self._is_real: if thresh is None: thresh = 0. return (self._adj_mat > thresh).astype(int) else: return self._adj_mat
[docs] def get_binary_adj_matrix_by_sparsity(self, sparsity=None): """ return binary transition graph (with sparsity specified) """ if self._is_real: thresh = 0. if sparsity is not None: assert 0 <= sparsity and sparsity <= 1 flatten_mat = self._adj_mat.reshape(-1) last_ele = int(np.floor(len(flatten_mat) * sparsity)) last_ele = len(flatten_mat)-1 if last_ele >= len(flatten_mat) else last_ele arg_sorted = np.argsort(flatten_mat) thresh = flatten_mat[arg_sorted[last_ele]] return self.get_binary_adj_matrix(thresh) return self._adj_mat
[docs]class TransitionGraph(Graph): """ RL transition graph class """ def __init__( self, graph: np.ndarray, state_dim: int, action_dim: int, is_real: bool = False, thresh_info: Optional[Dict[str, Any]] = None, ) -> None: """ :param graph: ndarray, causal-learn style adjacency matrix, shape [state_dim + action_dim + state_dim, state_dim + action_dim + state_dim] :param state_dim: int, the dimension of state variables :param action_dim: int, the dimension of action variables :param is_real: whether the every element in graph is a real number :param thresh_info: information about the threshold of the graph, only used when the elements of the graph is real numbers """ self._state_dim = state_dim self._action_dim = action_dim super().__init__(graph, is_real, thresh_info) def _format_graph(self, mat: np.ndarray) -> np.ndarray: """ format binary graph [S+A+S, S+A+S] as transition graph [S+A+S, S] """ mat = mat.copy() if not self._is_real: # causal-learn style binary matrix, -1 represents # the start of an edge, 1 represents the end of an edge # state, action -> next_state inter_mat = mat[ :self._state_dim+self._action_dim, self._state_dim+self._action_dim:] inter_mat[inter_mat == -1] = 1 # state -> action inter_mat = mat[:self._state_dim, self._state_dim:self._state_dim+self._action_dim] inter_mat[inter_mat == -1] = 1 # next_state -> next_state next_state_dim = mat.shape[0] - self._state_dim - self._action_dim for i in range(next_state_dim): start_idx = self._state_dim + self._action_dim + i # no loop mat[start_idx, start_idx] = 0 for j in range(i + 1, next_state_dim): end_idx = self._state_dim + self._action_dim + j # start -> end if mat[start_idx, end_idx] == -1 \ and mat[end_idx, start_idx] == 1: mat[start_idx, end_idx] = 1 mat[end_idx, start_idx] = 0 # end -> start elif mat[start_idx, end_idx] == 1 \ and mat[end_idx, start_idx] == -1: mat[start_idx, end_idx] = 0 mat[end_idx, start_idx] = 1 # start - end elif mat[start_idx, end_idx] \ == mat[end_idx, start_idx] == -1: mat[start_idx, end_idx] = mat[end_idx, start_idx] = 1 # start <-> end # TODO: handle unobserved confounder elif mat[start_idx, end_idx] \ == mat[end_idx, start_idx] == 1: mat[start_idx, end_idx] = mat[end_idx, start_idx] = 1 # no ... -> state mat[:, :self._state_dim] = 0. # no action | next state -> action mat[self._state_dim:, self._state_dim:self._state_dim+self._action_dim] = 0. return mat
[docs]class DiscoveryModule(ABC): """ Base class for causal discovery modules """ def __init__(self, **kwargs) -> None: self._graph: Union[Graph, None] = None
[docs] @abstractmethod def fit(self, data: Any, **kwargs) -> DiscoveryModule: pass
@property def graph(self) -> Union[Graph, None]: return self._graph
[docs]class ClassicalDiscovery(DiscoveryModule): """ Classical causal discovery algorithms """ CLASSICAL_ALGOS = { "pc": pc, "fci": fci, "lingam": lingam, "anm": anm, "ges": ges, "exact_search": exact_search, } # designed only for transition graph CLASSICAL_ALGOS_TRANSITION = { "inter_cit": inter_cit, } CLASSICAL_ALGOS_THRESH_INFO = { "anm": { "min": 0., "max": 1., "common": 0.5, }, "inter_cit": { "min": 0., "max": 1., "common": 0.8, }, "lingam": { "min": 0., "max": float("inf"), "common": 0.01, } } def __init__( self, alg: str = "inter_cit", alg_args: Dict[str, Any] = {"indep": "kci", "in_parallel": False}, state_keys: Optional[List[str]] = ["obs"], action_keys: Optional[List[str]] = ["action"], next_state_keys: Optional[List[str]] = ["next_obs"], limit: Optional[int] = 100, use_residual: bool = True, **kwargs ) -> None: """ :param alg: str, algorithm name, options include 'pc': Peter-Clark algorithm, 'fci': Fast Causal Inference, 'lingam': Linear Non-Gaussian Model, 'anm': Additive Nonlinear Model, 'ges': Greedy Equivalence Search, 'exact_search': Exact Search, :param alg_args: additional arguments for the algorithm used, pc: indep: str, conditional independence test used, including 'fisherz': Fisher's Z conditional independence test, 'chisq': Chi-squared conditional independence test, 'gsq': G-squared conditional independence test, 'kci': Kernel-based conditional independence test, 'mv_fisherz': Missing-value Fisher's Z conditional independence test, thresh: float, level of significance for conditional independence test fci: the same as pc cit: indep: the same as 'indep' in pc, in_parallel: whether run the algorithm in parallel, parallel_limit: limit of the number of workers lingam: ver: the version of linear non-Gaussian model to use, including 'ica': ICA-based LiNGAM, 'direct': DirectLiNGAM, 'var': VAR-LiNGAM, 'rcd': RCD (repetitive causal discovery) anm: kernelX: the kernel function for cause data, including 'Gaussian': Gaussian kernel, 'Polynomial': Polynomial kernel, 'Linear': Linear kernel, kernelY: the kernel function for effect data, options are the same as kernelX ges: score_func: score function used to score the graph, including 'BIC': BIC score, 'BDeu': BDeu score, 'CV_general': Generalized score with cross validation for data with single-dimensional variables 'marginal_general': Generalized score with marginal likelihood for data with single-dimensional variables 'CV_multi': Generalized score with cross validation for data with multi-dimensional variables 'marginal_multi': Generalized score with marginal likelihood for data with multi-dimensional variables exact_search: method: the search method used, including: 'dp': dynamic programming (DP), 'astar': A* search, :param state_keys: list[str] | None, specifying the keys of states in the input data dictionary (None indicating not using transition data) :param action_keys: list[str] | None, specifying the keys of actions in the input data dictionary (None indicating not using transition data) :param next_state_keys: list[str] | None, specifying the keys of next states in the input data dictionary (None indicating not using transition data) :param limit: int | None, limit for the number of data samples used :param residual: bool, whether use residual as next states (only for transition data) """ super().__init__(**kwargs) assert alg in ClassicalDiscovery.CLASSICAL_ALGOS \ or alg in ClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION, \ "Do not support algorithm {} (available: '{}')".format( alg, "', '".join( list(ClassicalDiscovery.CLASSICAL_ALGOS.keys()) + list(ClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION.keys()) )) assert (state_keys is not None and action_keys is not None and next_state_keys is not None) \ or (state_keys is None and action_keys is None and next_state_keys is None), \ "state, action, next states keys should all be None or not None" self._support_transition = state_keys is not None if self._support_transition: assert len(state_keys) > 0 and len(action_keys) > 0 and len(next_state_keys) > 0, \ "state, action, next states keys can not be empty" self._alg_name = alg if alg in ClassicalDiscovery.CLASSICAL_ALGOS: self._alg = ClassicalDiscovery.CLASSICAL_ALGOS[alg] else: assert self._support_transition, f"{alg} only work for transition data" self._alg = ClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION[alg] self._alg_args = alg_args self._state_keys = state_keys self._action_keys = action_keys self._next_state_keys = next_state_keys self._limit = limit self._use_residual = use_residual def _extract_data( self, data_dict: Dict[str, np.ndarray] ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: states = itemgetter(*self._state_keys)(data_dict) actions = itemgetter(*self._action_keys)(data_dict) next_states = itemgetter(*self._next_state_keys)(data_dict) if isinstance(states, tuple): states = np.concatenate(states, axis=-1) if isinstance(actions, tuple): actions = np.concatenate(actions, axis=-1) if isinstance(next_states, tuple): next_states = np.concatenate(next_states, axis=-1) return states, actions, next_states def _build_graph( self, graph_mat: np.ndarray, is_real: bool, state_dim: Optional[int] = None, action_dim: Optional[int] = None, ): """ build graph with graph adjacency matrix and whether the elements are real numbers :param graph_mat: ndarray, adjacency matrix of the graph :param is_real: whether the elements in the graph is real numbers :param state_dim: int or None, if not None, indicating the graph is a transition graph, data is shaped as [state_dim, action_dim, next_state_dim (state_dim)] :param action_dim: int or None, similar with state_dim """ # information about the threshold of the edge existence (None if the graph is binary) thresh_info = None if self._alg_name in ClassicalDiscovery.CLASSICAL_ALGOS_THRESH_INFO: thresh_info = ClassicalDiscovery.CLASSICAL_ALGOS_THRESH_INFO[self._alg_name] if state_dim is not None and action_dim is not None: # build a transition graph self._graph = TransitionGraph(graph_mat, state_dim, action_dim, is_real, thresh_info) else: # build a regular graph self._graph = Graph(graph_mat, is_real, thresh_info) def _fit_transition( self, data: np.ndarray, state_dim: int, action_dim: int ) -> ClassicalDiscovery: """ fit the discovery module to transition data :param data: ndarray, transition data, aranged by states, actions, next_states :param state_dim: int, dimension of states :param action_dim: int, dimension of actions :return: the module itself """ if self._alg_name in ClassicalDiscovery.CLASSICAL_ALGOS: # build general rl background knowledge rules obs_act_pattern = "^(%s)$" % "|".join( "X%d" % i for i in range(1, state_dim+action_dim+1)) next_obs_pattern = "^(%s)$" % "|".join( "X%d" % i for i in range(state_dim+action_dim+1, data.shape[-1]+1)) obs_pattern = "^(%s)$" % "|".join("X%d" % i for i in range(1, state_dim+1)) act_pattern = "^(%s)$" % "|".join( "X%d" % i for i in range(state_dim+1, state_dim+action_dim+1)) bg_rules = BackgroundKnowledge() # forbid next state -> state | action bg_rules.add_forbidden_by_pattern(next_obs_pattern, obs_act_pattern) # forbid action -> state bg_rules.add_forbidden_by_pattern(act_pattern, obs_pattern) # causal discovery algorithm, return a graph matrix # and whether the element of the matrix is real number graph, is_real = self._alg(data, bg_rules=bg_rules, **self._alg_args) else: # build general rl transition variable classes inter_classes = [ # state -> action [ list(range(state_dim)), list(range(state_dim, state_dim + action_dim)) ], # state, action -> next state [ list(range(state_dim + action_dim)), list(range(state_dim + action_dim, data.shape[1])) ] ] # causal discovery algorithm, return a graph matrix # and whether the element of the matrix is real number graph, is_real = self._alg(data, inter_classes=inter_classes, **self._alg_args) # build transition graph self._build_graph(graph, is_real, state_dim, action_dim) return self def _fit_all(self, data: np.ndarray) -> ClassicalDiscovery: """ fit the discovery module to general data :param data: [num_samples x num_features] :return: the module itself """ graph, is_real = self._alg(data, bg_rules=None, **self._alg_args) # build graph self._build_graph(graph, is_real) return self
[docs] def fit( self, data: Union[Dict[str, np.ndarray], np.ndarray], fit_transition: bool = True, ) -> ClassicalDiscovery: """ fit the discovery module to transition data or general data :param data: dict[str, ndarray] | ndarray, transition data dictionary or general data matrix :return: the module itself """ if fit_transition: assert isinstance(data, dict), "need transition data format" assert self._support_transition, \ "fitting transition data needs specifying keys" states, actions, next_states = self._extract_data(data) state_dim, action_dim = states.shape[-1], actions.shape[-1] assert states.shape[0] == actions.shape[0] == next_states.shape[0], \ "Transition data shape mismatch" # prepare data if self._use_residual: assert next_states.shape[1] == states.shape[1] next_states -= states data = np.concatenate([states, actions, next_states], axis=-1) else: assert isinstance(data, np.ndarray), "need general data format" assert self._alg_name not in ClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION, \ f"{self._alg_name} only work for transition data" # limited data samples limit = self._limit if self._limit is not None else data.shape[0] if limit <= data.shape[0]: indices = np.random.choice(data.shape[0], size=limit, replace=False) data = data[indices] # start fitting if fit_transition: return self._fit_transition(data, state_dim, action_dim) else: return self._fit_all(data)
[docs]class AsyncClassicalDiscovery(ClassicalDiscovery): """ Classical causal discovery algorithms (support asynchronous ver.) """ def __init__( self, alg: str = "inter_cit", alg_args: Dict[str, Any] = { "indep": "kci","in_parallel": False }, state_keys: Optional[List[str]] = ["obs"], action_keys: Optional[List[str]] = ["action"], next_state_keys: Optional[List[str]] = ["next_obs"], limit: Optional[int] = 100, use_residual: bool = True, callback: Optional[_CALLBACK_TYPE] = None, **kwargs ) -> None: super().__init__( alg, alg_args, state_keys, action_keys, next_state_keys, limit, use_residual, **kwargs) self._custom_callback = callback # init progress management self.cur_step = None self.tot_step = None self.start_time = None self.remaining_time = None self.is_running = False def _before_start(self): """ recording done before discovery starts """ self.cur_step = 0 self.tot_step = 0 self.start_time = time.time() self.remaining_time = float("inf") self.is_running = True def _after_end(self): """ recording done after discovery ends """ self.is_running = False def _callback( self, cur_step: int, tot_step: int, graph: np.ndarray, is_real: bool, state_dim: Optional[int] = None, action_dim: Optional[int] = None, ): # record progress self.cur_step = cur_step self.tot_step = tot_step self.elapsed_time = time.time() - self.start_time if self.cur_step != 0: self.remaining_time = self.elapsed_time * (tot_step / cur_step - 1) # record graph self._build_graph(graph, is_real, state_dim, action_dim) # custom callback if self._custom_callback: self._custom_callback(cur_step, tot_step, graph, is_real)
[docs] def set_callback(self, callback: _CALLBACK_TYPE): """ set custom callback function """ self._custom_callback = callback
[docs] def fit( self, data: Union[Dict[str, np.ndarray], np.ndarray], fit_transition: bool = True ) -> ClassicalDiscovery: if fit_transition: assert isinstance(data, dict), "need transition data format" assert self._support_transition, \ "fitting transition data needs specifying keys" states, actions, next_states = self._extract_data(data) state_dim, action_dim = states.shape[-1], actions.shape[-1] assert states.shape[0] == actions.shape[0] == next_states.shape[0], \ "Transition data shape mismatch" # prepare data if self._use_residual: assert next_states.shape[1] == states.shape[1] next_states -= states data = np.concatenate([states, actions, next_states], axis=-1) else: assert isinstance(data, np.ndarray), "need general data format" assert self._alg_name not in AsyncClassicalDiscovery.CLASSICAL_ALGOS_TRANSITION, \ f"{self._alg_name} only work for transition data" state_dim = action_dim = None # limited data samples limit = self._limit if self._limit is not None else data.shape[0] if limit <= data.shape[0]: indices = np.random.choice(data.shape[0], size=limit, replace=False) data = data[indices] # add callback into the algorithm arguments self._alg_args["callback"] = partial( self._callback, state_dim=state_dim, action_dim=action_dim) # start fitting self._before_start() try: if fit_transition: return self._fit_transition(data, state_dim, action_dim) else: return self._fit_all(data) finally: # finished self._after_end()
if __name__ == "__main__": np.random.seed(0) """ basic case for transition data """ states = np.random.normal(0, 1, size=10000).reshape(-1, 2) actions = np.tanh(states[:, 0:1]) delta_states = 2*states + actions + np.random.uniform(-0.1, 0.1, size=10000).reshape(-1, 2) next_states = delta_states + states # default keys data_dict = {"obs": states, "action": actions, "next_obs": next_states} # default keys are not None, the module supports transition data discover_module = ClassicalDiscovery() discover_module.fit(data_dict, fit_transition=True) # threshshold information (maybe None) print("graph threshold information\n", discover_module.graph.thresh_info) # for debug or private usage print("algorithm output graph\n", discover_module.graph.graph) # raw adjacent matrix print("transition graph\n", discover_module.graph.get_adj_matrix()) # get binary matrix by indicating threshold (recommended) # 0.95 is a proper threshold for default method "inter_cit" (bounded in [0,1]) print( "binary transition graph by threshold\n", discover_module.graph.get_binary_adj_matrix(discover_module.graph.thresh_info["common"]) ) # get binary matrix by indicating sparsity (need prior knowledge) print( "binary transition graph by sparsity\n", discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75) ) """ multi-keys """ states2 = np.random.uniform(-1, 1, size=5000).reshape(-1, 1) delta_states2 = states2 - actions \ + np.random.uniform(-0.1, 0.1, size=5000).reshape(-1, 1) next_states2 = delta_states2 + states2 data_dict2 = { "obs1": states, "obs2": states2, "action": actions, "next_obs1": next_states, "next_obs2": next_states2 } # keys are not None, the module supports transition data discover_module = ClassicalDiscovery( state_keys=["obs1", "obs2"], action_keys=["action"], next_state_keys=["next_obs1", "next_obs2"] ) discover_module.fit(data_dict2, fit_transition=True) # threshshold information (maybe None) print("graph threshold information\n", discover_module.graph.thresh_info) # for debug or private usage print("algorithm output graph\n", discover_module.graph.graph) # raw adjacent matrix print("transition graph\n", discover_module.graph.get_adj_matrix()) # get binary matrix by indicating threshold (recommended) # 0.95 is a proper threshold for default method "inter_cit" (bounded in [0,1]) print( "binary transition graph by threshold\n", discover_module.graph.get_binary_adj_matrix(discover_module.graph.thresh_info["common"]) ) # get binary matrix by indicating sparsity (need prior knowledge) print( "binary transition graph by sparsity\n", discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75) ) """ basic case for general data (default algorithm does not support general data) """ discover_module = ClassicalDiscovery( alg="lingam", alg_args={"ver": "direct"}, limit=None, ) s = np.random.uniform(-1, 1, 5000).reshape(-1, 1) a = 1.1 * s + np.random.uniform(-0.1, 0.1, size=5000).reshape(-1, 1) s_ = s + a + np.random.uniform(-0.01, 0.01, size=5000).reshape(-1, 1) data = np.concatenate((s, a, s_), axis=-1) discover_module.fit(data, fit_transition=False) # threshshold information (maybe None) print("graph threshold information\n", discover_module.graph.thresh_info) # for debug or private usage print("algorithm output graph\n", discover_module.graph.graph) # raw adjacent matrix print("transition graph\n", discover_module.graph.get_adj_matrix()) # get binary matrix by indicating threshold (recommended) # 0.01 is a proper threshold for method "lingam" (>=0, not upper bounded) print( "binary transition graph by threshold\n", discover_module.graph.get_binary_adj_matrix(discover_module.graph.thresh_info["common"]) ) # get binary matrix by indicating sparsity (need prior knowledge) print( "binary transition graph by sparsity\n", discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75) ) """ use more data to acquire a more accurate result """ discover_module = ClassicalDiscovery( # run in parallel alg_args={"indep": "kci", "in_parallel": True, "parallel_limit": 5}, state_keys=["obs1", "obs2"], action_keys=["action"], next_state_keys=["next_obs1", "next_obs2"], limit=1000, ) discover_module.fit(data_dict2, fit_transition=True) # threshshold information (maybe None) print("graph threshold information\n", discover_module.graph.thresh_info) # for debug or private usage print("algorithm output graph\n", discover_module.graph.graph) # raw adjacent matrix print("transition graph\n", discover_module.graph.get_adj_matrix()) # get binary matrix by indicating threshold (recommended) # 0.95 is a proper threshold for default method "inter_cit" (bounded in [0,1]) print( "binary transition graph by threshold\n", discover_module.graph.get_binary_adj_matrix(discover_module.graph.thresh_info["common"]) ) # get binary matrix by indicating sparsity (need prior knowledge) print( "binary transition graph by sparsity\n", discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75) ) """ use callback to informing progress """ async_discover_module = AsyncClassicalDiscovery( state_keys=["obs1", "obs2"], action_keys=["action"], next_state_keys=["next_obs1", "next_obs2"], limit=3000, # callback can be set when constructing callback=None, ) # callback type should be `_CALLBACK_TYPE`, see related comments def callback(cur_step: int, tot_step: int, cur_raw_graph: np.ndarray, is_real: bool): print(f"current step {cur_step}, total step {tot_step}, current raw graph {cur_raw_graph}," f" whether the elements of the graph is real number {is_real}") # can also get progress information from the module print(f"current step {async_discover_module.cur_step}," f" total step {async_discover_module.tot_step}") print(f"start time {async_discover_module.start_time}") print(f"time elapsed {async_discover_module.elapsed_time}") print(f"is running {async_discover_module.is_running}") print(f"estimated remaining time {async_discover_module.remaining_time}") print("currnet graph {}".format( async_discover_module.graph.get_binary_adj_matrix( async_discover_module.graph.thresh_info["common"]))) print("=" * 20) # callback can also be set after constructed async_discover_module.set_callback(callback) # start fitting async_discover_module.fit(data_dict2, fit_transition=True) print(f"is running {async_discover_module.is_running}") # threshshold information (maybe None) print("graph threshold information\n", async_discover_module.graph.thresh_info) # for debug or private usage print("algorithm output graph\n", async_discover_module.graph.graph) # raw adjacent matrix print("transition graph\n", async_discover_module.graph.get_adj_matrix()) # get binary matrix by indicating threshold (recommended) # 0.95 is a proper threshold for default method "inter_cit" (bounded in [0,1]) print( "binary transition graph by threshold\n", async_discover_module.graph.get_binary_adj_matrix( async_discover_module.graph.thresh_info["common"]) ) # get binary matrix by indicating sparsity (need prior knowledge) print( "binary transition graph by sparsity\n", async_discover_module.graph.get_binary_adj_matrix_by_sparsity(0.75) )