Spaces:
Runtime error
Runtime error
| # flake8: noqa | |
| import copy | |
| import json | |
| import random | |
| from pathlib import Path | |
| from pprint import pprint | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| def init_random_input(len_range: int = 5, value_gen=5) -> list: | |
| len_gen = random.randint(2, len_range + 1) | |
| value_range = list(range(-value_gen, value_gen + 1)) | |
| output = [] | |
| for index in range(len_gen): | |
| value_gen = random.choice(value_range) | |
| output.append(value_gen) | |
| return output | |
| const_integer = [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] | |
| # Functions in the DSL | |
| # Each function defines a transformation in the given DSL Grammar. | |
| def take(input_list: list, n: int) -> list: | |
| return input_list[:n] | |
| def drop(input_list: list, n: int) -> list: | |
| return input_list[n:] | |
| def minimum(input_list: list) -> int: | |
| return min(input_list) | |
| def maximum(input_list: list) -> int: | |
| return max(input_list) | |
| def reverse(input_list: list) -> list: | |
| return input_list[::-1] | |
| def sort_asc(input_list: list) -> list: | |
| return sorted(input_list) | |
| def sort_des(input_list: list) -> list: | |
| return sorted(input_list, reverse=True) | |
| def add_n(input_list: list, n: int) -> list: | |
| return [x + n for x in input_list] | |
| def sub_n(input_list: list, n: int) -> list: | |
| return [x - n for x in input_list] | |
| def mul_n(input_list: list, n: int) -> list: | |
| return [x * n for x in input_list] | |
| def div_n(input_list: list, n: int) -> list: | |
| return [x / n for x in input_list] | |
| def expand_copy(input_list: list) -> list: | |
| return input_list + input_list | |
| # Main Production Rules for the Toy DSL. | |
| list_manip_dsl = { | |
| "take": take, | |
| "drop": drop, | |
| "reverse": reverse, | |
| "sort_asc": sort_asc, | |
| "sort_des": sort_des, | |
| "add_n": add_n, | |
| "sub_n": sub_n, | |
| "mul_n": mul_n, | |
| "expand_copy": expand_copy, | |
| } | |
| # Use this class to execute programs written in the DSL. | |
| class Interpreter: | |
| def __init__(self) -> None: | |
| self.parser = list_manip_dsl | |
| def __call__(self, statement_string: str): | |
| """ | |
| Evaluation Function for the interpreter. | |
| args: | |
| statement_string (str) : Statement String | |
| """ | |
| try: | |
| return eval(statement_string) # Adding an exception to unparsable strings | |
| except: | |
| return "ERROR" | |
| interpreter = Interpreter() | |
| # TEMPLATE | |
| # This is used to store the input, output and the function template. | |
| # Input : List given as an input to the function. | |
| # function_template : The atomic function in a given DSL Grammar | |
| # Output : Transformed outut by applying function on the input. | |
| generation_template = {"function_template": "NONE", "output": "NONE", "input": []} | |
| # Each of the generate function is used to generate a | |
| # template for a given function | |
| # if chosen while sampling the dataset. | |
| # each function takes in expressions based on the grammar and generates a template. | |
| # Example: gen_take() generates a template for the take function. | |
| # take function has two arguments, | |
| # list_expression and a bounded integer(Should not be more | |
| # than the length of the list).. | |
| def gen_take(expr1=None, expr2=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| if expr2 == None: | |
| expr2 = random.choice(range(1, len(expr1) - 1)) | |
| formatted_fn = f"take({expr1},{expr2})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1, expr2] | |
| return template | |
| def gen_drop(expr1=None, expr2=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| if expr2 == None: | |
| expr2 = random.choice(range(1, len(expr1) - 1)) | |
| formatted_fn = f"drop({expr1},{expr2})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1, expr2] | |
| return template | |
| def gen_minimum(expr1=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| formatted_fn = f"minimum({expr1})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1] | |
| return template | |
| def gen_maximum(expr1=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| formatted_fn = f"maximum({expr1})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1] | |
| return template | |
| def gen_reverse(expr1=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| formatted_fn = f"reverse({expr1})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1] | |
| return template | |
| def gen_sort_asc(expr1=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| formatted_fn = f"sort_asc({expr1})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1] | |
| return template | |
| def gen_sort_des(expr1=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| formatted_fn = f"sort_des({expr1})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1] | |
| return template | |
| def gen_add_n(expr1=None, expr2=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| if expr2 == None: | |
| expr2 = random.choice(const_integer) | |
| formatted_fn = f"add_n({expr1},{expr2})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1, expr2] | |
| return template | |
| def gen_sub_n(expr1=None, expr2=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| if expr2 == None: | |
| expr2 = random.choice(const_integer) | |
| formatted_fn = f"sub_n({expr1},{expr2})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1, expr2] | |
| return template | |
| def gen_mul_n(expr1=None, expr2=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| if expr2 == None: | |
| expr2 = random.choice(const_integer) | |
| formatted_fn = f"mul_n({expr1},{expr2})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1, expr2] | |
| return template | |
| def gen_div_n(expr1=None, expr2=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| if expr2 == None: | |
| expr2 = random.choice(const_integer) | |
| formatted_fn = f"div_n({expr1},{expr2})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1, expr2] | |
| return template | |
| def gen_expand_copy(expr1=None, expr2=None): | |
| if expr1 == None: | |
| expr1 = init_random_input() | |
| if expr2 == None: | |
| expr2 = random.choice(range(1, 3)) | |
| formatted_fn = f"expand_copy({expr1},{expr2})" | |
| template = copy.copy(generation_template) | |
| template["function_template"] = formatted_fn | |
| template["output"] = interpreter(formatted_fn) | |
| template["input"] = [expr1, expr2] | |
| return template | |
| list_manip_dsl_gen = { | |
| "take": gen_take, | |
| "drop": gen_drop, | |
| "minimum": gen_minimum, | |
| "maximum": gen_maximum, | |
| "reverse": gen_reverse, | |
| "sort_asc": gen_sort_asc, | |
| "sort_des": gen_sort_des, | |
| "add_n": gen_add_n, | |
| "sub_n": gen_sub_n, | |
| "mul_n": gen_mul_n, | |
| "div_n": gen_div_n, | |
| "expand_copy": gen_expand_copy, | |
| } | |
| class Sampler: | |
| def __init__( | |
| self, | |
| max_sample_length: int = 5, | |
| code_sep: str = ";", | |
| interpreter_sep: str = "->", | |
| ): | |
| self.max_sample_length = max_sample_length | |
| self.parser = Interpreter() | |
| self.production_list = list_manip_dsl | |
| self.production_idt = [i for i in self.production_list.keys()] | |
| self.production_gen_list = list_manip_dsl_gen | |
| self.code_sep = code_sep | |
| self.interpreter_sep = interpreter_sep | |
| def sample_production(self, gen_length: int = 5): | |
| init_flag = True | |
| hash_functions = [] | |
| if gen_length == None: | |
| gen_length = self.max_sample_length | |
| for ind in range(gen_length): | |
| if init_flag: | |
| random_chosen_function = random.choice(self.production_idt) | |
| generated_function = self.production_gen_list[random_chosen_function]() | |
| hash_functions.append(generated_function) | |
| init_flag = False | |
| else: | |
| random_chosen_function = random.choice(self.production_idt) | |
| generated_function = self.production_gen_list[random_chosen_function]( | |
| hash_functions[-1]["function_template"] | |
| ) | |
| if generated_function["output"] == "ERROR": | |
| break | |
| hash_functions.append(generated_function) | |
| return hash_functions | |
| def create_synthetic_dataset(size: int, io_size=3) -> dict: | |
| output_list = [] | |
| sampler = Sampler() | |
| for i in tqdm(range(size)): | |
| try: | |
| sampled = sampler.sample_production() | |
| inp = sampled[0]["input"][0] | |
| out = sampled[-1]["output"] | |
| function = sampled[-1]["function_template"] | |
| prompt_inp = f"Input: {inp} Output: {out} Function:" | |
| prompt_out = function | |
| if out != [] and out != "ERROR": | |
| output_list.append( | |
| { | |
| "input": prompt_inp, | |
| "output": prompt_out, | |
| "io_inp": inp, | |
| "io_out": out, | |
| } | |
| ) | |
| except: | |
| pass | |
| return output_list | |
| def write_to_json(data: dict, file_name: str): | |
| with open(file_name, "w") as f: | |
| json.dump(data, f, indent=2) | |
| def basic_stats(dataset, tokenizer): | |
| """ | |
| Basic stats to calculate the token length of the dataset. | |
| """ | |
| length_list = [] | |
| for examples in tqdm(dataset): | |
| datapoint = tokenizer(examples["input"] + " " + examples["output"] + "<|endoftext|>") | |
| length_list.append(len(datapoint["input_ids"])) | |
| return { | |
| "max": max(length_list), | |
| "min": min(length_list), | |
| "mean": sum(length_list) / len(length_list), | |
| } | |
| if __name__ == "__main__": | |
| # sampler = Sampler() | |
| # pprint(sampler.sample_production()) | |
| # pprint(interpreter("div_n(reverse([-2, -5, -4]),1)")) | |
| train_data = create_synthetic_dataset(2000000) | |
| test_data = create_synthetic_dataset(2_000) | |
| print(f"Train data size: {len(train_data)}") | |
| print(f"Test data size: {len(test_data)}") | |
| Path("dataset").mkdir(parents=True, exist_ok=True) | |
| write_to_json(train_data, "dataset/train.json") | |
| write_to_json(test_data, "dataset/test.json") | |