Spaces:
Runtime error
Runtime error
| """Example of using PPO to train a T5 model for translation. | |
| Based on examples/summarize_daily_cnn/t5_summarize_daily_cnn.py""" | |
| import json | |
| import os | |
| import sys | |
| from typing import List | |
| import torch | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| import trlx | |
| from trlx.data.configs import ( | |
| ModelConfig, | |
| OptimizerConfig, | |
| SchedulerConfig, | |
| TokenizerConfig, | |
| TrainConfig, | |
| TRLConfig, | |
| ) | |
| from trlx.models.modeling_ppo import PPOConfig | |
| try: | |
| import comet | |
| import evaluate | |
| if comet.__version__ != "1.1.3": | |
| raise ImportError | |
| except ImportError: | |
| raise ImportError( | |
| "To run this example, please install `evaluate`, `nltk` and `comet==1.1.3` packages by " | |
| "running `pip install evaluate unbabel-comet==1.1.3`" | |
| ) | |
| default_config = TRLConfig( | |
| train=TrainConfig( | |
| seq_length=612, | |
| epochs=100, | |
| total_steps=100000, | |
| batch_size=12, | |
| checkpoint_interval=10000, | |
| eval_interval=200, | |
| pipeline="PromptPipeline", | |
| trainer="AcceleratePPOTrainer", | |
| tracker="wandb", | |
| ), | |
| model=ModelConfig( | |
| model_path="t5-large", | |
| model_arch_type="seq2seq", | |
| num_layers_unfrozen=-1, | |
| ), | |
| tokenizer=TokenizerConfig( | |
| tokenizer_path="t5-large", | |
| padding_side="right", | |
| truncation_side="right", | |
| ), | |
| optimizer=OptimizerConfig( | |
| name="adamw", | |
| kwargs={ | |
| "lr": 2.0e-6, | |
| "betas": [0.9, 0.999], | |
| "eps": 1.0e-8, | |
| "weight_decay": 1.0e-6, | |
| }, | |
| ), | |
| scheduler=SchedulerConfig( | |
| name="cosine_annealing", | |
| kwargs={ | |
| "T_max": 10000, | |
| "eta_min": 1.0e-6, | |
| }, | |
| ), | |
| method=PPOConfig( | |
| name="PPOConfig", | |
| num_rollouts=256, | |
| chunk_size=12, | |
| ppo_epochs=4, | |
| init_kl_coef=0.05, | |
| target=6, | |
| horizon=10000, | |
| gamma=0.99, | |
| lam=0.95, | |
| cliprange=0.2, | |
| cliprange_value=0.2, | |
| vf_coef=1.0, | |
| scale_reward=None, | |
| ref_mean=None, | |
| ref_std=None, | |
| cliprange_reward=10, | |
| gen_kwargs={ | |
| "max_new_tokens": 100, | |
| }, | |
| gen_experience_kwargs={ | |
| "max_new_tokens": 100, | |
| "do_sample": False, | |
| "num_beams": 4, | |
| "temperature": 1.0, | |
| }, | |
| ), | |
| ) | |
| def main(hparams={}): | |
| config = TRLConfig.update(default_config, hparams) | |
| # COMET is the metric we are optimizng for | |
| comet_metric = evaluate.load("comet", "wmt20-comet-da", progress_bar=False) | |
| bleu_metric = evaluate.load("bleu") | |
| chrf_metric = evaluate.load("chrf") | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | |
| def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> List[float]: | |
| original_sents = [translation_map[prompt.strip()] for prompt in prompts] | |
| scores = comet_metric.compute( | |
| predictions=[output.strip() for output in outputs], | |
| references=[original["tgt"] for original in original_sents], | |
| sources=[original["src"] for original in original_sents], | |
| )["scores"] | |
| # TODO: This is needed since there seems to be a bug in the comet metric | |
| # that changes torch's determinism setting. Remove this once the bug is fixed. | |
| torch.use_deterministic_algorithms(False, warn_only=True) | |
| return scores | |
| def metric_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> List[float]: | |
| """Compute COMET, BLEU and CHRF for evaluation""" | |
| original_sents = [translation_map[prompt.strip()] for prompt in prompts] | |
| comet_score = comet_metric.compute( | |
| predictions=[output.strip() for output in outputs], | |
| references=[original["tgt"] for original in original_sents], | |
| sources=[original["src"] for original in original_sents], | |
| )["mean_score"] | |
| bleu_score = bleu_metric.compute( | |
| predictions=[output.strip() for output in outputs], | |
| references=[original["tgt"] for original in original_sents], | |
| )["bleu"] | |
| chrf_score = chrf_metric.compute( | |
| predictions=[output.strip() for output in outputs], | |
| references=[original["tgt"] for original in original_sents], | |
| )["score"] | |
| # TODO: This is needed since there seems to be a bug in the comet metric | |
| # that changes torch's determinism setting. Remove this once the bug is fixed. | |
| # Same issue as in `reward_fn` | |
| torch.use_deterministic_algorithms(False, warn_only=True) | |
| # For corpus-level metrics, it's better to ignore the sentence-level scores | |
| return {"bleu": bleu_score, "chrf": chrf_score, "comet": comet_score} | |
| # The WMT16 is large so we can benefit with using it as a streaming dataset | |
| train_dataset = load_dataset("wmt16", "de-en", split="train", streaming=True) | |
| valid_dataset = load_dataset("wmt16", "de-en", split="validation", streaming=True) | |
| src_lang = "en" | |
| tgt_lang = "de" | |
| PREFIX = "translate English to German: " | |
| # take 20,000 samples from the training set as prompts for training | |
| original_src_dataset = [sent_pair["translation"][src_lang] for sent_pair in train_dataset.take(20000)] | |
| tgt_dataset = [sent_pair["translation"][tgt_lang] for sent_pair in train_dataset.take(20000)] | |
| src_dataset = [PREFIX + src_sent for src_sent in original_src_dataset] | |
| # take 1,000 samples from the validation set as prompts for evaluation | |
| val_original_src_dataset = [sent_pair["translation"][src_lang] for sent_pair in valid_dataset.take(1000)] | |
| val_tgt_dataset = [sent_pair["translation"][tgt_lang] for sent_pair in valid_dataset.take(1000)] | |
| val_src_dataset = [PREFIX + src_sent for src_sent in val_original_src_dataset] | |
| # make dictionary of prompts and labels to use for reward function | |
| tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) | |
| tokenizer.padding_side = "left" | |
| tokenizer.truncation_side = "right" | |
| tokenizer.sep_token = "<sep>" | |
| max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] | |
| translation_map = {} | |
| for i in tqdm(range(len(original_src_dataset))): | |
| key = tokenizer.decode( | |
| tokenizer(src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], | |
| skip_special_tokens=True, | |
| ) # get prompt like trlx's prompt | |
| translation_map[key.strip()] = {"src": original_src_dataset[i], "tgt": tgt_dataset[i]} | |
| for i in tqdm(range(len(val_original_src_dataset))): | |
| key = tokenizer.decode( | |
| tokenizer(val_src_dataset[i], truncation=True, max_length=max_length, add_special_tokens=False)[ | |
| "input_ids" | |
| ], | |
| skip_special_tokens=True, | |
| ) # get prompt like trlx's prompt | |
| translation_map[key.strip()] = {"src": val_original_src_dataset[i], "tgt": val_tgt_dataset[i]} | |
| trlx.train( | |
| reward_fn=reward_fn, | |
| metric_fn=metric_fn, | |
| prompts=src_dataset, | |
| eval_prompts=val_src_dataset, | |
| config=config, | |
| ) | |
| if __name__ == "__main__": | |
| hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) | |
| main(hparams) | |