Source code for revive.computation.funs_parser

''''''
"""
    POLIXIR REVIVE, copyright (C) 2021-2024 Polixir Technologies Co., Ltd., is 
    distributed under the GNU Lesser General Public License (GNU LGPL). 
    POLIXIR REVIVE is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 3 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.
"""
""" The function of this document is mainly for use on the SAAS platform, and maintenance will be temporarily suspended in the future."""
import re
import yaml
from shutil import copyfile
from loguru import logger

import revive.computation.operators as opt

OPERATORS_NAME = sorted(opt.__all__, reverse=False)

[docs] def get_nodes(yaml_file_path,): with open(yaml_file_path, 'r', encoding='UTF-8') as f: raw_config = yaml.load(f, Loader=yaml.FullLoader) data_config = raw_config['metadata']['columns'] nodes_name = set([list(d.values())[0]['dim'] for d in data_config]) nodes = { node_name : [list(d.keys())[0] for d in data_config if list(d.values())[0]['dim'] == node_name] for node_name in nodes_name} for node in raw_config['metadata']['graph'].keys(): if "next_" in node and node[5:] in nodes.keys(): nodes[node] = nodes[node[5:]] return nodes
[docs] def matching_bracket(string, idx, brackets=[]): a = { "(": ")", "[": "]", "{": "}", "<": ">" } if brackets: a = { k:v for k in a.items() for k in brackets} close_bracket = a[string[idx]] if idx != string.rindex(string[idx]): b = string.rindex(string[idx]) c = string[b:].index(close_bracket) str_list = list(string) str_list[b], str_list[b+c] = ".", "." return matching_bracket(''.join(str_list), idx) else: d = string[idx:].index(close_bracket) return idx + d
[docs] def strip(s): return re.sub("\s|\t|\n", "", s)
[docs] def find_all(s, sub_s): return [m.start() for m in re.finditer(sub_s, s)]
[docs] def find_node(line): nodes = {} for node_name in sorted(NODES.keys(), reverse=True): res = find_all(line, node_name) for index in res: bracket_start_index = index+len(node_name) if bracket_start_index==len(line): continue if line[bracket_start_index] != "[" or line[bracket_start_index+1:bracket_start_index+5] == "...,": continue bracket_stop_index = matching_bracket(line, bracket_start_index) + 1 nodes[bracket_start_index] = bracket_stop_index nodes = [[index, nodes[index]] for index in sorted(nodes.keys())] return nodes
[docs] def find_column(line): columns = {} for node_name in NODES.keys(): for column_name in NODES[node_name]: res = find_all(line, column_name) for index in res: if line[index-1:index+len(column_name)+1] == f'"{column_name}"' or \ line[index-1:index+len(column_name)+1] == f"'{column_name}'": columns[index] = [column_name,node_name] columns = [[index, columns[index]] for index in sorted(columns.keys())] return columns
[docs] def find_operator(line): operators = {} for operator_name in OPERATORS_NAME: res = find_all(line, operator_name) for index in res: if line[index-1] in [" ", "=", "(", ","] and not bool(re.match('[a-zA-Z0-9_]+', line[index+len(operator_name)])): operators[index] = operator_name operators = [[index, operators[index]] for index in sorted(operators.keys())] return operators
[docs] def convert_operator(oral_operator_name): return "opt." + oral_operator_name
[docs] def convert_column(oral_column,flag): column_name, node_name = oral_column return f"{NODES[node_name].index(column_name)}"
[docs] def convert_node(bracket_start_index): return f"{NODES[node_name].index(column_name)}"
[docs] def convert_operators(line): """ Convert operator Example: add -> opt.add """ operators = find_operator(line) if operators: first_operator_index, first_operator_name = operators[0] line = line[:first_operator_index] \ + convert_operator(first_operator_name) \ + line[first_operator_index+len(first_operator_name):] return convert_operators(line) else: return line
[docs] def convert_columns(line): """ Convert column name to column index Example: "obs_1" -> 1 """ columns = find_column(line) if columns: first_column_index, first_column_name = columns[0] if "[" == line[first_column_index-2]: flag = "L" elif "]" == line[first_column_index+len(first_column_name[0])+1]: flag = "R" else: flag = "C" line = line[:first_column_index-1] \ + convert_column(first_column_name,flag) \ + line[first_column_index+len(first_column_name[0])+1:] return convert_columns(line) else: return line
[docs] def convert_nodes(line): """ Convert node to tensor Example: obs[1,2] -> obs[...,[1,2]] """ nodes = find_node(line) if nodes: bracket_start_index, bracket_stop_index = nodes[0] if ":" in line[bracket_start_index:bracket_stop_index]: line = line[:bracket_start_index] \ + "[...," \ + line[bracket_start_index+1:bracket_stop_index-1] \ + "]" \ + line[bracket_stop_index:] else: line = line[:bracket_start_index] \ + "[...," \ + line[bracket_start_index:bracket_stop_index] \ + "]" \ + line[bracket_stop_index:] return convert_nodes(line) else: return line
[docs] def convert_line(line): line = strip(line) line = convert_operators(line) line = convert_columns(line) line = convert_nodes(line) return line
[docs] def checkt_convert(line): comvert_line_copy = line for fn in [convert_operators,convert_columns,convert_nodes]: comvert_line_copy = fn(comvert_line_copy) if comvert_line_copy == line: return False return True
[docs] def convert_fn_def(origin_code : list): ''' find the intent of original code ''' codes = [] index = 0 for i, code in enumerate(origin_code): if code.startswith('def '): bracket_start_index = code.index("(") codes.append(code[:bracket_start_index]+"(data: Dict[str, torch.Tensor]) -> torch.Tensor:\n") index = i if ":" in code: index = i break args = strip("".join(origin_code[:index+1])) bracket_start_index = args.index("(") bracket_stop_index = matching_bracket(args,bracket_start_index) args = args[bracket_start_index+1:bracket_stop_index].split(",") for arg in args: codes.append(f' {arg}=data["{arg}"]\n') other_codes = origin_code[index+1:] return codes, other_codes
[docs] def get_fn_list(origin_code_list): fn_start_index = [] fn_stop_index = [] for i,code in enumerate(origin_code_list): if code.startswith('def '): if fn_start_index: fn_stop_index.append(i) fn_start_index.append(i) fn_stop_index.append(i+1) return [origin_code_list[i:j] for i,j in zip(fn_start_index,fn_stop_index)]
[docs] def parser(input_file : str, output_file : str, yaml_file : str): global NODES NODES = get_nodes(yaml_file) with open(input_file, 'r') as f: origin_code_list = f.readlines() for code in origin_code_list: if code.startswith("import torch"): logger.info(f'Not parser function in {input_file}') # copyfile(input_file, output_file) return False logger.info(f'Parser function in {input_file}') output_codes = [] output_codes.append("import torch\n") output_codes.append("from typing import Dict\n") output_codes.append("\n") output_codes.append("import revive.computation.operators as opt\n") output_codes.append("\n") fn_list = get_fn_list(origin_code_list) for fn in fn_list: start_codes, other_codes = convert_fn_def(fn) output_codes += start_codes for code in other_codes: if not checkt_convert(code): output_codes.append(code) continue if " return " in code: if "(" not in code and "[" not in code: output_codes.append(code) else: sub_code = convert_line("return="+code[code.index("return")+len("return"):]) sub_code = sub_code.replace("return=", "return ") output_codes.append(" "*(len(code) - len(code.lstrip()))+sub_code+"\n") else: output_codes.append(" "*(len(code) - len(code.lstrip()))+convert_line(code)+"\n") with open(output_file, 'w') as f: f.writelines(output_codes) return True