diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c27989d900614d1a53e5bdcc1f9c2bee23ef3969 --- /dev/null +++ b/LICENSE @@ -0,0 +1,51 @@ +Third-Party License Attribution for Audio Processing Module +=========================================================== + +This directory contains code derived from multiple open-source projects. +The following sections detail the licenses and attributions for third-party code. + +## XCodec Repository +The code in this directory is derived from: +https://github.com/zhenye234/xcodec + +## Individual File Attributions + +### Quantization Module (quantization/) +- Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository +- Individual files contain their own license headers where applicable +- The vector-quantize-pytorch portions are licensed under the MIT License + +## License Terms + +### MIT License (for applicable portions) +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +## Attribution Requirements +When using this code, please ensure proper attribution to: +1. The original xcodec repository: https://github.com/zhenye234/xcodec +2. Any other repositories mentioned in individual file headers +3. This derivative work and its modifications + +## Disclaimer +This directory contains modified versions of the original code. Please refer to +the original repositories for the canonical implementations and their specific +license terms. + +For any questions about licensing or attribution, please check the individual +file headers and the original source repositories. \ No newline at end of file diff --git a/__pycache__/discriminator.cpython-312.pyc b/__pycache__/discriminator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8320e1f6a8860b8ab1455fdf71392be17670bf0 Binary files /dev/null and b/__pycache__/discriminator.cpython-312.pyc differ diff --git a/__pycache__/higgs_audio_tokenizer.cpython-311.pyc b/__pycache__/higgs_audio_tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44d4dcd2665b94c5480d41bf849f549cf04d1766 Binary files /dev/null and b/__pycache__/higgs_audio_tokenizer.cpython-311.pyc differ diff --git a/__pycache__/higgs_audio_tokenizer.cpython-312.pyc b/__pycache__/higgs_audio_tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af12d45c3b9a4b76d43e20f43b37f451b131f1fc Binary files /dev/null and b/__pycache__/higgs_audio_tokenizer.cpython-312.pyc differ diff --git a/__pycache__/loss.cpython-312.pyc b/__pycache__/loss.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa7769e4195ad6207d476bd34c9e6558fd1f86e8 Binary files /dev/null and b/__pycache__/loss.cpython-312.pyc differ diff --git a/__pycache__/semantic_module.cpython-312.pyc b/__pycache__/semantic_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..399996bbf868bbc0864fda626b032a941db4c237 Binary files /dev/null and b/__pycache__/semantic_module.cpython-312.pyc differ diff --git a/boson_codeit.py b/boson_codeit.py new file mode 100644 index 0000000000000000000000000000000000000000..98df2ce52f06a458d4e902b6272ac69b2ad92c0f --- /dev/null +++ b/boson_codeit.py @@ -0,0 +1,651 @@ +# #!/usr/bin/env python3 +# """ +# Audio Processing Script for Boson Codes +# Processes audio files in parallel using Higgs Audio Tokenizer +# and saves encoded representations as .pt files. +# """ + +# import os +# import sys +# import json +# import torch +# import librosa +# import numpy as np +# import warnings +# import argparse +# from pathlib import Path +# from multiprocessing import Pool +# from tqdm import tqdm + +# from datasets import load_from_disk +# from higgs_audio_tokenizer import HiggsAudioTokenizer + +# # Suppress PyTorch FutureWarnings +# warnings.filterwarnings("ignore", category=FutureWarning) + +# # Global configuration +# DEFAULT_OUTPUT_DIR = "/home/ubuntu/boson_codes" +# DEFAULT_NUM_CORES = 48 +# DEFAULT_SAMPLE_RATE = 44100 +# DEFAULT_DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/data" + +# # Model paths +# CONFIG_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/config.json" +# MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/model.pth" + +# # Global model variable (initialized in each worker) +# model = None + + +# def init_worker(): +# """Initialize model once per worker process.""" +# global model +# device = 'cpu' + +# # Load config +# with open(CONFIG_PATH, 'r') as f: +# config = json.load(f) + +# # Initialize model +# model = HiggsAudioTokenizer( +# **config, +# device=device, +# ) + +# # Load weights +# parameter_dict = torch.load(MODEL_PATH, map_location=device) +# _ = model.load_state_dict(parameter_dict, strict=False) +# model = model.to(device) +# _ = model.eval() + +# print(f"Model loaded in worker {os.getpid()}") + + +# def process_audio_file(args): +# """Process a single audio file using pre-loaded model.""" +# filename, output_dir, sample_rate = args + +# try: +# # Output filename - same name, just change extension to .pt +# base_name = Path(filename).stem +# output_path = os.path.join(output_dir, f"{base_name}.pt") + +# # Skip if exists (double-check in case of race conditions) +# if os.path.exists(output_path): +# return ("skipped", filename) + +# # Load and process audio +# wav, sr = librosa.load(filename, sr=sample_rate) +# wav = torch.from_numpy(wav).unsqueeze(0).float().to('cpu') + +# # Encode using the pre-loaded model +# with torch.no_grad(): +# encoded = model._xcodec_encode(wav.unsqueeze(0)) + +# # Save codes only +# torch.save(encoded.audio_codes, output_path) + +# return ("success", filename) + +# except Exception as e: +# return ("error", filename, str(e)) + + +# def load_dataset(dataset_path): +# """Load and prepare the dataset.""" +# print(f"Loading dataset from: {dataset_path}") +# ds = load_from_disk(dataset_path) +# print(f"Dataset info: {ds}") + +# # Remove unnecessary columns +# columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask'] +# existing_columns = [col for col in columns_to_remove if col in ds.column_names] +# if existing_columns: +# ds = ds.remove_columns(existing_columns) +# print(f"Removed columns: {existing_columns}") + +# # Convert to pandas DataFrame +# df = ds.to_pandas() +# print(f"Loaded {len(df)} files from dataset") +# return df + + +# def main(args): +# """Main processing function.""" +# # Change to audio processing directory +# os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing") +# print(f"Working directory: {os.getcwd()}") + +# # Create output directory +# os.makedirs(args.output_dir, exist_ok=True) +# print(f"Output directory: {args.output_dir}") + +# # Check if model files exist +# if not os.path.exists(CONFIG_PATH): +# print(f"Error: Config file not found at {CONFIG_PATH}") +# sys.exit(1) +# if not os.path.exists(MODEL_PATH): +# print(f"Error: Model file not found at {MODEL_PATH}") +# sys.exit(1) + +# # Load dataset +# df = load_dataset(args.dataset_path) + +# # Get filenames from dataframe +# all_filenames = df['filename'].tolist() + +# # Pre-filter to exclude already processed files +# filenames_to_process = [] +# already_processed = [] + +# print(f"\nChecking for already processed files...") +# for filename in all_filenames: +# base_name = Path(filename).stem +# output_path = os.path.join(args.output_dir, f"{base_name}.pt") +# if os.path.exists(output_path): +# already_processed.append(filename) +# else: +# filenames_to_process.append(filename) + +# print(f"\nTotal files: {len(all_filenames)}") +# print(f"Already processed: {len(already_processed)}") +# print(f"To process: {len(filenames_to_process)}") + +# if len(filenames_to_process) == 0: +# print("\nAll files have already been processed!") +# return + +# print(f"\nProcessing {len(filenames_to_process)} files using {args.num_cores} cores...") +# print(f"Sample rate: {args.sample_rate} Hz") + +# # Prepare arguments for multiprocessing +# process_args = [(filename, args.output_dir, args.sample_rate) +# for filename in filenames_to_process] + +# # Process in parallel with model reuse +# with Pool(processes=args.num_cores, initializer=init_worker) as pool: +# results = list(tqdm( +# pool.imap(process_audio_file, process_args, chunksize=args.chunksize), +# total=len(filenames_to_process), +# desc="Processing audio files" +# )) + +# # Count results +# processed = sum(1 for r in results if r[0] == "success") +# skipped = sum(1 for r in results if r[0] == "skipped") +# errors = sum(1 for r in results if r[0] == "error") + +# print(f"\nProcessing complete!") +# print(f" Successfully processed: {processed}") +# print(f" Previously processed: {len(already_processed)}") +# print(f" Skipped (race condition): {skipped}") +# print(f" Errors: {errors}") + +# # Show errors if any +# if errors > 0: +# print("\nErrors encountered:") +# error_log_path = os.path.join(args.output_dir, "processing_errors.log") +# with open(error_log_path, 'w') as f: +# for r in results: +# if r[0] == "error": +# error_msg = f"{r[1]}: {r[2]}" +# print(f" {error_msg}") +# f.write(error_msg + "\n") +# print(f"\nError log saved to: {error_log_path}") + +# # Show summary of all processed files +# total_processed_files = len(list(Path(args.output_dir).glob("*.pt"))) +# print(f"\nTotal .pt files in {args.output_dir}: {total_processed_files}") + + +# if __name__ == "__main__": +# parser = argparse.ArgumentParser( +# description="Process audio files using Higgs Audio Tokenizer and save as .pt files" +# ) + +# parser.add_argument( +# "--dataset-path", +# type=str, +# default=DEFAULT_DATASET_PATH, +# help=f"Path to the dataset (default: {DEFAULT_DATASET_PATH})" +# ) + +# parser.add_argument( +# "--output-dir", +# type=str, +# default=DEFAULT_OUTPUT_DIR, +# help=f"Output directory for .pt files (default: {DEFAULT_OUTPUT_DIR})" +# ) + +# parser.add_argument( +# "--num-cores", +# type=int, +# default=DEFAULT_NUM_CORES, +# help=f"Number of CPU cores to use (default: {DEFAULT_NUM_CORES})" +# ) + +# parser.add_argument( +# "--sample-rate", +# type=int, +# default=DEFAULT_SAMPLE_RATE, +# help=f"Sample rate for audio processing (default: {DEFAULT_SAMPLE_RATE})" +# ) + +# parser.add_argument( +# "--chunksize", +# type=int, +# default=1, +# help="Chunksize for multiprocessing pool (default: 1)" +# ) + +# args = parser.parse_args() + +# # Run main processing +# try: +# main(args) +# except KeyboardInterrupt: +# print("\n\nProcessing interrupted by user") +# sys.exit(1) +# except Exception as e: +# print(f"\n\nError: {e}") +# sys.exit(1) + +#!/usr/bin/env python3 +""" +GPU Batch Processing Script for Boson Codes with Dataset Loading +""" + +import os +import sys +import json +import torch +import torch.nn.functional as F +import librosa +import numpy as np +from pathlib import Path +from tqdm import tqdm +import warnings +from torch.nn.utils import remove_weight_norm, weight_norm + + +# from boson_multimodal.audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer +# model = load_higgs_audio_tokenizer("bosonai/higgs-audio-v2-tokenizer") +import librosa +import torch +import torch.nn.functional as F +import numpy as np +import json +import torch + +from higgs_audio_tokenizer import HiggsAudioTokenizer +# model = load_higgs_audio_tokenizer("bosonai/higgs-audio-v2-tokenizer") + +import torch +import torch.nn as nn +import warnings + +# Suppress warnings +warnings.filterwarnings('ignore') + +def remove_weight_norms_from_model(model): + for module in model.modules(): + try: + remove_weight_norm(module) + except: + continue + return model + + +class EncodedResult: + def __init__(self, audio_codes): + self.audio_codes = audio_codes + +def encode_batch(model, x_batch): + """ + Encodes a batch of audio tensors using the HiggsAudioTokenizer model. + Args: + model: The loaded HiggsAudioTokenizer model. + x_batch: A tensor of shape [B, 1, T] + """ + # Acoustic and Semantic Feature Extraction + e_semantic_input = model.get_regress_target(x_batch).detach() + e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2)) + e_acoustic = model.encoder(x_batch) + + # This block contains the fix for batch processing + if e_acoustic.shape[2] != e_semantic.shape[2]: + pad_size = 160 * model.semantic_downsample_factor + + # 1. Remove channel dim, preserving batch dim -> [B, T] + x_slice = x_batch[:, 0, :] + + # 2. Pad the tensor + x_padded = F.pad(x_slice, (pad_size, pad_size)) + + # 3. Re-add channel dim before passing to encoder -> [B, 1, T_padded] + e_acoustic = model.encoder(x_padded.unsqueeze(1)) + + # Ensure dimensions match before concatenating + min_len = min(e_acoustic.shape[2], e_semantic.shape[2]) + e_acoustic = e_acoustic[:, :, :min_len] + e_semantic = e_semantic[:, :, :min_len] + + # Remainder of the original encoding logic + e = torch.cat([e_acoustic, e_semantic], dim=1) + e = model.fc_prior(e.transpose(1, 2)) + + if model.quantizer_type == "RVQ": + e = e.transpose(1, 2) + _, codes, _, _ = model.quantizer(e, model.frame_rate, None) + codes = codes.permute(1, 0, 2) + else: # RFSQ + quantized, codes = model.quantizer(e) + codes = codes.permute(0, 2, 1) + + return EncodedResult(audio_codes=codes) + + +def fix_all_inference_issues(model): + """ + Comprehensive fix for all potential inference issues + """ + device = next(model.parameters()).device + + # 1. Force everything to eval mode + model.eval() + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, nn.Module): + module.eval() + if hasattr(module, 'training'): + module.training = False + + # 2. Fix semantic model specifically + if hasattr(model, 'semantic_model'): + print("Fixing semantic model...") + + # Move to correct device + model.semantic_model = model.semantic_model.to(device) + model.semantic_model.eval() + + # Disable ALL gradient checkpointing + def disable_gradient_checkpointing(module): + if hasattr(module, 'gradient_checkpointing'): + module.gradient_checkpointing = False + if hasattr(module, 'gradient_checkpointing_disable'): + try: + module.gradient_checkpointing_disable() + except: + pass + for child in module.children(): + disable_gradient_checkpointing(child) + + disable_gradient_checkpointing(model.semantic_model) + + # For HuBERT specifically + if hasattr(model.semantic_model, 'encoder'): + model.semantic_model.encoder.gradient_checkpointing = False + if hasattr(model.semantic_model.encoder, 'layers'): + for layer in model.semantic_model.encoder.layers: + if hasattr(layer, 'gradient_checkpointing'): + layer.gradient_checkpointing = False + + # 3. Set all dropout to eval mode + def set_dropout_eval(module): + if isinstance(module, nn.Dropout): + module.eval() + module.training = False + for child in module.children(): + set_dropout_eval(child) + + set_dropout_eval(model) + + # 4. Clear any cached computations + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + return model + +def inference_pipeline(checkpoint_path, config_path, device='cuda'): + """ + Complete pipeline for inference with your trained model + """ + # Load config + print("Loading config...") + with open(config_path, 'r') as f: + config = json.load(f) + + # Create model + print("Creating model...") + model = HiggsAudioTokenizer( + n_filters=config['n_filters'], + D=config['D'], + target_bandwidths=config['target_bandwidths'], + ratios=config['ratios'], + sample_rate=config['sample_rate'], + bins=config['bins'], + n_q=config['n_q'], + codebook_dim=config.get('codebook_dim', None), + semantic_techer=config['semantic_techer'], + device=device + ).to(device) + + # Load checkpoint + print("Loading checkpoint...") + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + + if 'model_state_dict' in checkpoint: + state_dict = checkpoint['model_state_dict'] + else: + state_dict = checkpoint + + # Remove 'module.' prefix if present (from DDP) + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + + model.load_state_dict(new_state_dict, strict=False) + + # Fix all inference issues + print("Fixing inference issues...") + model = fix_all_inference_issues(model) + + + return model + + + +# # Add paths +# sys.path.insert(0, "/home/ubuntu/AP-BWE") + +# Suppress warnings +warnings.filterwarnings("ignore") + +# Configuration +OUTPUT_DIR = "/home/ubuntu/data_boson_44.1khz" +BATCH_SIZE = 32 +SAMPLE_RATE = 44100 +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/Qanary_data" + +# # Model paths +# CONFIG_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/config.json" +# MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/model.pth" + +# --- Setup --- +print(f"Using device: {DEVICE}") + +# Change to working directory +os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing") + +# Load dataset +from datasets import load_from_disk + + +print(f"Loading dataset from: {DATASET_PATH}") +ds = load_from_disk(DATASET_PATH) +print(f"Dataset info: {ds}") + +# Remove unnecessary columns +columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask'] +existing_columns = [col for col in columns_to_remove if col in ds.column_names] +if existing_columns: + ds = ds.remove_columns(existing_columns) + +df = ds.to_pandas() +print(f"Loaded {len(df)} files from dataset") + +os.makedirs(OUTPUT_DIR, exist_ok=True) +print(f"Output directory '{OUTPUT_DIR}' is ready.") + +# --- Filter already processed --- +print("Checking for already processed files...") + +def get_output_path(audio_path): + base_name = Path(audio_path).stem + return os.path.join(OUTPUT_DIR, f"{base_name}.pt") + +# Filter +original_count = len(df) +df['output_exists'] = df['filename'].apply(lambda x: os.path.exists(get_output_path(x))) +df_filtered = df[~df['output_exists']].copy() +skipped_count = original_count - len(df_filtered) + +print(f"Found {skipped_count} already processed files. Skipping them.") +print(f"Processing {len(df_filtered)} remaining files.") + +if len(df_filtered) == 0: + print("All files have already been processed!") + exit() + +# --- Load Model --- +print("Loading Higgs Audio Tokenizer model...") + +from transformers import HubertModel +from higgs_audio_tokenizer import HiggsAudioTokenizer + +# Load config +# with open(CONFIG_PATH, 'r') as f: +# config = json.load(f) + +# # Initialize model +# model = HiggsAudioTokenizer( +# **config, +# device=DEVICE, +# ) + +# Load weights +# parameter_dict = torch.load(MODEL_PATH, map_location=DEVICE) +# _ = model.load_state_dict(parameter_dict, strict=False) +# model = model.to(DEVICE) +# _ = model.eval() + + +checkpoint_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/outputs_CQT/checkpoints/step_99000.pth' +config_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/config copy.json' +device = 'cuda' +model = inference_pipeline(checkpoint_path, config_path, device) +_ = model.eval() + +model = remove_weight_norms_from_model(model) + +print(f"Model loaded on {DEVICE}") + +# Get hop length +hop_length = model.hop_length +print(f"Encoder hop length: {hop_length}") + +# --- Batch Processing --- +print(f"\nStarting batch processing with batch size {BATCH_SIZE}...") + +# Process in batches +filenames = df_filtered['filename'].tolist() +total_processed = 0 +total_errors = 0 + +with torch.no_grad(): + for batch_start in tqdm(range(0, len(filenames), BATCH_SIZE), desc="Processing batches"): + batch_end = min(batch_start + BATCH_SIZE, len(filenames)) + batch_filenames = filenames[batch_start:batch_end] + + batch_audio = [] + batch_lengths = [] + batch_outputs = [] + + # Load batch + for filename in batch_filenames: + output_path = get_output_path(filename) + + # Skip if exists (race condition check) + if os.path.exists(output_path): + continue + + try: + # Load audio + wav, _ = librosa.load(filename, sr=SAMPLE_RATE) + wav_tensor = torch.from_numpy(wav).float() + + batch_audio.append(wav_tensor) + batch_lengths.append(len(wav)) + batch_outputs.append(output_path) + + except Exception as e: + print(f"\nError loading {filename}: {e}") + total_errors += 1 + continue + + if not batch_audio: + continue + + # Pad batch to same length + max_len = max(len(x) for x in batch_audio) + padded_batch = [] + + for audio in batch_audio: + pad_len = max_len - len(audio) + if pad_len > 0: + audio = F.pad(audio, (0, pad_len), mode='constant', value=0) + # Don't add extra dimensions here, just collect the padded audio + padded_batch.append(audio) + + # Convert list to tensor and add channel dimension + # Stack along batch dimension to get [B, T] + batch_tensor = torch.stack(padded_batch, dim=0) # [B, T] + # Add channel dimension + batch_tensor = batch_tensor.unsqueeze(1) # [B, 1, T] + batch_tensor = batch_tensor.to(DEVICE) + + # Encode batch + try: + encoded = encode_batch(model, batch_tensor) + codes = encoded.audio_codes # [B, n_codebooks, T_compressed] + + # Save each item + for idx, (output_path, orig_len) in enumerate(zip(batch_outputs, batch_lengths)): + # Calculate true code length + true_code_len = int(np.ceil(orig_len / hop_length)) + + # Extract non-padded codes + item_codes = codes[idx, :, :true_code_len].cpu() + + # Save + torch.save(item_codes, output_path) + total_processed += 1 + + except Exception as e: + print(f"\nError encoding batch: {e}") + total_errors += len(batch_outputs) + +print("\n" + "="*50) +print("PROCESSING COMPLETE!") +print("="*50) +print(f"Successfully processed: {total_processed} files") +print(f"Previously processed: {skipped_count} files") +print(f"Errors encountered: {total_errors} files") +print(f"Output directory: {OUTPUT_DIR}") + +# Final count +final_count = len(list(Path(OUTPUT_DIR).glob("*.pt"))) +print(f"Total .pt files in output: {final_count}") \ No newline at end of file diff --git a/descriptaudiocodec/__init__.py b/descriptaudiocodec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/descriptaudiocodec/__pycache__/__init__.cpython-311.pyc b/descriptaudiocodec/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5417b2d7e5a3a3836bc433bce2c67551c932df5 Binary files /dev/null and b/descriptaudiocodec/__pycache__/__init__.cpython-311.pyc differ diff --git a/descriptaudiocodec/__pycache__/__init__.cpython-312.pyc b/descriptaudiocodec/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a6afaf3f849bc8775320106dba961dbd9d7a1b2 Binary files /dev/null and b/descriptaudiocodec/__pycache__/__init__.cpython-312.pyc differ diff --git a/descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc b/descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17e0763835c7f9d868e9da6e0737fd6aa174c813 Binary files /dev/null and b/descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc differ diff --git a/descriptaudiocodec/dac/model/__pycache__/base.cpython-312.pyc b/descriptaudiocodec/dac/model/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29eec54b51ac6980d447d1f0c709f2c200eb0d34 Binary files /dev/null and b/descriptaudiocodec/dac/model/__pycache__/base.cpython-312.pyc differ diff --git a/descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc b/descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bda175a90551968307c86fc6903b66a018e61dd Binary files /dev/null and b/descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc differ diff --git a/descriptaudiocodec/dac/model/__pycache__/dac.cpython-312.pyc b/descriptaudiocodec/dac/model/__pycache__/dac.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86d77726110b92fadaa674c0886ebae32d036122 Binary files /dev/null and b/descriptaudiocodec/dac/model/__pycache__/dac.cpython-312.pyc differ diff --git a/descriptaudiocodec/dac/model/base.py b/descriptaudiocodec/dac/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..08e39a2d9016c6ddc2491d0e2644b80c8efe3986 --- /dev/null +++ b/descriptaudiocodec/dac/model/base.py @@ -0,0 +1,286 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.") + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = audio_signal.signal_duration if win_duration is None else win_duration + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length) + + self.padding = original_padding + return recons diff --git a/descriptaudiocodec/dac/model/dac.py b/descriptaudiocodec/dac/model/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..efaed1c25eee7cbb55a96b4f12376b9d26d4a685 --- /dev/null +++ b/descriptaudiocodec/dac/model/dac.py @@ -0,0 +1,365 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from dac.nn.layers import Snake1d +from dac.nn.layers import WNConv1d +from dac.nn.layers import WNConvTranspose1d +from dac.nn.quantize import ResidualVectorQuantize + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 256, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, # out_pad, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + if i == 1: + out_pad = 1 + else: + out_pad = 0 + layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + # nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + return self.decoder(z) + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p / 1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/descriptaudiocodec/dac/nn/layers.py b/descriptaudiocodec/dac/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/descriptaudiocodec/dac/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/descriptaudiocodec/dac/nn/quantize.py b/descriptaudiocodec/dac/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..8861224cbb49813816dc41b63059faa13d246cc7 --- /dev/null +++ b/descriptaudiocodec/dac/nn/quantize.py @@ -0,0 +1,251 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from dac.nn.layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual) + + # Create mask to apply quantizer dropout + mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/discriminator.py b/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c7e9e84f72badb87459b4ca69c8acb585dd3f9 --- /dev/null +++ b/discriminator.py @@ -0,0 +1,596 @@ +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from audiotools import AudioSignal +# from audiotools import ml +# from audiotools import STFTParams +# from einops import rearrange +# from torch.nn.utils import weight_norm + + +# def WNConv1d(*args, **kwargs): +# act = kwargs.pop("act", True) +# conv = weight_norm(nn.Conv1d(*args, **kwargs)) +# if not act: +# return conv +# return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +# def WNConv2d(*args, **kwargs): +# act = kwargs.pop("act", True) +# conv = weight_norm(nn.Conv2d(*args, **kwargs)) +# if not act: +# return conv +# return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +# class MPD(nn.Module): +# def __init__(self, period): +# super().__init__() +# self.period = period +# self.convs = nn.ModuleList( +# [ +# WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), +# WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), +# WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), +# WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), +# WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), +# ] +# ) +# self.conv_post = WNConv2d( +# 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False +# ) + +# def pad_to_period(self, x): +# t = x.shape[-1] +# x = F.pad(x, (0, self.period - t % self.period), mode="reflect") +# return x + +# def forward(self, x): +# fmap = [] + +# x = self.pad_to_period(x) +# x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + +# for layer in self.convs: +# x = layer(x) +# fmap.append(x) + +# x = self.conv_post(x) +# fmap.append(x) + +# return fmap + + +# class MSD(nn.Module): +# def __init__(self, rate: int = 1, sample_rate: int = 44100): +# super().__init__() +# self.convs = nn.ModuleList( +# [ +# WNConv1d(1, 16, 15, 1, padding=7), +# WNConv1d(16, 64, 41, 4, groups=4, padding=20), +# WNConv1d(64, 256, 41, 4, groups=16, padding=20), +# WNConv1d(256, 1024, 41, 4, groups=64, padding=20), +# WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), +# WNConv1d(1024, 1024, 5, 1, padding=2), +# ] +# ) +# self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) +# self.sample_rate = sample_rate +# self.rate = rate + +# def forward(self, x): +# x = AudioSignal(x, self.sample_rate) +# x.resample(self.sample_rate // self.rate) +# x = x.audio_data + +# fmap = [] + +# for l in self.convs: +# x = l(x) +# fmap.append(x) +# x = self.conv_post(x) +# fmap.append(x) + +# return fmap + + +# BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +# class MRD(nn.Module): +# def __init__( +# self, +# window_length: int, +# hop_factor: float = 0.25, +# sample_rate: int = 44100, +# bands: list = BANDS, +# ): +# """Complex multi-band spectrogram discriminator. +# Parameters +# ---------- +# window_length : int +# Window length of STFT. +# hop_factor : float, optional +# Hop factor of the STFT, defaults to ``0.25 * window_length``. +# sample_rate : int, optional +# Sampling rate of audio in Hz, by default 44100 +# bands : list, optional +# Bands to run discriminator over. +# """ +# super().__init__() + +# self.window_length = window_length +# self.hop_factor = hop_factor +# self.sample_rate = sample_rate +# self.stft_params = STFTParams( +# window_length=window_length, +# hop_length=int(window_length * hop_factor), +# match_stride=True, +# ) + +# n_fft = window_length // 2 + 1 +# bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] +# self.bands = bands + +# ch = 32 +# convs = lambda: nn.ModuleList( +# [ +# WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), +# WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), +# WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), +# WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), +# WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), +# ] +# ) +# self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) +# self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + +# def spectrogram(self, x): +# x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) +# x = torch.view_as_real(x.stft()) +# x = rearrange(x, "b 1 f t c -> (b 1) c t f") +# # Split into bands +# x_bands = [x[..., b[0] : b[1]] for b in self.bands] +# return x_bands + +# def forward(self, x): +# x_bands = self.spectrogram(x) +# fmap = [] + +# x = [] +# for band, stack in zip(x_bands, self.band_convs): +# for layer in stack: +# band = layer(band) +# fmap.append(band) +# x.append(band) + +# x = torch.cat(x, dim=-1) +# x = self.conv_post(x) +# fmap.append(x) + +# return fmap + + +# class Discriminator(ml.BaseModel): +# def __init__( +# self, +# rates: list = [], +# periods: list = [2, 3, 5, 7, 11], +# fft_sizes: list = [2048, 1024, 512], +# sample_rate: int = 44100, +# bands: list = BANDS, +# ): +# """Discriminator that combines multiple discriminators. + +# Parameters +# ---------- +# rates : list, optional +# sampling rates (in Hz) to run MSD at, by default [] +# If empty, MSD is not used. +# periods : list, optional +# periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] +# fft_sizes : list, optional +# Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] +# sample_rate : int, optional +# Sampling rate of audio in Hz, by default 44100 +# bands : list, optional +# Bands to run MRD at, by default `BANDS` +# """ +# super().__init__() +# discs = [] +# discs += [MPD(p) for p in periods] +# discs += [MSD(r, sample_rate=sample_rate) for r in rates] +# discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] +# self.discriminators = nn.ModuleList(discs) + +# def preprocess(self, y): +# # Remove DC offset +# y = y - y.mean(dim=-1, keepdims=True) +# # Peak normalize the volume of input audio +# y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) +# return y + +# def forward(self, x): +# x = self.preprocess(x) +# fmaps = [d(x) for d in self.discriminators] +# return fmaps + + +# if __name__ == "__main__": +# disc = Discriminator() +# x = torch.zeros(1, 1, 44100) +# results = disc(x) +# for i, result in enumerate(results): +# print(f"disc{i}") +# for i, r in enumerate(result): +# print(r.shape, r.mean(), r.min(), r.max()) +# print() +import torch +import torch.nn as nn +import torch.nn.functional as F +from audiotools import AudioSignal, STFTParams +from audiotools import ml +from einops import rearrange +from torch.nn.utils import weight_norm +import torchaudio +import nnAudio.features as features +from munch import Munch + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def get_2d_padding(kernel_size, dilation=(1, 1)): + return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2), + int((kernel_size[1] * dilation[1] - dilation[1]) / 2)) + + +class NormConv2d(nn.Module): + """Conv2d with normalization""" + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, norm="weight_norm"): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias) + if norm == "weight_norm": + self.conv = weight_norm(self.conv) + + def forward(self, x): + return self.conv(x) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList([ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ]) + self.conv_post = WNConv2d(1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100): + super().__init__() + self.convs = nn.ModuleList([ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ]) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + return fmap + + +class DiscriminatorCQT(nn.Module): + def __init__(self, cfg, hop_length, n_octaves, bins_per_octave): + super().__init__() + self.cfg = cfg + self.filters = cfg.filters + self.max_filters = cfg.max_filters + self.filters_scale = cfg.filters_scale + self.kernel_size = (3, 9) + self.dilations = cfg.dilations + self.stride = (1, 2) + self.in_channels = cfg.in_channels + self.out_channels = cfg.out_channels + self.fs = cfg.sampling_rate + self.hop_length = hop_length + self.n_octaves = n_octaves + self.bins_per_octave = bins_per_octave + + self.cqt_transform = features.cqt.CQT2010v2( + sr=self.fs * 2, + hop_length=self.hop_length, + n_bins=self.bins_per_octave * self.n_octaves, + bins_per_octave=self.bins_per_octave, + output_format="Complex", + pad_mode="constant", + ) + + self.conv_pres = nn.ModuleList() + for i in range(self.n_octaves): + self.conv_pres.append( + NormConv2d( + self.in_channels * 2, # Real + Imaginary + self.in_channels * 2, + kernel_size=self.kernel_size, + padding=get_2d_padding(self.kernel_size), + norm="weight_norm", + ) + ) + + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d( + self.in_channels * 2, + self.filters, + kernel_size=self.kernel_size, + padding=get_2d_padding(self.kernel_size), + ) + ) + + in_chs = min(self.filters_scale * self.filters, self.max_filters) + for i, dilation in enumerate(self.dilations): + out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=(dilation, 1), + padding=get_2d_padding(self.kernel_size, (dilation, 1)), + norm="weight_norm", + ) + ) + in_chs = out_chs + + out_chs = min( + (self.filters_scale ** (len(self.dilations) + 1)) * self.filters, + self.max_filters, + ) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + norm="weight_norm", + ) + ) + + self.conv_post = NormConv2d( + out_chs, + self.out_channels, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + norm="weight_norm", + ) + + self.activation = torch.nn.LeakyReLU(negative_slope=0.1) + self.resample = torchaudio.transforms.Resample( + orig_freq=self.fs, new_freq=self.fs * 2 + ) + + def forward(self, x): + fmap = [] + x = self.resample(x) + z = self.cqt_transform(x) + + + z_amplitude = z[:, :, :, 0].unsqueeze(1) + z_phase = z[:, :, :, 1].unsqueeze(1) + z = torch.cat([z_amplitude, z_phase], dim=1) + z = rearrange(z, "b c w t -> b c t w") + + latent_z = [] + for i in range(self.n_octaves): + octave_band = z[:, :, :, i * self.bins_per_octave : (i + 1) * self.bins_per_octave] + processed_band = self.conv_pres[i](octave_band) + latent_z.append(processed_band) + latent_z = torch.cat(latent_z, dim=-1) + + for i, l in enumerate(self.convs): + latent_z = l(latent_z) + latent_z = self.activation(latent_z) + fmap.append(latent_z) + + latent_z = self.conv_post(latent_z) + fmap.append(latent_z) + + return fmap + + +class MultiScaleSubbandCQT(nn.Module): + """CQT discriminator at multiple scales""" + def __init__(self, sample_rate=44100): + super().__init__() + cfg = Munch({ + "hop_lengths": [1024, 512, 512], + "sampling_rate": sample_rate, + "filters": 32, + "max_filters": 1024, + "filters_scale": 1, + "dilations": [1, 2, 4], + "in_channels": 1, + "out_channels": 1, + "n_octaves": [10, 10, 10], + "bins_per_octaves": [24, 36, 48], + }) + self.cfg = cfg + self.discriminators = nn.ModuleList([ + DiscriminatorCQT( + cfg, + hop_length=cfg.hop_lengths[i], + n_octaves=cfg.n_octaves[i], + bins_per_octave=cfg.bins_per_octaves[i], + ) + for i in range(len(cfg.hop_lengths)) + ]) + + def forward(self, x): + fmap = [] + for disc in self.discriminators: + fmap.extend(disc(x)) + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + +class MRD(nn.Module): + def __init__(self, window_length: int, hop_factor: float = 0.25, + sample_rate: int = 44100, bands: list = BANDS): + """Multi-resolution spectrogram discriminator.""" + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList([ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ]) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + return fmap + + +class Discriminator(ml.BaseModel): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + ): + """Discriminator combining MPD, MSD, MRD and CQT. + + Parameters + ---------- + rates : list, optional + Sampling rates for MSD, by default [] + periods : list, optional + Periods for MPD, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + FFT sizes for MRD, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + """ + super().__init__() + discs = [] + # Time-domain discriminators + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + + # Frequency-domain discriminators (both STFT and CQT) + discs += [MRD(f, sample_rate=sample_rate) for f in fft_sizes] + discs += [MultiScaleSubbandCQT(sample_rate=sample_rate)] + + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps \ No newline at end of file diff --git a/higgs_audio_tokenizer.py b/higgs_audio_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..31753aa5427763a739c854959afa32c9149cc2b1 --- /dev/null +++ b/higgs_audio_tokenizer.py @@ -0,0 +1,373 @@ +# Based on code from: https://github.com/zhenye234/xcodec +# Licensed under MIT License +# Modifications by BosonAI + +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Union, Sequence +import numpy as np +from transformers import AutoModel +import torchaudio +import json +import librosa +from huggingface_hub import snapshot_download + +from vector_quantize_pytorch import ResidualFSQ +from descriptaudiocodec.dac.model import dac as dac2 +from quantization.vq import ResidualVectorQuantizer +from semantic_module import Encoder, Decoder + +from transformers import HubertModel + + +# At the top of higgs_audio_tokenizer.py, after the imports + +def WNConv1d(*args, **kwargs): + """Applies weight normalization to a 1D Convolutional layer.""" + return nn.utils.weight_norm(nn.Conv1d(*args, **kwargs)) + +def WNLinear(*args, **kwargs): + """Applies weight normalization to a Linear layer.""" + return nn.utils.weight_norm(nn.Linear(*args, **kwargs)) + +def init_weights(m): + """ + Applies Xavier (Glorot) uniform initialization to Conv and Linear layers. + This is a robust, "classic" initialization scheme. + """ + if isinstance(m, (nn.Conv1d, nn.Conv2d)): + # Truncated normal initialization for convolutional layers + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + # Also apply to linear layers for consistency + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Embedding): + # Initialize the codebook gently as well + nn.init.trunc_normal_(m.weight, std=0.02) + + +class EncodedResult: + def __init__(self, audio_codes): + self.audio_codes = audio_codes + +class HiggsAudioFeatureExtractor(nn.Module): + def __init__(self, sampling_rate=16000): + super().__init__() + self.sampling_rate = sampling_rate + + def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"): + audio_signal = torch.tensor(raw_audio) + audio_signal = audio_signal.unsqueeze(0) + if len(audio_signal.shape) < 3: + audio_signal = audio_signal.unsqueeze(0) + return {"input_values": audio_signal} + + +class HiggsAudioTokenizer(nn.Module): + def __init__( + self, + n_filters: int = 32, + D: int = 128, + target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6], + ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 + sample_rate: int = 16000, + bins: int = 1024, + n_q: int = 8, + codebook_dim: int = None, + normalize: bool = False, + causal: bool = False, + semantic_techer: str = "hubert_base_general", + last_layer_semantic: bool = True, + merge_mode: str = "concat", + downsample_mode: str = "step_down", + semantic_mode: str = "classic", + vq_scale: int = 1, + semantic_sample_rate: int = None, + device: str = "cuda", + ): + super().__init__() + self.hop_length = np.prod(ratios) + self.semantic_techer = semantic_techer + + self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz + + self.target_bandwidths = target_bandwidths + self.n_q = n_q + self.sample_rate = sample_rate + self.encoder = dac2.Encoder(64, ratios, D) + + self.decoder_2 = dac2.Decoder(D, 1024, ratios) + self.last_layer_semantic = last_layer_semantic + self.device = device + if semantic_techer == "hubert_base": + self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960") + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + elif semantic_techer == "wavlm_base_plus": + self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus") + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + elif semantic_techer == "mHubert_base": + self.semantic_model = AutoModel.from_pretrained("utter-project/mHuBERT-147") + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + elif semantic_techer == "hubert_base_general": + self.semantic_model = HubertModel.from_pretrained("/home/ubuntu/.cache/huggingface/hub/models--bosonai--hubert_base/snapshots/b4b85f1652c16ad63fdc818221b215b79ff55934", trust_remote_code=False) + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer + if semantic_sample_rate is not None: + self.semantic_sample_rate = semantic_sample_rate + + self.semantic_model.eval() + + # make the semantic model parameters do not need gradient + for param in self.semantic_model.parameters(): + param.requires_grad = False + + self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320) + + self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale) + self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim) + self.decoder_semantic = Decoder( + code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim + ) + + # out_D=D+768 + if isinstance(bins, int): # RVQ + self.quantizer = ResidualVectorQuantizer( + dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins + ) + self.quantizer_type = "RVQ" + else: # RFSQ + self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q) + self.quantizer_type = "RFSQ" + + # self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim) + # self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim) + # self.fc_post2 = nn.Linear(self.quantizer_dim, D) + + + self.fc_prior = WNLinear(D + self.encoder_semantic_dim, self.quantizer_dim) + self.fc_post1 = WNLinear(self.quantizer_dim, self.encoder_semantic_dim) + self.fc_post2 = WNLinear(self.quantizer_dim, D) + + + self.downsample_mode = downsample_mode + if downsample_mode == "avg": + self.semantic_pooling = nn.AvgPool1d( + kernel_size=self.semantic_downsample_factor, stride=self.semantic_downsample_factor + ) + + self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate) + + self.apply(init_weights) + + @property + def tps(self): + return self.frame_rate + + @property + def sampling_rate(self): + return self.sample_rate + + @property + def num_codebooks(self): + return self.n_q + + @property + def codebook_size(self): + return self.quantizer_dim + + def get_last_layer(self): + return self.decoder.layers[-1].weight + + def calculate_rec_loss(self, rec, target): + target = target / target.norm(dim=-1, keepdim=True) + rec = rec / rec.norm(dim=-1, keepdim=True) + rec_loss = (1 - (target * rec).sum(-1)).mean() + + return rec_loss + + @torch.no_grad() + def get_regress_target(self, x): + x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate) + + if ( + self.semantic_techer == "hubert_base" + or self.semantic_techer == "hubert_base_general" + or self.semantic_techer == "wavlm_base_plus" + ): + x = x[:, 0, :] + x = F.pad(x, (160, 160)) + target = self.semantic_model(x, output_hidden_states=True).hidden_states + target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2) + + # average for all layers + target = target.mean(1) + # target = target[9] + # if self.hop_length > 320: + # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) + + elif self.semantic_techer == "w2v_bert2": + target = self.semantic_model(x) + + elif self.semantic_techer.startswith("whisper"): + if self.last_layer_semantic: + target = self.semantic_model(x, avg_layers=False) + else: + target = self.semantic_model(x, avg_layers=True) + + elif self.semantic_techer.startswith("mert_music"): + if self.last_layer_semantic: + target = self.semantic_model(x, avg_layers=False) + else: + target = self.semantic_model(x, avg_layers=True) + + elif self.semantic_techer.startswith("qwen_audio_omni"): + target = self.semantic_model(x) + + if self.downsample_mode == "step_down": + if self.semantic_downsample_factor > 1: + target = target[:, :: self.semantic_downsample_factor, :] + + elif self.downsample_mode == "avg": + target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) + return target + + def forward(self, x: torch.Tensor, bw: int): + e_semantic_input = self.get_regress_target(x).detach() + + e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) + e_acoustic = self.encoder(x) + + e = torch.cat([e_acoustic, e_semantic], dim=1) + + e = self.fc_prior(e.transpose(1, 2)) + + if self.quantizer_type == "RVQ": + e = e.transpose(1, 2) + quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) + quantized = quantized.transpose(1, 2) + else: + quantized, codes = self.quantizer(e) + commit_loss = torch.tensor(0.0) + + quantized_semantic = self.fc_post1(quantized).transpose(1, 2) + quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) + + o = self.decoder_2(quantized_acoustic) + + o_semantic = self.decoder_semantic(quantized_semantic) + semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic) + + return o, commit_loss, semantic_recon_loss, None + + def encode(self, audio_path_or_wv, sr=None, loudness_normalize=False, loudness_threshold=-23.0): + if isinstance(audio_path_or_wv, str): + wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None) + else: + wv = audio_path_or_wv + assert sr is not None + if loudness_normalize: + import pyloudnorm as pyln + + meter = pyln.Meter(sr) + l = meter.integrated_loudness(wv) + wv = pyln.normalize.loudness(wv, l, loudness_threshold) + if sr != self.sampling_rate: + wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate) + if self.audio_tokenizer_feature_extractor is not None: + inputs = self.audio_tokenizer_feature_extractor( + raw_audio=wv, sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate, return_tensors="pt" + ) + input_values = inputs["input_values"].to(self.device) + else: + input_values = torch.from_numpy(wv).float().unsqueeze(0) + with torch.no_grad(): + encoder_outputs = self._xcodec_encode(input_values) + vq_code = encoder_outputs.audio_codes[0] + return vq_code + + + + def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: + bw = target_bw + + e_semantic_input = self.get_regress_target(x).detach() + + e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) + e_acoustic = self.encoder(x) + + if e_acoustic.shape[2] != e_semantic.shape[2]: + pad_size = 160 * self.semantic_downsample_factor + e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0)) + + if e_acoustic.shape[2] != e_semantic.shape[2]: + if e_acoustic.shape[2] > e_semantic.shape[2]: + e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]] + else: + e_semantic = e_semantic[:, :, : e_acoustic.shape[2]] + + e = torch.cat([e_acoustic, e_semantic], dim=1) + + e = self.fc_prior(e.transpose(1, 2)) + + if self.quantizer_type == "RVQ": + e = e.transpose(1, 2) + quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) + codes = codes.permute(1, 0, 2) + else: + quantized, codes = self.quantizer(e) + codes = codes.permute(0, 2, 1) + + # return codes + return EncodedResult(codes) + + def decode(self, vq_code: torch.Tensor) -> torch.Tensor: + if self.quantizer_type == "RVQ": + vq_code = vq_code.permute(1, 0, 2) + quantized = self.quantizer.decode(vq_code) + quantized = quantized.transpose(1, 2) + else: + vq_code = vq_code.permute(0, 2, 1) + quantized = self.quantizer.get_output_from_indices(vq_code) + quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) + + o = self.decoder_2(quantized_acoustic) + return o.cpu().numpy() + + +def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"): + is_local = os.path.exists(tokenizer_name_or_path) + if not is_local: + tokenizer_path = snapshot_download(tokenizer_name_or_path) + else: + tokenizer_path = tokenizer_name_or_path + config_path = os.path.join(tokenizer_path, "config.json") + model_path = os.path.join(tokenizer_path, "model.pth") + config = json.load(open(config_path)) + model = HiggsAudioTokenizer( + **config, + device=device, + ) + parameter_dict = torch.load(model_path, map_location=device, weights_only=False) + model.load_state_dict(parameter_dict, strict=False) + model.to(device) + model.eval() + return model diff --git a/loss.py b/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ecf676cbbb15b43833d34c9bb53210c4a320bf24 --- /dev/null +++ b/loss.py @@ -0,0 +1,368 @@ +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn +import typing +from typing import List + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature + + diff --git a/outputs/logs/250801-104649/events.out.tfevents.1754045209.192-222-50-191.575849.0 b/outputs/logs/250801-104649/events.out.tfevents.1754045209.192-222-50-191.575849.0 new file mode 100644 index 0000000000000000000000000000000000000000..c13991edbac6a281bc9a4307134bf09123be71a8 --- /dev/null +++ b/outputs/logs/250801-104649/events.out.tfevents.1754045209.192-222-50-191.575849.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:adf08230fc10f50d2a7fcb4c2deaf1b2cbb45116b6d44f1cce5eb365471f516b +size 657 diff --git a/outputs/logs/250801-104824/events.out.tfevents.1754045304.192-222-50-191.577752.0 b/outputs/logs/250801-104824/events.out.tfevents.1754045304.192-222-50-191.577752.0 new file mode 100644 index 0000000000000000000000000000000000000000..d968f53c72dd2706d67115e266d2b2a13a4c93cb --- /dev/null +++ b/outputs/logs/250801-104824/events.out.tfevents.1754045304.192-222-50-191.577752.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85e63a92198fbc1d513003ee532294feaedf87505a51e15ec2fc520900ec53c4 +size 657 diff --git a/outputs/logs/250801-104944/events.out.tfevents.1754045384.192-222-50-191.579650.0 b/outputs/logs/250801-104944/events.out.tfevents.1754045384.192-222-50-191.579650.0 new file mode 100644 index 0000000000000000000000000000000000000000..78b839843cac4dfa5c19e8f03ad9baa0bc15ef66 --- /dev/null +++ b/outputs/logs/250801-104944/events.out.tfevents.1754045384.192-222-50-191.579650.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4e7f045c676534be11f5ad39b553611caa786d3e1559cbf100ca84133d54a77 +size 88 diff --git a/outputs/logs/250801-105034/events.out.tfevents.1754045434.192-222-50-191.581483.0 b/outputs/logs/250801-105034/events.out.tfevents.1754045434.192-222-50-191.581483.0 new file mode 100644 index 0000000000000000000000000000000000000000..8ac19be4e0263f774bf20332bd2f746513a0853e --- /dev/null +++ b/outputs/logs/250801-105034/events.out.tfevents.1754045434.192-222-50-191.581483.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5c97f5d3b5b6cc60223a60dceb2b06f6a847e682935f8d45086892f176bd98d +size 657 diff --git a/outputs/logs/250801-105133/events.out.tfevents.1754045493.192-222-50-191.583409.0 b/outputs/logs/250801-105133/events.out.tfevents.1754045493.192-222-50-191.583409.0 new file mode 100644 index 0000000000000000000000000000000000000000..0b19844e0329ffdb03e92b1d9b21302c03bac215 --- /dev/null +++ b/outputs/logs/250801-105133/events.out.tfevents.1754045493.192-222-50-191.583409.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7bae6adc53141720bb44a2fda69e2f97bce51b98365458ec53ce59c66bf255f +size 5751664 diff --git a/outputs/logs/250801-134657/events.out.tfevents.1754056017.192-222-50-191.688744.0 b/outputs/logs/250801-134657/events.out.tfevents.1754056017.192-222-50-191.688744.0 new file mode 100644 index 0000000000000000000000000000000000000000..82a1b56ddb3579dac71d2f279a1f01e9299b4fa1 --- /dev/null +++ b/outputs/logs/250801-134657/events.out.tfevents.1754056017.192-222-50-191.688744.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c021361932098e4ba601f59134b8ea6c0192a31d103eb6374070e394bb740060 +size 61388 diff --git a/outputs/logs/250801-135301/events.out.tfevents.1754056381.192-222-50-191.693590.0 b/outputs/logs/250801-135301/events.out.tfevents.1754056381.192-222-50-191.693590.0 new file mode 100644 index 0000000000000000000000000000000000000000..f12b4988dae20d437612eb93bdbf624afd7f28fd --- /dev/null +++ b/outputs/logs/250801-135301/events.out.tfevents.1754056381.192-222-50-191.693590.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61a297a283372869b74886f9fa2a17d11bd77fe284275aae505fbaa28957cda8 +size 88 diff --git a/outputs/logs/250801-135344/events.out.tfevents.1754056424.192-222-50-191.695388.0 b/outputs/logs/250801-135344/events.out.tfevents.1754056424.192-222-50-191.695388.0 new file mode 100644 index 0000000000000000000000000000000000000000..3f7d17b483111a27b91a69c970253cefca95810b --- /dev/null +++ b/outputs/logs/250801-135344/events.out.tfevents.1754056424.192-222-50-191.695388.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0954cc521ca9b594ace0fb485ccf8556621e017c08da9d525e4071c94489c3ec +size 657 diff --git a/outputs/logs/250801-135510/events.out.tfevents.1754056510.192-222-50-191.697490.0 b/outputs/logs/250801-135510/events.out.tfevents.1754056510.192-222-50-191.697490.0 new file mode 100644 index 0000000000000000000000000000000000000000..01e94336e204e9c9a18a540335565fe9b2d8b252 --- /dev/null +++ b/outputs/logs/250801-135510/events.out.tfevents.1754056510.192-222-50-191.697490.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1dcf0296a6058bce4cfc81563b9ffaac98d0e171d3f69e15008f179aeb4e5d8e +size 3419391 diff --git a/outputs/logs/250801-202235/events.out.tfevents.1754079755.192-222-50-191.6026.0 b/outputs/logs/250801-202235/events.out.tfevents.1754079755.192-222-50-191.6026.0 new file mode 100644 index 0000000000000000000000000000000000000000..b3a6f592ed90d4df7c7b8816ed9d04a7b0f12c9c --- /dev/null +++ b/outputs/logs/250801-202235/events.out.tfevents.1754079755.192-222-50-191.6026.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:163634220bcd180cfe60f9d73024faeeeaed6c2849ea5d44909e6666e4a4ac54 +size 88 diff --git a/outputs/logs/250801-202320/events.out.tfevents.1754079800.192-222-50-191.6708.0 b/outputs/logs/250801-202320/events.out.tfevents.1754079800.192-222-50-191.6708.0 new file mode 100644 index 0000000000000000000000000000000000000000..1cd4f0cd7dd77702ebb54503013062596bb2feaa --- /dev/null +++ b/outputs/logs/250801-202320/events.out.tfevents.1754079800.192-222-50-191.6708.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7088dc97ef965f0fb3bfff3b1f28e1b63644f666b6464a8524235f5bdd68a24 +size 7234042 diff --git a/outputs/logs/250802-065733/events.out.tfevents.1754117853.192-222-50-191.86944.0 b/outputs/logs/250802-065733/events.out.tfevents.1754117853.192-222-50-191.86944.0 new file mode 100644 index 0000000000000000000000000000000000000000..856425df7e8c9a9897cadf4eaa27c128e5f41e37 --- /dev/null +++ b/outputs/logs/250802-065733/events.out.tfevents.1754117853.192-222-50-191.86944.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b13f240f199d42345f13428575a813d039fdd4e8899a915b7b55741b504ea46 +size 208834 diff --git a/outputs/logs/250802-072035/events.out.tfevents.1754119235.192-222-50-191.100690.0 b/outputs/logs/250802-072035/events.out.tfevents.1754119235.192-222-50-191.100690.0 new file mode 100644 index 0000000000000000000000000000000000000000..337f0c9c4572fa8ed3f7a713fa8ddd54e2fb17f6 --- /dev/null +++ b/outputs/logs/250802-072035/events.out.tfevents.1754119235.192-222-50-191.100690.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93fc6699e79f38e10ada3f7431617725b9b35daee4cda1978fb73800f7304113 +size 3373 diff --git a/outputs_24/logs/250730-112649/events.out.tfevents.1753874809.192-222-50-191.3556345.0 b/outputs_24/logs/250730-112649/events.out.tfevents.1753874809.192-222-50-191.3556345.0 new file mode 100644 index 0000000000000000000000000000000000000000..8c38ba34d59c8bd5cf4cc80f4fb1418cc61b1b0a --- /dev/null +++ b/outputs_24/logs/250730-112649/events.out.tfevents.1753874809.192-222-50-191.3556345.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34843a6a8d0849900de8943c4bbf4dcf77424e605eb1e4a721b8cbf894f1fc73 +size 88 diff --git a/outputs_24/logs/250730-112910/events.out.tfevents.1753874950.192-222-50-191.3557426.0 b/outputs_24/logs/250730-112910/events.out.tfevents.1753874950.192-222-50-191.3557426.0 new file mode 100644 index 0000000000000000000000000000000000000000..bd78796ab3972d72f92e593ee10f2d97b6c49227 --- /dev/null +++ b/outputs_24/logs/250730-112910/events.out.tfevents.1753874950.192-222-50-191.3557426.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d9cca204af08430087351d78de76e17d16ad4ed8e2e31540b0857406c77bb86 +size 1990 diff --git a/outputs_24/logs/250730-113135/events.out.tfevents.1753875095.192-222-50-191.3558918.0 b/outputs_24/logs/250730-113135/events.out.tfevents.1753875095.192-222-50-191.3558918.0 new file mode 100644 index 0000000000000000000000000000000000000000..1afd5c4bedc3db057ffd910e2c3074a927a5d540 --- /dev/null +++ b/outputs_24/logs/250730-113135/events.out.tfevents.1753875095.192-222-50-191.3558918.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0f7180018cd52196492be018bc160ce1222c437e27d4019bd87d5207e9f1fe5 +size 63081 diff --git a/outputs_24/logs/250730-114727/events.out.tfevents.1753876047.192-222-50-191.3567432.0 b/outputs_24/logs/250730-114727/events.out.tfevents.1753876047.192-222-50-191.3567432.0 new file mode 100644 index 0000000000000000000000000000000000000000..c26eead670877d90c7fa2479da0e322a46e04b4b --- /dev/null +++ b/outputs_24/logs/250730-114727/events.out.tfevents.1753876047.192-222-50-191.3567432.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:395a91280aa73bfaf31de37cb66f76ca9916e25d59fa84568e2124ad19162e7a +size 3808 diff --git a/outputs_24/logs/250730-115006/events.out.tfevents.1753876206.192-222-50-191.3569242.0 b/outputs_24/logs/250730-115006/events.out.tfevents.1753876206.192-222-50-191.3569242.0 new file mode 100644 index 0000000000000000000000000000000000000000..82d45e7cefb43e722d676062548fd986b4977b9f --- /dev/null +++ b/outputs_24/logs/250730-115006/events.out.tfevents.1753876206.192-222-50-191.3569242.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b47cceec7b3b651af1fa77129099bd8f0fcebfb60ba0950b1a0efeba2ea28a2f +size 6744012 diff --git a/outputs_24/logs/250730-151325/events.out.tfevents.1753888405.192-222-50-191.3660307.0 b/outputs_24/logs/250730-151325/events.out.tfevents.1753888405.192-222-50-191.3660307.0 new file mode 100644 index 0000000000000000000000000000000000000000..728f1ec25c5385607b66647b6aab99fbcbe0dff5 --- /dev/null +++ b/outputs_24/logs/250730-151325/events.out.tfevents.1753888405.192-222-50-191.3660307.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69269e593e8fb146793020448eedf00cda64a27f3b3a3cebf47dbc5c7df5e8dc +size 5976 diff --git a/outputs_24/logs/250730-152054/events.out.tfevents.1753888854.192-222-50-191.3663830.0 b/outputs_24/logs/250730-152054/events.out.tfevents.1753888854.192-222-50-191.3663830.0 new file mode 100644 index 0000000000000000000000000000000000000000..00e96b258872c31483d12a59c8b3084918d1e865 --- /dev/null +++ b/outputs_24/logs/250730-152054/events.out.tfevents.1753888854.192-222-50-191.3663830.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f87195745b3babee5ab3afbc4b46d85decbf2c39f31376891b3d001f1e1db52 +size 88 diff --git a/outputs_24/logs/250730-152132/events.out.tfevents.1753888892.192-222-50-191.3664702.0 b/outputs_24/logs/250730-152132/events.out.tfevents.1753888892.192-222-50-191.3664702.0 new file mode 100644 index 0000000000000000000000000000000000000000..506de9466df0d9d44149e4e301a2264ff5fd3e2b --- /dev/null +++ b/outputs_24/logs/250730-152132/events.out.tfevents.1753888892.192-222-50-191.3664702.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7e10c22d4a06fb02d74592b2ac393196affce38724e0c2fa3117dfec6931acf +size 88 diff --git a/outputs_24/logs/250730-152218/events.out.tfevents.1753888938.192-222-50-191.3665630.0 b/outputs_24/logs/250730-152218/events.out.tfevents.1753888938.192-222-50-191.3665630.0 new file mode 100644 index 0000000000000000000000000000000000000000..e5db86a59e3a441f65a3e6a8315dc46b9305ef32 --- /dev/null +++ b/outputs_24/logs/250730-152218/events.out.tfevents.1753888938.192-222-50-191.3665630.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c9e1081361a51f466892f52660c7126dd4e281ad1e1a86b01ecf914a03318ed +size 657 diff --git a/outputs_24/logs/250730-152329/events.out.tfevents.1753889009.192-222-50-191.3666743.0 b/outputs_24/logs/250730-152329/events.out.tfevents.1753889009.192-222-50-191.3666743.0 new file mode 100644 index 0000000000000000000000000000000000000000..f2b081747de5a54944371423eb479fab170a3a86 --- /dev/null +++ b/outputs_24/logs/250730-152329/events.out.tfevents.1753889009.192-222-50-191.3666743.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbbfa9f88f5e7790426f8bee52f171526f4d7fb1ab5efb51e25c77581a3afee2 +size 657 diff --git a/outputs_24/logs/250730-152554/events.out.tfevents.1753889154.192-222-50-191.3668339.0 b/outputs_24/logs/250730-152554/events.out.tfevents.1753889154.192-222-50-191.3668339.0 new file mode 100644 index 0000000000000000000000000000000000000000..bf2ec0004b7945cb11a8e5837e8037c5b9c41e24 --- /dev/null +++ b/outputs_24/logs/250730-152554/events.out.tfevents.1753889154.192-222-50-191.3668339.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bd5fbc186ff41408a06c8d8ecb91ea3429ab1971c864edb0411ae2391af8965 +size 88 diff --git a/outputs_24/logs/250730-152702/events.out.tfevents.1753889222.192-222-50-191.3669391.0 b/outputs_24/logs/250730-152702/events.out.tfevents.1753889222.192-222-50-191.3669391.0 new file mode 100644 index 0000000000000000000000000000000000000000..ee428b1ed7c271b11b115a10d117522df9ca1b30 --- /dev/null +++ b/outputs_24/logs/250730-152702/events.out.tfevents.1753889222.192-222-50-191.3669391.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df5419dd06bca60309a60b28ee070c8fc816845de941a0b311ebfbb6777dccac +size 88 diff --git a/outputs_24/logs/250730-152902/events.out.tfevents.1753889342.192-222-50-191.3671654.0 b/outputs_24/logs/250730-152902/events.out.tfevents.1753889342.192-222-50-191.3671654.0 new file mode 100644 index 0000000000000000000000000000000000000000..47c10ffc4e2233907d90ef415ba5d4ac9012bf0f --- /dev/null +++ b/outputs_24/logs/250730-152902/events.out.tfevents.1753889342.192-222-50-191.3671654.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f2377fd533270911cdd47226a5b11222bb0d037b3a47eaac9c66b3ffc605d03 +size 1526378 diff --git a/outputs_24/logs/250730-161025/events.out.tfevents.1753891825.192-222-50-191.3698156.0 b/outputs_24/logs/250730-161025/events.out.tfevents.1753891825.192-222-50-191.3698156.0 new file mode 100644 index 0000000000000000000000000000000000000000..7b0a2caff544d865987eb08dbdd672902d94d05e --- /dev/null +++ b/outputs_24/logs/250730-161025/events.out.tfevents.1753891825.192-222-50-191.3698156.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3427b4010f32544ebc7c050f6ed85ce380c15bddcc039e3fffa295b4a2b4813 +size 1528786 diff --git a/outputs_24/logs/250730-165034/events.out.tfevents.1753894234.192-222-50-191.3717308.0 b/outputs_24/logs/250730-165034/events.out.tfevents.1753894234.192-222-50-191.3717308.0 new file mode 100644 index 0000000000000000000000000000000000000000..a56bf84bb1a08cdcc6eb7da33487b6a40a466215 --- /dev/null +++ b/outputs_24/logs/250730-165034/events.out.tfevents.1753894234.192-222-50-191.3717308.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:242a5126eb7e28c25d6ccf4377c745c970065110502df9dbf74e3392eff098c8 +size 4794 diff --git a/outputs_24/logs/250730-165327/events.out.tfevents.1753894407.192-222-50-191.3719515.0 b/outputs_24/logs/250730-165327/events.out.tfevents.1753894407.192-222-50-191.3719515.0 new file mode 100644 index 0000000000000000000000000000000000000000..e89c598e8f9f55104a2844141aab0a0cef504fb9 --- /dev/null +++ b/outputs_24/logs/250730-165327/events.out.tfevents.1753894407.192-222-50-191.3719515.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f02d9f09b42a51e0879c5de3e059703520e0aa2c0a1fc329e2878ecd0deb1c23 +size 657 diff --git a/outputs_24/logs/250730-165526/events.out.tfevents.1753894526.192-222-50-191.3721806.0 b/outputs_24/logs/250730-165526/events.out.tfevents.1753894526.192-222-50-191.3721806.0 new file mode 100644 index 0000000000000000000000000000000000000000..a89e28ec213bb48030cd2a830441fec8234bb9e4 --- /dev/null +++ b/outputs_24/logs/250730-165526/events.out.tfevents.1753894526.192-222-50-191.3721806.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:079e8e751153213fc3d9c3e9aba70b9d55c68e27f7cffc8dacd0e43c1669adfa +size 88 diff --git a/outputs_24/logs/250730-165559/events.out.tfevents.1753894559.192-222-50-191.3723454.0 b/outputs_24/logs/250730-165559/events.out.tfevents.1753894559.192-222-50-191.3723454.0 new file mode 100644 index 0000000000000000000000000000000000000000..5c09e2579987af8ddce50de63d4b05f6e0b7343b --- /dev/null +++ b/outputs_24/logs/250730-165559/events.out.tfevents.1753894559.192-222-50-191.3723454.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1eca3ac9cbdbd53b7bacb055d090f8daee1030c6e540866655782e7878ea2d30 +size 124537 diff --git a/outputs_24/logs/250730-171212/events.out.tfevents.1753895532.192-222-50-191.3733007.0 b/outputs_24/logs/250730-171212/events.out.tfevents.1753895532.192-222-50-191.3733007.0 new file mode 100644 index 0000000000000000000000000000000000000000..bcfbe67f4fb5ca6e375257e398217e5df768c17d --- /dev/null +++ b/outputs_24/logs/250730-171212/events.out.tfevents.1753895532.192-222-50-191.3733007.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cc66c8eccc6cf3bed4dc52b67619a429769c8fb4dd7c207f8e1d264590e5cbe +size 33033 diff --git a/outputs_24/logs/250730-171954/events.out.tfevents.1753895994.192-222-50-191.3737937.0 b/outputs_24/logs/250730-171954/events.out.tfevents.1753895994.192-222-50-191.3737937.0 new file mode 100644 index 0000000000000000000000000000000000000000..f575134fa41003864cf891770e4e743d3356d897 --- /dev/null +++ b/outputs_24/logs/250730-171954/events.out.tfevents.1753895994.192-222-50-191.3737937.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04f3bb52d93686c73112c0c558b77a8e79c971446a5ee3353ee97bcbbac4724f +size 14973 diff --git a/outputs_24/logs/250730-172641/events.out.tfevents.1753896401.192-222-50-191.3742245.0 b/outputs_24/logs/250730-172641/events.out.tfevents.1753896401.192-222-50-191.3742245.0 new file mode 100644 index 0000000000000000000000000000000000000000..13028a106802c17f9c5abc31e0ee629685244f8d --- /dev/null +++ b/outputs_24/logs/250730-172641/events.out.tfevents.1753896401.192-222-50-191.3742245.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af97a5158a7135d58889234dbea96e1474561ada25f438109ee76eeebe5c9cbe +size 32431 diff --git a/outputs_24/logs/250730-173148/events.out.tfevents.1753896708.192-222-50-191.3746209.0 b/outputs_24/logs/250730-173148/events.out.tfevents.1753896708.192-222-50-191.3746209.0 new file mode 100644 index 0000000000000000000000000000000000000000..898f8d3401c8a309f6238a184b9c84d46e6bc85c --- /dev/null +++ b/outputs_24/logs/250730-173148/events.out.tfevents.1753896708.192-222-50-191.3746209.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61e7924bd3a364c4f2414255c1e159c70b2817dfa034aafcb0f7f0da30e39a9e +size 52899 diff --git a/outputs_24/logs/250730-174023/events.out.tfevents.1753897223.192-222-50-191.3751843.0 b/outputs_24/logs/250730-174023/events.out.tfevents.1753897223.192-222-50-191.3751843.0 new file mode 100644 index 0000000000000000000000000000000000000000..ff4d84c130c5c97258d8605d77d6ff45a650f964 --- /dev/null +++ b/outputs_24/logs/250730-174023/events.out.tfevents.1753897223.192-222-50-191.3751843.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5568121470e806cfe07b52040216fa2f24a37b3e8f5b81b99d1d33d3aeab8056 +size 1545040 diff --git a/outputs_24/logs/250730-183008/events.out.tfevents.1753900208.192-222-50-191.3782769.0 b/outputs_24/logs/250730-183008/events.out.tfevents.1753900208.192-222-50-191.3782769.0 new file mode 100644 index 0000000000000000000000000000000000000000..01391289ab3ef2f67b5e4904440d6fd8082c7996 --- /dev/null +++ b/outputs_24/logs/250730-183008/events.out.tfevents.1753900208.192-222-50-191.3782769.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c64126077c6c14771b0cd4cfc0f01f9fcce91c669ec91f37d73f298a4865a732 +size 88 diff --git a/outputs_24/logs/250730-183120/events.out.tfevents.1753900280.192-222-50-191.3783512.0 b/outputs_24/logs/250730-183120/events.out.tfevents.1753900280.192-222-50-191.3783512.0 new file mode 100644 index 0000000000000000000000000000000000000000..3d238350776e5e691908de97b7c37298857bd8aa --- /dev/null +++ b/outputs_24/logs/250730-183120/events.out.tfevents.1753900280.192-222-50-191.3783512.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72ac7fd7a6c6e582b33d80aa31322ef841d42ed18046b8189d29d2ff7ebbdeb7 +size 3260 diff --git a/outputs_24/logs/250730-183314/events.out.tfevents.1753900394.192-222-50-191.3785928.0 b/outputs_24/logs/250730-183314/events.out.tfevents.1753900394.192-222-50-191.3785928.0 new file mode 100644 index 0000000000000000000000000000000000000000..abcd0d68ec5bb384fa4e5d023670618b969e98dc --- /dev/null +++ b/outputs_24/logs/250730-183314/events.out.tfevents.1753900394.192-222-50-191.3785928.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7918b824c393cce3f8ab31a9b5628f8e9b0ddbcdabe40d81a183103c0722c04b +size 1292 diff --git a/outputs_24/logs/250730-183418/events.out.tfevents.1753900458.192-222-50-191.3788161.0 b/outputs_24/logs/250730-183418/events.out.tfevents.1753900458.192-222-50-191.3788161.0 new file mode 100644 index 0000000000000000000000000000000000000000..be2fa15d1985849a6930a051ad55723921f511b6 --- /dev/null +++ b/outputs_24/logs/250730-183418/events.out.tfevents.1753900458.192-222-50-191.3788161.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ba618d86e5d7b967ea1eb1b9e9370aa78fc5eea39952093cdb803c22279627a +size 1948 diff --git a/outputs_24/logs/250730-183532/events.out.tfevents.1753900532.192-222-50-191.3790274.0 b/outputs_24/logs/250730-183532/events.out.tfevents.1753900532.192-222-50-191.3790274.0 new file mode 100644 index 0000000000000000000000000000000000000000..174baa335e9048613f49eaa2fd09ab5464fe9f86 --- /dev/null +++ b/outputs_24/logs/250730-183532/events.out.tfevents.1753900532.192-222-50-191.3790274.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eca5bb6f09d845aa3c02b3a21a7d3166608d8bbbbb550ed26e5b29ba6643f8a2 +size 37372 diff --git a/outputs_24/logs/250730-184343/events.out.tfevents.1753901023.192-222-50-191.3796242.0 b/outputs_24/logs/250730-184343/events.out.tfevents.1753901023.192-222-50-191.3796242.0 new file mode 100644 index 0000000000000000000000000000000000000000..69e5b3e8c70081f78870f4711d115ae6496605b8 --- /dev/null +++ b/outputs_24/logs/250730-184343/events.out.tfevents.1753901023.192-222-50-191.3796242.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cfe22d198b98c91675ba0d318c25eae5fd11fd9237a05753f8cbbc21033399f8 +size 3211195 diff --git a/outputs_24/logs/250730-205828/events.out.tfevents.1753909108.192-222-50-191.3873633.0 b/outputs_24/logs/250730-205828/events.out.tfevents.1753909108.192-222-50-191.3873633.0 new file mode 100644 index 0000000000000000000000000000000000000000..c70e0109bd75159653ff8087811e33000e920300 --- /dev/null +++ b/outputs_24/logs/250730-205828/events.out.tfevents.1753909108.192-222-50-191.3873633.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:701deeb13bb1e7fba2056e1fc9a72926379d85f5c96d24ce66e2bfad4498dfba +size 32834 diff --git a/outputs_24/logs/250730-210557/events.out.tfevents.1753909557.192-222-50-191.3879203.0 b/outputs_24/logs/250730-210557/events.out.tfevents.1753909557.192-222-50-191.3879203.0 new file mode 100644 index 0000000000000000000000000000000000000000..c1ff12579197cc2041358a899b9be48767c50df6 --- /dev/null +++ b/outputs_24/logs/250730-210557/events.out.tfevents.1753909557.192-222-50-191.3879203.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c308520a1da0011ed6e694e2aa6515fedb63fcae09c8863e177431dcadc40f45 +size 7749 diff --git a/outputs_24/logs/250730-210852/events.out.tfevents.1753909732.192-222-50-191.3881938.0 b/outputs_24/logs/250730-210852/events.out.tfevents.1753909732.192-222-50-191.3881938.0 new file mode 100644 index 0000000000000000000000000000000000000000..402a150928a5b1bc35d36de18f636eed5b77d051 --- /dev/null +++ b/outputs_24/logs/250730-210852/events.out.tfevents.1753909732.192-222-50-191.3881938.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:644cc28ec7f985256e8a75e9d9cdd0792cbce8bf59ca914dd37f8278cafdc4f3 +size 657 diff --git a/outputs_24/logs/250730-211627/events.out.tfevents.1753910187.192-222-50-191.3886371.0 b/outputs_24/logs/250730-211627/events.out.tfevents.1753910187.192-222-50-191.3886371.0 new file mode 100644 index 0000000000000000000000000000000000000000..489ba8b1ef3079aec27b242a6eb86192de811309 --- /dev/null +++ b/outputs_24/logs/250730-211627/events.out.tfevents.1753910187.192-222-50-191.3886371.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e0971fb2e682d2b0e9bc1aaa01e56da0550b96ea3fb4b717724d7cc24435918 +size 88 diff --git a/outputs_24/logs/250730-211723/events.out.tfevents.1753910243.192-222-50-191.3886979.0 b/outputs_24/logs/250730-211723/events.out.tfevents.1753910243.192-222-50-191.3886979.0 new file mode 100644 index 0000000000000000000000000000000000000000..93f5b4271538e4401487018d171a2bdfb38ea5a4 --- /dev/null +++ b/outputs_24/logs/250730-211723/events.out.tfevents.1753910243.192-222-50-191.3886979.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cb75f1b8e3ca6549e5c4e327fdddd8f8be93ffee8ebe1a1f2f117d73fd50dd8 +size 88 diff --git a/outputs_24/logs/250730-211839/events.out.tfevents.1753910319.192-222-50-191.3887566.0 b/outputs_24/logs/250730-211839/events.out.tfevents.1753910319.192-222-50-191.3887566.0 new file mode 100644 index 0000000000000000000000000000000000000000..2cedeba213e62cda78236d9163a279db781642ca --- /dev/null +++ b/outputs_24/logs/250730-211839/events.out.tfevents.1753910319.192-222-50-191.3887566.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3175b36981e0593e26210287c2adf5be7fb47f3000c24bbbbec7d0648022fd7d +size 88 diff --git a/outputs_24/logs/250730-212010/events.out.tfevents.1753910410.192-222-50-191.3888158.0 b/outputs_24/logs/250730-212010/events.out.tfevents.1753910410.192-222-50-191.3888158.0 new file mode 100644 index 0000000000000000000000000000000000000000..9671330909997f9612180daea4d02453e2f24f0e --- /dev/null +++ b/outputs_24/logs/250730-212010/events.out.tfevents.1753910410.192-222-50-191.3888158.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f1385d38c90fcc42d92dc6f3a5e4837a035c154d10f2243278072c37978ebc4 +size 3628523 diff --git a/outputs_24/logs/250730-231859/events.out.tfevents.1753917539.192-222-50-191.3939273.0 b/outputs_24/logs/250730-231859/events.out.tfevents.1753917539.192-222-50-191.3939273.0 new file mode 100644 index 0000000000000000000000000000000000000000..82df7d2ad0e0db04f7d0acdd73ec8dd05bd10edb --- /dev/null +++ b/outputs_24/logs/250730-231859/events.out.tfevents.1753917539.192-222-50-191.3939273.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04921f804661028b00e93f732cad6eab9af30a359a9a192ab39cbbe40eb3300d +size 44844 diff --git a/outputs_24/logs/250730-232914/events.out.tfevents.1753918154.192-222-50-191.3944529.0 b/outputs_24/logs/250730-232914/events.out.tfevents.1753918154.192-222-50-191.3944529.0 new file mode 100644 index 0000000000000000000000000000000000000000..9c998bebb0401d6dc5420616c286794a73307a07 --- /dev/null +++ b/outputs_24/logs/250730-232914/events.out.tfevents.1753918154.192-222-50-191.3944529.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:458a1407c545a525cb4d5f8be1ce51d62d4f50646ac9f9acaa4eb02c8b3c9501 +size 26808 diff --git a/outputs_24/logs/250730-233623/events.out.tfevents.1753918583.192-222-50-191.3948594.0 b/outputs_24/logs/250730-233623/events.out.tfevents.1753918583.192-222-50-191.3948594.0 new file mode 100644 index 0000000000000000000000000000000000000000..74c2da22a1fb71abee974e671d64edbdff21922b --- /dev/null +++ b/outputs_24/logs/250730-233623/events.out.tfevents.1753918583.192-222-50-191.3948594.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c85060205e4e4b8ffd42ed878e18f148ee3c297cbac34449009e71f7a0a5e30a +size 18703586 diff --git a/outputs_44/logs/250731-224904/events.out.tfevents.1754002144.192-222-50-191.247315.0 b/outputs_44/logs/250731-224904/events.out.tfevents.1754002144.192-222-50-191.247315.0 new file mode 100644 index 0000000000000000000000000000000000000000..59455cdd0dacce102b4958e781d84ac29a25dd21 --- /dev/null +++ b/outputs_44/logs/250731-224904/events.out.tfevents.1754002144.192-222-50-191.247315.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f8af37348e9f77fa08ab65715496ca0b2300fbf598283bd2831fdaa33ebf33c +size 88 diff --git a/outputs_44/logs/250731-225045/events.out.tfevents.1754002245.192-222-50-191.249501.0 b/outputs_44/logs/250731-225045/events.out.tfevents.1754002245.192-222-50-191.249501.0 new file mode 100644 index 0000000000000000000000000000000000000000..2e4edd9703b56b70d16b6783607698cc4060167e --- /dev/null +++ b/outputs_44/logs/250731-225045/events.out.tfevents.1754002245.192-222-50-191.249501.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b28ada3aec1a04d2dd497c59fd291a59a933935eee7d36273df7df12538ac6c +size 88 diff --git a/outputs_44/logs/250731-225223/events.out.tfevents.1754002343.192-222-50-191.251408.0 b/outputs_44/logs/250731-225223/events.out.tfevents.1754002343.192-222-50-191.251408.0 new file mode 100644 index 0000000000000000000000000000000000000000..20b4af4b0a573d1a66c384a6a3ca32261ab7ef3a --- /dev/null +++ b/outputs_44/logs/250731-225223/events.out.tfevents.1754002343.192-222-50-191.251408.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39a5afdc3cce09e86e690161fb5ffc86dc91f658c0ed98f59bb5b009fc41913c +size 88 diff --git a/outputs_44/logs/250731-225403/events.out.tfevents.1754002443.192-222-50-191.253585.0 b/outputs_44/logs/250731-225403/events.out.tfevents.1754002443.192-222-50-191.253585.0 new file mode 100644 index 0000000000000000000000000000000000000000..1ea0d4b64c7da0b0af7b67934f46bb8661fed1e3 --- /dev/null +++ b/outputs_44/logs/250731-225403/events.out.tfevents.1754002443.192-222-50-191.253585.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ace267edf757f67cbefd8f9b608c65cb7203a64ec9ff67abcb6fdfc133f3c6d9 +size 88 diff --git a/outputs_44/logs/250731-225825/events.out.tfevents.1754002705.192-222-50-191.256776.0 b/outputs_44/logs/250731-225825/events.out.tfevents.1754002705.192-222-50-191.256776.0 new file mode 100644 index 0000000000000000000000000000000000000000..843a90198d89862fa573e6a3f46fc9a818b7cc74 --- /dev/null +++ b/outputs_44/logs/250731-225825/events.out.tfevents.1754002705.192-222-50-191.256776.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95f6a49c1abbada3f43a053c6f3a7b71b2350e69465e7d472d7cdf7ab6e3efb2 +size 88 diff --git a/outputs_44/logs/250731-225933/events.out.tfevents.1754002773.192-222-50-191.258726.0 b/outputs_44/logs/250731-225933/events.out.tfevents.1754002773.192-222-50-191.258726.0 new file mode 100644 index 0000000000000000000000000000000000000000..86a2d22147abde0d679192106cdcb5e3b1a9c313 --- /dev/null +++ b/outputs_44/logs/250731-225933/events.out.tfevents.1754002773.192-222-50-191.258726.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3ecfd4d1e631beefb9eb6b44e06d00109fb4eeb4cb66fbe6c18d6f715929a9c +size 88 diff --git a/outputs_44/logs/250731-230141/events.out.tfevents.1754002901.192-222-50-191.261044.0 b/outputs_44/logs/250731-230141/events.out.tfevents.1754002901.192-222-50-191.261044.0 new file mode 100644 index 0000000000000000000000000000000000000000..7be3b1a1e5bca2dee2cb4a77697251ce1a28609f --- /dev/null +++ b/outputs_44/logs/250731-230141/events.out.tfevents.1754002901.192-222-50-191.261044.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d7ea0d8fd1b641cb2e72757aae1dc40e8d46543592881178e77bd1f25d2fc60 +size 657 diff --git a/outputs_44/logs/250731-230933/events.out.tfevents.1754003373.192-222-50-191.265600.0 b/outputs_44/logs/250731-230933/events.out.tfevents.1754003373.192-222-50-191.265600.0 new file mode 100644 index 0000000000000000000000000000000000000000..1774d903673adc2c122a2f0ce26bb6461f805adb --- /dev/null +++ b/outputs_44/logs/250731-230933/events.out.tfevents.1754003373.192-222-50-191.265600.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fb8f0ff63b6bc3672799ef1f942a3550e7835fe8f1bc9323c6285dced6bac8e +size 88 diff --git a/outputs_44/logs/250731-231116/events.out.tfevents.1754003476.192-222-50-191.267758.0 b/outputs_44/logs/250731-231116/events.out.tfevents.1754003476.192-222-50-191.267758.0 new file mode 100644 index 0000000000000000000000000000000000000000..a3d44e8fa50f1581331537d71b60da1642d53a91 --- /dev/null +++ b/outputs_44/logs/250731-231116/events.out.tfevents.1754003476.192-222-50-191.267758.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85e3914a1cb506a4db5c835143cef42a9ead8e2575c507be4db58500589a76a9 +size 88 diff --git a/outputs_44/logs/250731-231231/events.out.tfevents.1754003551.192-222-50-191.269776.0 b/outputs_44/logs/250731-231231/events.out.tfevents.1754003551.192-222-50-191.269776.0 new file mode 100644 index 0000000000000000000000000000000000000000..241894d3062576066e70298db0ef6da8985646b9 --- /dev/null +++ b/outputs_44/logs/250731-231231/events.out.tfevents.1754003551.192-222-50-191.269776.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35880f64dbda81c78cb2db9896489ac4e57ce40f1479604619d8c53a393ef564 +size 8351 diff --git a/outputs_44/logs/250731-231530/events.out.tfevents.1754003730.192-222-50-191.272651.0 b/outputs_44/logs/250731-231530/events.out.tfevents.1754003730.192-222-50-191.272651.0 new file mode 100644 index 0000000000000000000000000000000000000000..2320d7c1d657c45a3d165004400287ef8f8ffee3 --- /dev/null +++ b/outputs_44/logs/250731-231530/events.out.tfevents.1754003730.192-222-50-191.272651.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0982d44c25d89ec5049534538b6544a1451a15f0d90ca2b51e3c1e8d0602e62e +size 657 diff --git a/outputs_44/logs/250731-231657/events.out.tfevents.1754003817.192-222-50-191.274780.0 b/outputs_44/logs/250731-231657/events.out.tfevents.1754003817.192-222-50-191.274780.0 new file mode 100644 index 0000000000000000000000000000000000000000..2ad5402f3ee8e01418b4819a2a41dca17662ff35 --- /dev/null +++ b/outputs_44/logs/250731-231657/events.out.tfevents.1754003817.192-222-50-191.274780.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1aea22b39c5a92a988ed0eca38561ee0e63e133671ccce7b99ab2fcc7a7ffdd9 +size 88 diff --git a/outputs_44/logs/250731-231755/events.out.tfevents.1754003875.192-222-50-191.276659.0 b/outputs_44/logs/250731-231755/events.out.tfevents.1754003875.192-222-50-191.276659.0 new file mode 100644 index 0000000000000000000000000000000000000000..c6d8be44c6355a1f1f7b869f8ed3dd6a1ed6f98a --- /dev/null +++ b/outputs_44/logs/250731-231755/events.out.tfevents.1754003875.192-222-50-191.276659.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89f45e59d2d8dd3f1082b83ce80273e6ec9616c1fed0160a7cf8f1db48f54373 +size 88 diff --git a/outputs_44/logs/250731-231848/events.out.tfevents.1754003928.192-222-50-191.278498.0 b/outputs_44/logs/250731-231848/events.out.tfevents.1754003928.192-222-50-191.278498.0 new file mode 100644 index 0000000000000000000000000000000000000000..d34ff181f55264b84f218827866ee09f8dbe260a --- /dev/null +++ b/outputs_44/logs/250731-231848/events.out.tfevents.1754003928.192-222-50-191.278498.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5899d7d0cce9c0fdba40cde59ac07ef652dec015ff4cf8e3127e3a6eb0a21746 +size 88 diff --git a/outputs_44/logs/250731-232239/events.out.tfevents.1754004159.192-222-50-191.281551.0 b/outputs_44/logs/250731-232239/events.out.tfevents.1754004159.192-222-50-191.281551.0 new file mode 100644 index 0000000000000000000000000000000000000000..2b5abd4bd4bf1d6d398ca8b67a6e42b11962e2aa --- /dev/null +++ b/outputs_44/logs/250731-232239/events.out.tfevents.1754004159.192-222-50-191.281551.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16dc81b44f9d3942f12fc4116d32f22a271c8f8da8b89ff7f3b4724fad34b707 +size 88 diff --git a/outputs_44/logs/250731-232509/events.out.tfevents.1754004309.192-222-50-191.284044.0 b/outputs_44/logs/250731-232509/events.out.tfevents.1754004309.192-222-50-191.284044.0 new file mode 100644 index 0000000000000000000000000000000000000000..66ab69f695fb25af6d671c30224b3b6bfbfec583 --- /dev/null +++ b/outputs_44/logs/250731-232509/events.out.tfevents.1754004309.192-222-50-191.284044.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3129f65104d4265cf0694184eb10b57ca97feb090d8e0fa28bf9221cbfc6fc3 +size 20391 diff --git a/outputs_44/logs/250731-233010/events.out.tfevents.1754004610.192-222-50-191.287890.0 b/outputs_44/logs/250731-233010/events.out.tfevents.1754004610.192-222-50-191.287890.0 new file mode 100644 index 0000000000000000000000000000000000000000..ba2eea551eb0d442636e4cdc9ad104c8aa141d3c --- /dev/null +++ b/outputs_44/logs/250731-233010/events.out.tfevents.1754004610.192-222-50-191.287890.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:929af03bf641b4e5b79fc9636677d488105c781df166a228f921b9a61af26ff8 +size 88 diff --git a/outputs_44/logs/250731-233059/events.out.tfevents.1754004659.192-222-50-191.289779.0 b/outputs_44/logs/250731-233059/events.out.tfevents.1754004659.192-222-50-191.289779.0 new file mode 100644 index 0000000000000000000000000000000000000000..1f3a8f3b4e69040b0b3013b4247e9323e9be8097 --- /dev/null +++ b/outputs_44/logs/250731-233059/events.out.tfevents.1754004659.192-222-50-191.289779.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65d5f637e5642d72342dee49cbc2f317adb4ae9efa2d0a4c9c05493ce75a7866 +size 88 diff --git a/outputs_44/logs/250731-233155/events.out.tfevents.1754004715.192-222-50-191.291676.0 b/outputs_44/logs/250731-233155/events.out.tfevents.1754004715.192-222-50-191.291676.0 new file mode 100644 index 0000000000000000000000000000000000000000..4203046b95847bd6442be52c7bcf6613ba816e0e --- /dev/null +++ b/outputs_44/logs/250731-233155/events.out.tfevents.1754004715.192-222-50-191.291676.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:feee42abcd03fb6b1218db207645a65c9e50827fa95686eaf988bf3bc086f5af +size 88 diff --git a/outputs_44/logs/250731-233252/events.out.tfevents.1754004772.192-222-50-191.293587.0 b/outputs_44/logs/250731-233252/events.out.tfevents.1754004772.192-222-50-191.293587.0 new file mode 100644 index 0000000000000000000000000000000000000000..c7c0e26b8f20321d9bb2afac55012bfa2b5b207f --- /dev/null +++ b/outputs_44/logs/250731-233252/events.out.tfevents.1754004772.192-222-50-191.293587.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05e371ed5b6c7e034cfb9a5f5292ba534ed48afd30189a478edb5d9e94c8e0e8 +size 88 diff --git a/outputs_44/logs/250731-233447/events.out.tfevents.1754004887.192-222-50-191.296551.0 b/outputs_44/logs/250731-233447/events.out.tfevents.1754004887.192-222-50-191.296551.0 new file mode 100644 index 0000000000000000000000000000000000000000..9dfcf1237efa2cabb4aa182d76837ac2efaa3f32 --- /dev/null +++ b/outputs_44/logs/250731-233447/events.out.tfevents.1754004887.192-222-50-191.296551.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df433a42dc166db97ac8efe2beda59e0878c9a6e2c0474904a8ee39bbe35084a +size 301525 diff --git a/outputs_44/logs/250801-000341/events.out.tfevents.1754006621.192-222-50-191.314103.0 b/outputs_44/logs/250801-000341/events.out.tfevents.1754006621.192-222-50-191.314103.0 new file mode 100644 index 0000000000000000000000000000000000000000..a8edb72a9e3cdf042e4031cf6ae2ab0b2c8f38c2 --- /dev/null +++ b/outputs_44/logs/250801-000341/events.out.tfevents.1754006621.192-222-50-191.314103.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42dbafb5a562cd905fbd0197474ad7f38588c48968831fab3c31ecaa0c9a7a41 +size 33033 diff --git a/outputs_44/logs/250801-000836/events.out.tfevents.1754006916.192-222-50-191.318094.0 b/outputs_44/logs/250801-000836/events.out.tfevents.1754006916.192-222-50-191.318094.0 new file mode 100644 index 0000000000000000000000000000000000000000..79edfe7255f907a40a7eb7c0f409e8586eac6671 --- /dev/null +++ b/outputs_44/logs/250801-000836/events.out.tfevents.1754006916.192-222-50-191.318094.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:324fb528f88d11de2e0f97a31f0bb690e8fdabb9ea3f42d84eacff73381a89be +size 6058164 diff --git a/outputs_44/logs/250801-061127/events.out.tfevents.1754028687.192-222-50-191.479069.0 b/outputs_44/logs/250801-061127/events.out.tfevents.1754028687.192-222-50-191.479069.0 new file mode 100644 index 0000000000000000000000000000000000000000..f8cc639e5b342cf462ac4061ae5a62d188f2486d --- /dev/null +++ b/outputs_44/logs/250801-061127/events.out.tfevents.1754028687.192-222-50-191.479069.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6430f08402c7796e984a55d22106bed2341f436647d83f020e385e62bb0ed7cc +size 25834 diff --git a/outputs_44/logs/250801-061443/events.out.tfevents.1754028883.192-222-50-191.482224.0 b/outputs_44/logs/250801-061443/events.out.tfevents.1754028883.192-222-50-191.482224.0 new file mode 100644 index 0000000000000000000000000000000000000000..f2207ac2f19cc73d3c4639083af01d114a00c537 --- /dev/null +++ b/outputs_44/logs/250801-061443/events.out.tfevents.1754028883.192-222-50-191.482224.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6ac7bd202994fe4eede4e537ec45506ec808cd36c2c4819b4d29f78e44a7540 +size 3489477 diff --git a/outputs_CQT/logs/250802-193238/events.out.tfevents.1754163158.192-222-50-191.450847.0 b/outputs_CQT/logs/250802-193238/events.out.tfevents.1754163158.192-222-50-191.450847.0 new file mode 100644 index 0000000000000000000000000000000000000000..9b247cbbaa3d5061f47965609d8582d3feaec242 --- /dev/null +++ b/outputs_CQT/logs/250802-193238/events.out.tfevents.1754163158.192-222-50-191.450847.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1ca13aeae9e9cfceb714cee0843a49e65f60a79d2c8241edef02e7b647c18dd +size 88 diff --git a/outputs_CQT/logs/250802-193416/events.out.tfevents.1754163256.192-222-50-191.453400.0 b/outputs_CQT/logs/250802-193416/events.out.tfevents.1754163256.192-222-50-191.453400.0 new file mode 100644 index 0000000000000000000000000000000000000000..cb7714b9d3753aa4587550114be2be683fb8a518 --- /dev/null +++ b/outputs_CQT/logs/250802-193416/events.out.tfevents.1754163256.192-222-50-191.453400.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b74b5b578aa9e1424559b140c2627cbab5cc7d1af12be20bc51e1094954a699f +size 34072 diff --git a/outputs_CQT/logs/250802-194115/events.out.tfevents.1754163675.192-222-50-191.461717.0 b/outputs_CQT/logs/250802-194115/events.out.tfevents.1754163675.192-222-50-191.461717.0 new file mode 100644 index 0000000000000000000000000000000000000000..7116ad81f3bd36b79f344b17fd4798a55ca49964 --- /dev/null +++ b/outputs_CQT/logs/250802-194115/events.out.tfevents.1754163675.192-222-50-191.461717.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f94b5a3c27265764126c48243122dbb3a375cdf01c660de6cfacd31266a2336 +size 2803 diff --git a/outputs_CQT/logs/250802-194301/events.out.tfevents.1754163781.192-222-50-191.464869.0 b/outputs_CQT/logs/250802-194301/events.out.tfevents.1754163781.192-222-50-191.464869.0 new file mode 100644 index 0000000000000000000000000000000000000000..32de7672e9031c7201997f4d78aa7b8bd546b41d --- /dev/null +++ b/outputs_CQT/logs/250802-194301/events.out.tfevents.1754163781.192-222-50-191.464869.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb59cef5dc8ba7c397749cac7acfc2042df74553b99425455899d1ebe4f2111f +size 761 diff --git a/outputs_CQT/logs/250802-194610/events.out.tfevents.1754163970.192-222-50-191.469115.0 b/outputs_CQT/logs/250802-194610/events.out.tfevents.1754163970.192-222-50-191.469115.0 new file mode 100644 index 0000000000000000000000000000000000000000..596cfba320faeb74485f523a53294814f3c3390d --- /dev/null +++ b/outputs_CQT/logs/250802-194610/events.out.tfevents.1754163970.192-222-50-191.469115.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7252461be718428b60229fa881683ad554174558f1b4004f6b9d53029b899026 +size 95173 diff --git a/outputs_CQT/logs/250802-195930/events.out.tfevents.1754164770.192-222-50-191.484506.0 b/outputs_CQT/logs/250802-195930/events.out.tfevents.1754164770.192-222-50-191.484506.0 new file mode 100644 index 0000000000000000000000000000000000000000..fa086827c1b5a41ccfc0bfee2e69af80cfa90a9b --- /dev/null +++ b/outputs_CQT/logs/250802-195930/events.out.tfevents.1754164770.192-222-50-191.484506.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:acaabd03d00d477ca4fae7bf5ccd4e83b4ad162c36af56005d884000d3262222 +size 1407 diff --git a/outputs_CQT/logs/250802-200219/events.out.tfevents.1754164939.192-222-50-191.488453.0 b/outputs_CQT/logs/250802-200219/events.out.tfevents.1754164939.192-222-50-191.488453.0 new file mode 100644 index 0000000000000000000000000000000000000000..e1fd299dd7c51fcf758ca3e88fc92af64ccb43c7 --- /dev/null +++ b/outputs_CQT/logs/250802-200219/events.out.tfevents.1754164939.192-222-50-191.488453.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e7ff6fa6cb292bcfbe1a2283aaeac96e410d2cd570fcddd85bf34c2b96987f8 +size 19177 diff --git a/outputs_CQT/logs/250802-200725/events.out.tfevents.1754165245.192-222-50-191.494780.0 b/outputs_CQT/logs/250802-200725/events.out.tfevents.1754165245.192-222-50-191.494780.0 new file mode 100644 index 0000000000000000000000000000000000000000..1836039ae398046e826aa329ee2a3eff174df5ee --- /dev/null +++ b/outputs_CQT/logs/250802-200725/events.out.tfevents.1754165245.192-222-50-191.494780.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f07aa6b9973757d72e42cf672491ad911a8eed11d39b8b9edb7d8c35d66e6433 +size 189 diff --git a/outputs_CQT/logs/250802-200854/events.out.tfevents.1754165334.192-222-50-191.497569.0 b/outputs_CQT/logs/250802-200854/events.out.tfevents.1754165334.192-222-50-191.497569.0 new file mode 100644 index 0000000000000000000000000000000000000000..59f5f156e236ed7a00f84b62f8f0bc020e7c0692 --- /dev/null +++ b/outputs_CQT/logs/250802-200854/events.out.tfevents.1754165334.192-222-50-191.497569.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:adc0523fd31d421bb02c44a431341a1045fcd8d5a11f96d5622fc0c8758213ec +size 709 diff --git a/outputs_CQT/logs/250802-201010/events.out.tfevents.1754165410.192-222-50-191.500209.0 b/outputs_CQT/logs/250802-201010/events.out.tfevents.1754165410.192-222-50-191.500209.0 new file mode 100644 index 0000000000000000000000000000000000000000..73b020d64849ebd2e68ab3e8747961c0441769b9 --- /dev/null +++ b/outputs_CQT/logs/250802-201010/events.out.tfevents.1754165410.192-222-50-191.500209.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed8af193fd5c65aaacc6f77ff677adcb061a11c069bce578caa400ea0605cb53 +size 5956910 diff --git a/outputs_CQT/logs/250802-234327/events.out.tfevents.1754178207.192-222-50-191.737317.0 b/outputs_CQT/logs/250802-234327/events.out.tfevents.1754178207.192-222-50-191.737317.0 new file mode 100644 index 0000000000000000000000000000000000000000..795244e024101aac8f7c2a7619420f93aba851c7 --- /dev/null +++ b/outputs_CQT/logs/250802-234327/events.out.tfevents.1754178207.192-222-50-191.737317.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8db19b40728df6f0d614abb02acf1a8268c31b38059cf40d67586008d659af1 +size 88 diff --git a/outputs_CQT/logs/250802-234408/events.out.tfevents.1754178248.192-222-50-191.739490.0 b/outputs_CQT/logs/250802-234408/events.out.tfevents.1754178248.192-222-50-191.739490.0 new file mode 100644 index 0000000000000000000000000000000000000000..a0617d1362fe5b44082e7589b8c8dc1c9462763e --- /dev/null +++ b/outputs_CQT/logs/250802-234408/events.out.tfevents.1754178248.192-222-50-191.739490.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b30be993a5f281fba79d9a56f128e9e4f0a0c44ee2795e1cc9135df3837e689e +size 88 diff --git a/outputs_CQT/logs/250802-234456/events.out.tfevents.1754178296.192-222-50-191.741743.0 b/outputs_CQT/logs/250802-234456/events.out.tfevents.1754178296.192-222-50-191.741743.0 new file mode 100644 index 0000000000000000000000000000000000000000..30e0541237e4cfc2009b9afd5b1b94bbbc91401e --- /dev/null +++ b/outputs_CQT/logs/250802-234456/events.out.tfevents.1754178296.192-222-50-191.741743.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3be82bf3dd09dd55bb51bb73c602e5027bab6eae116f2c5f6e9d52130408a152 +size 88 diff --git a/outputs_CQT/logs/250802-234612/events.out.tfevents.1754178372.192-222-50-191.744771.0 b/outputs_CQT/logs/250802-234612/events.out.tfevents.1754178372.192-222-50-191.744771.0 new file mode 100644 index 0000000000000000000000000000000000000000..07cb71417126009b9839e01f91082c85ecf5ac27 --- /dev/null +++ b/outputs_CQT/logs/250802-234612/events.out.tfevents.1754178372.192-222-50-191.744771.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:093ce03e0b22bedcf43d5331d3c567070bb45a8990e6ef37b62599e1963dc490 +size 88 diff --git a/outputs_CQT/logs/250802-234646/events.out.tfevents.1754178406.192-222-50-191.745574.0 b/outputs_CQT/logs/250802-234646/events.out.tfevents.1754178406.192-222-50-191.745574.0 new file mode 100644 index 0000000000000000000000000000000000000000..ccaca85dce596530cd2a1118f7b0ab150788e113 --- /dev/null +++ b/outputs_CQT/logs/250802-234646/events.out.tfevents.1754178406.192-222-50-191.745574.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a84a5981e2ed222a0a3738b4cce9fca427ff86d5a8cf59d0850b64385229f75 +size 2873 diff --git a/outputs_CQT/logs/250802-234803/events.out.tfevents.1754178483.192-222-50-191.748386.0 b/outputs_CQT/logs/250802-234803/events.out.tfevents.1754178483.192-222-50-191.748386.0 new file mode 100644 index 0000000000000000000000000000000000000000..166ade3ee07eb2ea48786cc47e0552d279a306ae --- /dev/null +++ b/outputs_CQT/logs/250802-234803/events.out.tfevents.1754178483.192-222-50-191.748386.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2cba767f89f649810920900d7c44738cbf36604db255547bbbebcc42faad0b3 +size 1200 diff --git a/outputs_CQT/logs/250802-234912/events.out.tfevents.1754178552.192-222-50-191.750989.0 b/outputs_CQT/logs/250802-234912/events.out.tfevents.1754178552.192-222-50-191.750989.0 new file mode 100644 index 0000000000000000000000000000000000000000..28d1fe3c4ac6524a2febfac8a6de3ac919f1207d --- /dev/null +++ b/outputs_CQT/logs/250802-234912/events.out.tfevents.1754178552.192-222-50-191.750989.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:082153cc5725f1638314707bcc5e23dfbf4174ae14f1682cb03fb5e650d5fb89 +size 6397810 diff --git a/outputs_CQT/logs/250803-072434/events.out.tfevents.1754205874.192-222-50-191.849437.0 b/outputs_CQT/logs/250803-072434/events.out.tfevents.1754205874.192-222-50-191.849437.0 new file mode 100644 index 0000000000000000000000000000000000000000..675b4f510a05f2e4da222e2407d09823e0e8f3a1 --- /dev/null +++ b/outputs_CQT/logs/250803-072434/events.out.tfevents.1754205874.192-222-50-191.849437.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c16249fadde72237b852855dcfe7d165769e00ee780041777f9a0ac9c36fceb7 +size 88 diff --git a/outputs_CQT/logs/250803-072505/events.out.tfevents.1754205905.192-222-50-191.849986.0 b/outputs_CQT/logs/250803-072505/events.out.tfevents.1754205905.192-222-50-191.849986.0 new file mode 100644 index 0000000000000000000000000000000000000000..4de71f4b0f037c33e19b71f38853082e2aa52550 --- /dev/null +++ b/outputs_CQT/logs/250803-072505/events.out.tfevents.1754205905.192-222-50-191.849986.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:30bdf426c8446754b31773e465a8b840508f0b200fb3c8aeddf7f676bf4c896c +size 9587259 diff --git a/outputs_mp_2k/logs/250802-061508/events.out.tfevents.1754115308.192-222-50-191.55173.0 b/outputs_mp_2k/logs/250802-061508/events.out.tfevents.1754115308.192-222-50-191.55173.0 new file mode 100644 index 0000000000000000000000000000000000000000..e5637b7508510682a2bcfb460d08f54513121c9c --- /dev/null +++ b/outputs_mp_2k/logs/250802-061508/events.out.tfevents.1754115308.192-222-50-191.55173.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0848ed11360d7e3dd6828ab18dd5bd017ead6d5fa704dbe7c689e7ffbc3eff42 +size 8449 diff --git a/outputs_mp_2k/logs/250802-061754/events.out.tfevents.1754115474.192-222-50-191.57865.0 b/outputs_mp_2k/logs/250802-061754/events.out.tfevents.1754115474.192-222-50-191.57865.0 new file mode 100644 index 0000000000000000000000000000000000000000..6e6e910487e08ad17f0d9e1b77ad340bbb1e5330 --- /dev/null +++ b/outputs_mp_2k/logs/250802-061754/events.out.tfevents.1754115474.192-222-50-191.57865.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f8f0b131b7e62a48f03076da0b6d0f7d855ded99edc9d3de9c36c72d80d7020 +size 88 diff --git a/outputs_mp_2k/logs/250802-061907/events.out.tfevents.1754115547.192-222-50-191.59898.0 b/outputs_mp_2k/logs/250802-061907/events.out.tfevents.1754115547.192-222-50-191.59898.0 new file mode 100644 index 0000000000000000000000000000000000000000..035b05a3a37d7fc794f37fb062476920d273df83 --- /dev/null +++ b/outputs_mp_2k/logs/250802-061907/events.out.tfevents.1754115547.192-222-50-191.59898.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84af44f7c1b95f44c3ae8681904c3ffc9303a24597ae2da58759a28ce7a1187f +size 709 diff --git a/outputs_mp_2k/logs/250802-062153/events.out.tfevents.1754115713.192-222-50-191.62603.0 b/outputs_mp_2k/logs/250802-062153/events.out.tfevents.1754115713.192-222-50-191.62603.0 new file mode 100644 index 0000000000000000000000000000000000000000..c9390b2d452e2acc9becbd1b313a1d5bff4e9f6e --- /dev/null +++ b/outputs_mp_2k/logs/250802-062153/events.out.tfevents.1754115713.192-222-50-191.62603.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90f979a46e09bac300037bb4e8902afe11bf9efd1870a0680322ecec9c219b1c +size 66091 diff --git a/outputs_mp_2k/logs/250802-063428/events.out.tfevents.1754116468.192-222-50-191.70131.0 b/outputs_mp_2k/logs/250802-063428/events.out.tfevents.1754116468.192-222-50-191.70131.0 new file mode 100644 index 0000000000000000000000000000000000000000..70055cde6e086d691161492d5f3083281353cbaa --- /dev/null +++ b/outputs_mp_2k/logs/250802-063428/events.out.tfevents.1754116468.192-222-50-191.70131.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:392e82f9da909b6ca6094f1ba8090cc292b946e5668398aef5cbcd1bd45ade4f +size 66044 diff --git a/outputs_mp_2k/logs/250802-065546/events.out.tfevents.1754117746.192-222-50-191.84723.0 b/outputs_mp_2k/logs/250802-065546/events.out.tfevents.1754117746.192-222-50-191.84723.0 new file mode 100644 index 0000000000000000000000000000000000000000..d2b88ddce74155b412c21ccfe9e319e2b9e5d4b3 --- /dev/null +++ b/outputs_mp_2k/logs/250802-065546/events.out.tfevents.1754117746.192-222-50-191.84723.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d61b668cc2b3d6ab44a7d70775e656cb2c4d761cc1033fec0da78595e1fdb8a +size 709 diff --git a/outputs_mp_2k/logs/250802-072237/events.out.tfevents.1754119357.192-222-50-191.103131.0 b/outputs_mp_2k/logs/250802-072237/events.out.tfevents.1754119357.192-222-50-191.103131.0 new file mode 100644 index 0000000000000000000000000000000000000000..e46bb29e4817656e8eb52dffebfa9d2fdffe69ad --- /dev/null +++ b/outputs_mp_2k/logs/250802-072237/events.out.tfevents.1754119357.192-222-50-191.103131.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd4eef1756581821da82e50125f64f67cd3e2caf8bae6b2c35b052f562258d57 +size 88 diff --git a/outputs_mp_2k/logs/250802-073157/events.out.tfevents.1754119917.192-222-50-191.107469.0 b/outputs_mp_2k/logs/250802-073157/events.out.tfevents.1754119917.192-222-50-191.107469.0 new file mode 100644 index 0000000000000000000000000000000000000000..c0cfc2be151257b067865a971b3d9b95d0ebd169 --- /dev/null +++ b/outputs_mp_2k/logs/250802-073157/events.out.tfevents.1754119917.192-222-50-191.107469.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c0db545588901e9213f53d9bcbc561f143e69716ce4302b2bf975b69a9f7235 +size 21769 diff --git a/outputs_mp_2k/logs/250802-073456/events.out.tfevents.1754120096.192-222-50-191.110787.0 b/outputs_mp_2k/logs/250802-073456/events.out.tfevents.1754120096.192-222-50-191.110787.0 new file mode 100644 index 0000000000000000000000000000000000000000..61a521964b5657d26e17d247a68764f4f1d87197 --- /dev/null +++ b/outputs_mp_2k/logs/250802-073456/events.out.tfevents.1754120096.192-222-50-191.110787.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc3279bf0d25bb01b1e9546cff756cfc69ef0253265301e2be6dea8d3a19eca1 +size 21769 diff --git a/outputs_mp_2k/logs/250802-073738/events.out.tfevents.1754120258.192-222-50-191.113691.0 b/outputs_mp_2k/logs/250802-073738/events.out.tfevents.1754120258.192-222-50-191.113691.0 new file mode 100644 index 0000000000000000000000000000000000000000..74d10da0b4355694c9d8f4d43411170d422bc911 --- /dev/null +++ b/outputs_mp_2k/logs/250802-073738/events.out.tfevents.1754120258.192-222-50-191.113691.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c352d9eb7d45231bfd644b79e20344243efa08855c64b163f2f18d749262a986 +size 56590 diff --git a/outputs_mp_2k/logs/250802-074336/events.out.tfevents.1754120616.192-222-50-191.118429.0 b/outputs_mp_2k/logs/250802-074336/events.out.tfevents.1754120616.192-222-50-191.118429.0 new file mode 100644 index 0000000000000000000000000000000000000000..7c069f02ba3bd8f9e3271135c53031fed4b12701 --- /dev/null +++ b/outputs_mp_2k/logs/250802-074336/events.out.tfevents.1754120616.192-222-50-191.118429.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d9d37ddccf8bff4894c51dfb85905914967ab3e2a3a7c33f8cd9b6542181545 +size 2059 diff --git a/outputs_mp_2k/logs/250802-074425/events.out.tfevents.1754120665.192-222-50-191.120287.0 b/outputs_mp_2k/logs/250802-074425/events.out.tfevents.1754120665.192-222-50-191.120287.0 new file mode 100644 index 0000000000000000000000000000000000000000..7c1ecca4f4b4ef2a675c5edadfb597822e437fc6 --- /dev/null +++ b/outputs_mp_2k/logs/250802-074425/events.out.tfevents.1754120665.192-222-50-191.120287.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1bad8d9bb621d3d56a874871fd0f14d615b7819f8b871a57b02f9c796017dd2a +size 88 diff --git a/outputs_mp_2k/logs/250802-074456/events.out.tfevents.1754120696.192-222-50-191.122150.0 b/outputs_mp_2k/logs/250802-074456/events.out.tfevents.1754120696.192-222-50-191.122150.0 new file mode 100644 index 0000000000000000000000000000000000000000..d6bdc056685e0b339e823e7c5ec38378ac719d7c --- /dev/null +++ b/outputs_mp_2k/logs/250802-074456/events.out.tfevents.1754120696.192-222-50-191.122150.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e14b36393f92dd7322b0d9b760b99f0def346d131a10e57343a513adccf362b +size 2716 diff --git a/outputs_mp_2k/logs/250802-074646/events.out.tfevents.1754120806.192-222-50-191.124400.0 b/outputs_mp_2k/logs/250802-074646/events.out.tfevents.1754120806.192-222-50-191.124400.0 new file mode 100644 index 0000000000000000000000000000000000000000..1a74dde06bcefc0061a9f290cd7232de60bc9858 --- /dev/null +++ b/outputs_mp_2k/logs/250802-074646/events.out.tfevents.1754120806.192-222-50-191.124400.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8e2de6f8f34dc22a811a001bdefef42495c5b6250492c8498bc176563c48460 +size 5689841 diff --git a/outputs_mp_2k/logs/250802-101742/events.out.tfevents.1754129862.192-222-50-191.146107.0 b/outputs_mp_2k/logs/250802-101742/events.out.tfevents.1754129862.192-222-50-191.146107.0 new file mode 100644 index 0000000000000000000000000000000000000000..ceb8b8502a8ccfe8b1ad904e55977e07df1ee77a --- /dev/null +++ b/outputs_mp_2k/logs/250802-101742/events.out.tfevents.1754129862.192-222-50-191.146107.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2258c17fc289889fbd65a3f79eb5baf3f9596028d7878dc11289ceed4cc294bb +size 24841 diff --git a/outputs_mp_2k/logs/250802-102040/events.out.tfevents.1754130040.192-222-50-191.149567.0 b/outputs_mp_2k/logs/250802-102040/events.out.tfevents.1754130040.192-222-50-191.149567.0 new file mode 100644 index 0000000000000000000000000000000000000000..68626af26534ba5474a5a9dd8d1e34f19f91cefb --- /dev/null +++ b/outputs_mp_2k/logs/250802-102040/events.out.tfevents.1754130040.192-222-50-191.149567.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c410ce5a421f653fc834489b143e6fc96cc96335c6bfa647531c54fc3a6cc64 +size 6402446 diff --git a/outputs_mp_2k/logs/250802-152757/events.out.tfevents.1754148477.192-222-50-191.310460.0 b/outputs_mp_2k/logs/250802-152757/events.out.tfevents.1754148477.192-222-50-191.310460.0 new file mode 100644 index 0000000000000000000000000000000000000000..cabf7b6f7b00ba2a118de11c07f7bd885706d001 --- /dev/null +++ b/outputs_mp_2k/logs/250802-152757/events.out.tfevents.1754148477.192-222-50-191.310460.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3424997fa63473acd32f5dc66e0bbfd05212fe349bfcc02029270c2e6f49096d +size 18912 diff --git a/outputs_mp_2k/logs/250802-153123/events.out.tfevents.1754148683.192-222-50-191.314251.0 b/outputs_mp_2k/logs/250802-153123/events.out.tfevents.1754148683.192-222-50-191.314251.0 new file mode 100644 index 0000000000000000000000000000000000000000..f2665334ef8ff02c699563f94b5ec58d747ef218 --- /dev/null +++ b/outputs_mp_2k/logs/250802-153123/events.out.tfevents.1754148683.192-222-50-191.314251.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f78bd6f5b87e39b8c617025bf6f0a96faf91879cd290f1ebdec44c9eaa5159d +size 88 diff --git a/outputs_mp_2k/logs/250802-153508/events.out.tfevents.1754148908.192-222-50-191.317237.0 b/outputs_mp_2k/logs/250802-153508/events.out.tfevents.1754148908.192-222-50-191.317237.0 new file mode 100644 index 0000000000000000000000000000000000000000..573df30a4e54f467e0b80cfe0aac3195677dfa4e --- /dev/null +++ b/outputs_mp_2k/logs/250802-153508/events.out.tfevents.1754148908.192-222-50-191.317237.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f44c201c5d8883c6c4b8d40552aad7493b60d32b1dd75ff4055edc078650c47f +size 3275693 diff --git a/quantization/__init__.py b/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfabe52b8cb6f260cdda6137b34df2f4736bd02f --- /dev/null +++ b/quantization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/quantization/__pycache__/__init__.cpython-312.pyc b/quantization/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb99178733bad93336b9563185eee9fa5fa0ac69 Binary files /dev/null and b/quantization/__pycache__/__init__.cpython-312.pyc differ diff --git a/quantization/__pycache__/core_vq_lsx_version.cpython-312.pyc b/quantization/__pycache__/core_vq_lsx_version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51d41ec0d168ef5a540100e01b18fe7998ad6e45 Binary files /dev/null and b/quantization/__pycache__/core_vq_lsx_version.cpython-312.pyc differ diff --git a/quantization/__pycache__/ddp_utils.cpython-312.pyc b/quantization/__pycache__/ddp_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2a3f017a068ac1d2a0ae1312bfe7c67c3751d29 Binary files /dev/null and b/quantization/__pycache__/ddp_utils.cpython-312.pyc differ diff --git a/quantization/__pycache__/distrib.cpython-312.pyc b/quantization/__pycache__/distrib.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b04fc25c36e1c4f9d8cae1714e8d7a166e48082c Binary files /dev/null and b/quantization/__pycache__/distrib.cpython-312.pyc differ diff --git a/quantization/__pycache__/vq.cpython-312.pyc b/quantization/__pycache__/vq.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bae729c0c13d0c7dce4c84a9d84612104785ab3e Binary files /dev/null and b/quantization/__pycache__/vq.cpython-312.pyc differ diff --git a/quantization/ac.py b/quantization/ac.py new file mode 100644 index 0000000000000000000000000000000000000000..318d993b610c78a46f3d605e7e2ccbdde4b915ec --- /dev/null +++ b/quantization/ac.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf( + pdf: torch.Tensor, total_range_bits: int, roundoff: float = 1e-8, min_range: int = 2, check: bool = True +) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2**total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] + if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2**self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream.""" + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + self.current -= b1 << self.max_bit + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2**self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/quantization/core_vq.py b/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..ad368a980582bcbba901f28b568c4bfb8f4099e6 --- /dev/null +++ b/quantization/core_vq.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +from xcodec.quantization.distrib import broadcast_tensors, rank + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) # get embedding based on index + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) # get index based on Euclidean distance + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)]) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/quantization/core_vq_lsx_version.py b/quantization/core_vq_lsx_version.py new file mode 100644 index 0000000000000000000000000000000000000000..d9add3f3016093a744804ae089d525c45d24ad16 --- /dev/null +++ b/quantization/core_vq_lsx_version.py @@ -0,0 +1,425 @@ +# Copyright (c) +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# This implementation is inspired from +# https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and +# https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist + +from .distrib import broadcast_tensors, is_distributed +from .ddp_utils import SyncFunction + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64): + """ + Memory-efficient K-means clustering. + Args: + samples (tensor): shape [N, D] + num_clusters (int): number of centroids. + num_iters (int): number of iterations. + frames_to_use (int): subsample size from total samples. + batch_size (int): batch size used in distance computation. + Returns: + means: [num_clusters, D] + bins: [num_clusters] (number of points per cluster) + """ + N, D = samples.shape + dtype, device = samples.dtype, samples.device + + if frames_to_use < N: + indices = torch.randperm(N, device=device)[:frames_to_use] + samples = samples[indices] + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + # Store cluster assignments + all_assignments = [] + + for i in range(0, samples.shape[0], batch_size): + batch = samples[i : i + batch_size] # [B, D] + dists = torch.cdist(batch, means, p=2) # [B, C] + assignments = dists.argmin(dim=1) # [B] + all_assignments.append(assignments) + + buckets = torch.cat(all_assignments, dim=0) # [N] + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + # Compute new means + new_means = torch.zeros_like(means) + for i in range(num_clusters): + mask = buckets == i + if mask.any(): + new_means[i] = samples[mask].mean(dim=0) + + means = torch.where(zero_mask[:, None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + # Flag variable to indicate whether the codebook is initialized + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + # Codebook + self.register_buffer("embed", embed) + # EMA codebook: eq. (7) in vqvae paper + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + """Initialize codebook. + Args: + data (tensor): [B * T, D]. + """ + if self.inited: + return + + ## NOTE (snippet added by Songxiang Liu): gather data from all gpus + if dist.is_available() and dist.is_initialized(): + # [B * T * world_size, D] + data = SyncFunction.apply(data) + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + ## NOTE (snippet added by Songxiang Liu): gather data from all gpus + if is_distributed(): + # [B * T * world_size, D] + batch_samples = SyncFunction.apply(batch_samples) + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) # [B, T, D] -> [B*T, D] + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + # shape: [B, T, D] + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) # [B, T, D] -> [B*T, D] + + # Initialize codebook + self.init_embed_(x) + + embed_ind = self.quantize(x) # [B*T,] + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size] + embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T] + quantize = self.dequantize(embed_ind) # [B, T, D] + + if self.training: + ### Update codebook by EMA + embed_onehot_sum = embed_onehot.sum(0) # [cb-size,] + embed_sum = x.t() @ embed_onehot # [D, cb-size] + if is_distributed(): + dist.all_reduce(embed_onehot_sum) + dist.all_reduce(embed_sum) + # Update ema cluster count N_i^t, eq. (6) in vqvae paper + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) + # Update ema embed: eq. (7) in vqvae paper + self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay) + # apply laplace smoothing + n = self.cluster_size.sum() + cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n + # Update ema embed: eq. (8) in vqvae paper + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d] + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n] + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)]) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/quantization/ddp_utils.py b/quantization/ddp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..990dca85fd518f09e2fcd528e28e7d256f64a15a --- /dev/null +++ b/quantization/ddp_utils.py @@ -0,0 +1,197 @@ +import logging +import random +import subprocess +from datetime import datetime + +import numpy as np +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import _find_tensors +import torch.optim +import torch.utils.data +from packaging import version +from omegaconf import OmegaConf + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def is_logging_process(): + return not dist.is_initialized() or dist.get_rank() == 0 + + +def get_logger(cfg, name=None): + # log_file_path is used when unit testing + if is_logging_process(): + logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True)) + return logging.getLogger(name) + + +# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 +class SyncFunction(torch.autograd.Function): + @staticmethod + # @torch.no_grad() + def forward(ctx, tensor): + ctx.batch_size = tensor.shape[0] + + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + + torch.distributed.all_gather(gathered_tensor, tensor) + gathered_tensor = torch.cat(gathered_tensor, 0) + + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + + idx_from = torch.distributed.get_rank() * ctx.batch_size + idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size + return grad_input[idx_from:idx_to] + + +def get_timestamp(): + return datetime.now().strftime("%y%m%d-%H%M%S") + + +def get_commit_hash(): + message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + return message.strip().decode("utf-8") + + +class DDP(DistributedDataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): # pragma: no cover + if version.parse(torch.__version__[:6]) < version.parse("1.11"): + self._sync_params() + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + assert len(self.device_ids) == 1 + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + if torch.is_grad_enabled(): + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + from torch.nn.parallel.distributed import ( + logging, + Join, + _DDPSink, + _tree_flatten_with_rref, + _tree_unflatten_with_rref, + ) + + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.logger.set_runtime_stats_and_log() + self.num_iterations += 1 + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size) + + # Calling _rebuild_buckets before forward compuation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + logging.info("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + buffer_hook_registered = hasattr(self, "buffer_hook") + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and self.num_iterations == 1 + ): + state_dict = { + "static_graph": self.static_graph, + "num_iterations": self.num_iterations, + } + + output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output) + output_placeholders = [None for _ in range(len(output_tensor_list))] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + self.reducer, + state_dict, + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref) + return output diff --git a/quantization/distrib.py b/quantization/distrib.py new file mode 100644 index 0000000000000000000000000000000000000000..cabf8f8a24eb710ab0eb83ce29ba054b7c11ccf3 --- /dev/null +++ b/quantization/distrib.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + # print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + else: + handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.0): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/quantization/vq.py b/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..dac26ba2a3bc2c97d6178fa33c629f324980d5a0 --- /dev/null +++ b/quantization/vq.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +# from .core_vq import ResidualVectorQuantization +from .core_vq_lsx_version import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dimension: int = 256, + codebook_dim: int = None, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.codebook_dim = codebook_dim + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_dim=self.codebook_dim, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + sample_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return quantized, codes, bw, torch.mean(commit_loss) + # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int: + """Return n_q based on specified target bandwidth.""" + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.n_q + if bandwidth and bandwidth > 0.0: + n_q = int(max(1, math.floor(bandwidth / bw_per_q))) + return n_q + + def get_bandwidth_per_quantizer(self, sample_rate: int): + """Return bandwidth per quantizer for a given input sample rate.""" + return math.log2(self.bins) * sample_rate / 1000 + + def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + """ + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + codes = self.vq.encode(x, n_q=n_q) + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + quantized = self.vq.decode(codes) + return quantized diff --git a/semantic_module.py b/semantic_module.py new file mode 100644 index 0000000000000000000000000000000000000000..f75efaf74e748cd733782133a65c2d0b6638b3a6 --- /dev/null +++ b/semantic_module.py @@ -0,0 +1,282 @@ +# Based on code from: https://github.com/zhenye234/xcodec +# Licensed under MIT License +# Modifications by BosonAI + +import torch +import torch.nn as nn + + +class Conv1d1x1(nn.Conv1d): + """1x1 Conv1d.""" + + def __init__(self, in_channels, out_channels, bias=True): + super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias) + + +class Conv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = -1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + if padding < 0: + padding = (kernel_size - 1) // 2 * dilation + self.dilation = dilation + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + """ + Args: + x (Tensor): Float tensor variable with the shape (B, C, T). + Returns: + Tensor: Float tensor variable with the shape (B, C, T). + """ + x = self.conv(x) + return x + + +class ResidualUnit(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + dilation=1, + bias=False, + nonlinear_activation="ELU", + nonlinear_activation_params={}, + ): + super().__init__() + self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + self.conv1 = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + dilation=dilation, + bias=bias, + ) + self.conv2 = Conv1d1x1(out_channels, out_channels, bias) + + def forward(self, x): + y = self.conv1(self.activation(x)) + y = self.conv2(self.activation(y)) + return x + y + + +class ConvTranspose1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding=-1, + output_padding=-1, + groups=1, + bias=True, + ): + super().__init__() + if padding < 0: + padding = (stride + 1) // 2 + if output_padding < 0: + output_padding = 1 if stride % 2 else 0 + self.deconv = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + ) + + def forward(self, x): + """ + Args: + x (Tensor): Float tensor variable with the shape (B, C, T). + Returns: + Tensor: Float tensor variable with the shape (B, C', T'). + """ + x = self.deconv(x) + return x + + +class EncoderBlock(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True + ): + super().__init__() + self.res_units = torch.nn.ModuleList() + for dilation in dilations: + self.res_units += [ResidualUnit(in_channels, in_channels, kernel_size=unit_kernel_size, dilation=dilation)] + self.num_res = len(self.res_units) + + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2 + stride=stride, + bias=bias, + ) + + def forward(self, x): + for idx in range(self.num_res): + x = self.res_units[idx](x) + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + input_channels: int, + encode_channels: int, + channel_ratios=(1, 1), + strides=(1, 1), + kernel_size=3, + bias=True, + block_dilations=(1, 1), + unit_kernel_size=3, + ): + super().__init__() + assert len(channel_ratios) == len(strides) + + self.conv = Conv1d( + in_channels=input_channels, out_channels=encode_channels, kernel_size=kernel_size, stride=1, bias=False + ) + self.conv_blocks = torch.nn.ModuleList() + in_channels = encode_channels + for idx, stride in enumerate(strides): + out_channels = int(encode_channels * channel_ratios[idx]) # could be float + self.conv_blocks += [ + EncoderBlock( + in_channels, + out_channels, + stride, + dilations=block_dilations, + unit_kernel_size=unit_kernel_size, + bias=bias, + ) + ] + in_channels = out_channels + self.num_blocks = len(self.conv_blocks) + self.out_channels = out_channels + + def forward(self, x): + x = self.conv(x) + for i in range(self.num_blocks): + x = self.conv_blocks[i](x) + return x + + +class DecoderBlock(nn.Module): + """Decoder block (no up-sampling)""" + + def __init__( + self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True + ): + super().__init__() + + if stride == 1: + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape + stride=stride, + bias=bias, + ) + else: + self.conv = ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(2 * stride), + stride=stride, + bias=bias, + ) + + self.res_units = torch.nn.ModuleList() + for idx, dilation in enumerate(dilations): + self.res_units += [ + ResidualUnit(out_channels, out_channels, kernel_size=unit_kernel_size, dilation=dilation) + ] + self.num_res = len(self.res_units) + + def forward(self, x): + x = self.conv(x) + for idx in range(self.num_res): + x = self.res_units[idx](x) + return x + + +class Decoder(nn.Module): + def __init__( + self, + code_dim: int, + output_channels: int, + decode_channels: int, + channel_ratios=(1, 1), + strides=(1, 1), + kernel_size=3, + bias=True, + block_dilations=(1, 1), + unit_kernel_size=3, + ): + super().__init__() + assert len(channel_ratios) == len(strides) + + self.conv1 = Conv1d( + in_channels=code_dim, + out_channels=int(decode_channels * channel_ratios[0]), + kernel_size=kernel_size, + stride=1, + bias=False, + ) + + self.conv_blocks = torch.nn.ModuleList() + for idx, stride in enumerate(strides): + in_channels = int(decode_channels * channel_ratios[idx]) + if idx < (len(channel_ratios) - 1): + out_channels = int(decode_channels * channel_ratios[idx + 1]) + else: + out_channels = decode_channels + self.conv_blocks += [ + DecoderBlock( + in_channels, + out_channels, + stride, + dilations=block_dilations, + unit_kernel_size=unit_kernel_size, + bias=bias, + ) + ] + self.num_blocks = len(self.conv_blocks) + + self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False) + + def forward(self, z): + x = self.conv1(z) + for i in range(self.num_blocks): + x = self.conv_blocks[i](x) + x = self.conv2(x) + return x diff --git a/train_boson.py b/train_boson.py new file mode 100644 index 0000000000000000000000000000000000000000..044890dc991eedcdb0fd525a9087098d25b1abc5 --- /dev/null +++ b/train_boson.py @@ -0,0 +1,891 @@ +#!/usr/bin/env python3 +""" +Training script for Boson Audio Codec with DAC-inspired losses +""" + +import os +import json +import argparse +import random +from pathlib import Path +from datetime import datetime +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter +import torchaudio +import librosa +from tqdm import tqdm +from audiotools import AudioSignal, STFTParams + +# Import from the provided codebase +from higgs_audio_tokenizer import HiggsAudioTokenizer +from quantization.distrib import broadcast_tensors, sync_buffer, is_distributed, world_size, rank +from quantization.ddp_utils import set_random_seed, is_logging_process, get_timestamp + +# Import DAC losses and discriminator +import sys +sys.path.append('.') # Add current directory to path +from loss import L1Loss, MultiScaleSTFTLoss, MelSpectrogramLoss, GANLoss +from discriminator import Discriminator + + +class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): + """Cosine scheduler with linear warmup""" + def __init__(self, optimizer, warmup_steps, total_steps, eta_min=1e-6, last_epoch=-1): + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.eta_min = eta_min + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch < self.warmup_steps: + # Linear warmup + warmup_factor = self.last_epoch / self.warmup_steps + return [base_lr * warmup_factor for base_lr in self.base_lrs] + else: + # Cosine annealing + progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) + cosine_factor = 0.5 * (1 + np.cos(np.pi * progress)) + return [self.eta_min + (base_lr - self.eta_min) * cosine_factor for base_lr in self.base_lrs] + + +class AudioDataset(Dataset): + """Dataset for loading audio files from CSV""" + def __init__(self, csv_path, sample_rate=44100, segment_duration=2.0, is_train=True): + self.df = pd.read_csv(csv_path) + self.sample_rate = sample_rate + self.segment_duration = segment_duration + self.segment_length = int(sample_rate * segment_duration) + self.is_train = is_train + + # Filter out files that don't exist + valid_files = [] + for idx, row in self.df.iterrows(): + if os.path.exists(row.iloc[0]): + valid_files.append(row.iloc[0]) + self.audio_paths = valid_files + print(f"Found {len(self.audio_paths)} valid audio files") + + def __len__(self): + return len(self.audio_paths) + + def __getitem__(self, idx): + audio_path = self.audio_paths[idx] + + try: + # Load audio using librosa + audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True) + + # Random segment extraction for training + if len(audio) > self.segment_length: + if self.is_train: + start = random.randint(0, len(audio) - self.segment_length) + else: + start = 0 # Always use beginning for validation + audio = audio[start:start + self.segment_length] + else: + # Pad if too short + audio = np.pad(audio, (0, self.segment_length - len(audio))) + + # Convert to tensor and add batch dimension + audio_tensor = torch.FloatTensor(audio).unsqueeze(0) + + return audio_tensor, audio_path + + except Exception as e: + print(f"Error loading {audio_path}: {e}") + # Return silence if loading fails + return torch.zeros(1, self.segment_length), audio_path + + +class BosonTrainer: + def __init__(self, args): + self.args = args + self.distributed = False + + # Check if we're in a distributed environment + if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) > 1: + self.distributed = True + self.setup_ddp() + self.device = torch.device(f'cuda:{args.local_rank}') + else: + # Single GPU mode + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + torch.cuda.set_device(0) + set_random_seed(args.seed) + + # Load config + with open(args.config, 'r') as f: + self.config = json.load(f) + + # Initialize models + self.model = self.build_model() + self.discriminator = self.build_discriminator() if args.use_discriminator else None + + # Setup data loaders + self.train_loader, self.val_loader = self.setup_data_loaders() + + # Setup optimizers + self.optimizer_g = torch.optim.AdamW( + self.model.parameters(), + lr=args.learning_rate, + betas=(0.5, 0.9), + weight_decay=args.weight_decay + ) + + if self.discriminator is not None: + self.optimizer_d = torch.optim.AdamW( + self.discriminator.parameters(), + lr=args.learning_rate * 2, # Typically discriminator learns faster + betas=(0.5, 0.9), + weight_decay=args.weight_decay + ) + + # Calculate total training steps + self.total_steps = args.num_epochs * len(self.train_loader) + + # Setup schedulers with warmup + self.scheduler_g = CosineWarmupScheduler( + self.optimizer_g, + warmup_steps=args.warmup_steps, + total_steps=self.total_steps, + eta_min=1e-6 + ) + + if self.discriminator is not None: + self.scheduler_d = CosineWarmupScheduler( + self.optimizer_d, + warmup_steps=args.warmup_steps, + total_steps=self.total_steps, + eta_min=1e-6 + ) + + # Setup losses + self.setup_losses() + + # Setup tensorboard + if not self.distributed or rank() == 0: + self.writer = SummaryWriter( + log_dir=os.path.join(args.output_dir, 'logs', get_timestamp()) + ) + + self.global_step = 0 + self.start_epoch = 0 + + # Load checkpoint if exists + if args.resume: + self.load_checkpoint() + + def setup_ddp(self): + """Initialize DDP""" + if 'LOCAL_RANK' in os.environ: + self.args.local_rank = int(os.environ['LOCAL_RANK']) + dist.init_process_group(backend='nccl') + torch.cuda.set_device(self.args.local_rank) + set_random_seed(self.args.seed + rank()) + + def build_model(self): + """Build and wrap model with DDP if needed""" + + print(self.config) + model = HiggsAudioTokenizer( + n_filters=self.config['n_filters'], + D=self.config['D'], + target_bandwidths=self.config['target_bandwidths'], + ratios=self.config['ratios'], + sample_rate=self.config['sample_rate'], + bins=self.config['bins'], + n_q=self.config['n_q'], + codebook_dim=self.config.get('codebook_dim', None), + semantic_techer=self.config['semantic_techer'], + device=self.device + ).to(self.device) + + if self.distributed: + # Broadcast model parameters to ensure all ranks have same initialization + broadcast_tensors(model.parameters()) + # Wrap with DDP + model = DDP(model, device_ids=[self.args.local_rank]) + + return model + + def build_discriminator(self): + """Build discriminator with DDP if needed""" + # Use sample rate from config + discriminator = Discriminator( + rates=[], # No multi-rate discriminator for now + periods=[2, 3, 5, 7, 11], + fft_sizes=[2048, 1024, 512], + sample_rate=self.config['sample_rate'], + ).to(self.device) + + if self.distributed: + broadcast_tensors(discriminator.parameters()) + discriminator = DDP(discriminator, device_ids=[self.args.local_rank]) + + return discriminator + + def setup_losses(self): + """Setup all loss functions""" + # Basic losses + self.l1_loss = L1Loss() + self.stft_loss = MultiScaleSTFTLoss( + window_lengths=[2048, 1024, 512, 256, 128], + loss_fn=nn.L1Loss(), + clamp_eps=1e-5, + mag_weight=1.0, + log_weight=1.0, + ) + self.mel_loss = MelSpectrogramLoss( + n_mels=[150, 80], + window_lengths=[2048, 512], + mel_fmin=[0.0, 0.0], + mel_fmax=[None, None], + clamp_eps=1e-5, + mag_weight=1.0, + log_weight=1.0, + ) + + # GAN loss if using discriminator + if self.discriminator is not None: + self.gan_loss = GANLoss(self.discriminator) + + # Loss weights (matching DAC's proven configuration) + self.loss_weights = { + 'rec': 1., # Waveform L1 loss + 'stft': 1., # Multi-scale STFT loss + # 'mel': 15.0, # Mel-spectrogram loss (ENABLE it after 20-25k steps) + 'mel': 0.0, # Mel-spectrogram loss (DISABLED) + 'commit': 0.25, # Commitment loss + 'semantic': 1., # Semantic loss + 'gen': 1., # Generator adversarial loss + 'feat': 1.0, # Feature matching loss + } + + def setup_data_loaders(self): + """Setup data loaders (distributed or single GPU)""" + # Split data into train/val + df = pd.read_csv(self.args.data_csv) + n_total = len(df) + n_train = int(n_total * 0.9) + + # Create temporary CSV files for train/val split + train_csv = '/tmp/train_audio.csv' + val_csv = '/tmp/val_audio.csv' + + if not self.distributed or rank() == 0: + df[:n_train].to_csv(train_csv, index=False) + df[n_train:].to_csv(val_csv, index=False) + + # Synchronize across processes if distributed + if self.distributed: + dist.barrier() + + # Create datasets + train_dataset = AudioDataset( + train_csv, + sample_rate=self.config['sample_rate'], + segment_duration=self.args.segment_duration, + is_train=True + ) + + val_dataset = AudioDataset( + val_csv, + sample_rate=self.config['sample_rate'], + segment_duration=self.args.segment_duration, + is_train=False + ) + + # Create samplers and loaders + if self.distributed: + train_sampler = DistributedSampler(train_dataset, shuffle=True) + val_sampler = DistributedSampler(val_dataset, shuffle=False) + else: + train_sampler = None + val_sampler = None + + train_loader = DataLoader( + train_dataset, + batch_size=self.args.batch_size, + sampler=train_sampler, + shuffle=(train_sampler is None), + num_workers=self.args.num_workers, + pin_memory=True, + drop_last=True + ) + + val_loader = DataLoader( + val_dataset, + batch_size=self.args.batch_size, + sampler=val_sampler, + shuffle=False, + num_workers=self.args.num_workers, + pin_memory=True, + drop_last=False + ) + + return train_loader, val_loader + + def is_main_process(self): + """Check if this is the main process""" + return not self.distributed or rank() == 0 + + def train_epoch(self, epoch): + """Train for one epoch""" + self.model.train() + if self.discriminator is not None: + self.discriminator.train() + + if self.distributed: + self.train_loader.sampler.set_epoch(epoch) + + total_losses = { + 'total': 0, 'rec': 0, 'stft': 0, 'mel': 0, + 'commit': 0, 'semantic': 0, 'gen': 0, 'feat': 0, 'disc': 0 + } + + pbar = tqdm(self.train_loader, desc=f'Epoch {epoch}', disable=not self.is_main_process()) + + for batch_idx, (audio, paths) in enumerate(pbar): + audio = audio.to(self.device) + + # Create AudioSignal objects for loss computation + audio_signal = AudioSignal(audio, self.config['sample_rate']) + + # Forward pass with random bandwidth + bw_idx = random.randint(0, len(self.config['target_bandwidths']) - 1) + bw = self.config['target_bandwidths'][bw_idx] + + output, commit_loss, semantic_loss, _ = self.model(audio, bw) + recons_signal = AudioSignal(output, self.config['sample_rate']) + + # Check if discriminator should be active (after discriminator_start_step) + use_discriminator = (self.discriminator is not None and + self.global_step >= self.args.discriminator_start_step) + + # Train discriminator first if using GAN and past the start step + if use_discriminator and self.global_step % self.args.disc_interval == 0: + self.optimizer_d.zero_grad() + disc_loss = self.gan_loss.discriminator_loss(recons_signal, audio_signal) + disc_loss.backward() + torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 10.0) + self.optimizer_d.step() + self.scheduler_d.step() + total_losses['disc'] += disc_loss.item() + + # Train generator + losses = {} + + # Reconstruction losses + losses['rec'] = self.l1_loss(recons_signal, audio_signal) + losses['stft'] = self.stft_loss(recons_signal, audio_signal) + # losses['mel'] = self.mel_loss(recons_signal, audio_signal) + losses['mel'] = torch.tensor(0.0, device=self.device) # 15. + losses['commit'] = commit_loss + losses['semantic'] = semantic_loss + + # GAN losses if discriminator is active + if use_discriminator: + gen_loss, feat_loss = self.gan_loss.generator_loss(recons_signal, audio_signal) + losses['gen'] = gen_loss + losses['feat'] = feat_loss + else: + # Set to zero for logging purposes + losses['gen'] = torch.tensor(0.0, device=self.device) + losses['feat'] = torch.tensor(0.0, device=self.device) + + # Total weighted loss + total_loss = sum(self.loss_weights.get(k, 0) * v for k, v in losses.items() + if k not in ['gen', 'feat'] or use_discriminator) + + # Backward pass + self.optimizer_g.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer_g.step() + self.scheduler_g.step() + + # Update metrics + total_losses['total'] += total_loss.item() + for k, v in losses.items(): + total_losses[k] += v.item() + + # Update progress bar + if self.is_main_process(): + pbar.set_postfix({ + 'loss': f'{total_loss.item():.4f}', + 'rec': f'{losses["rec"].item():.4f}', + 'mel': f'{losses["mel"].item():.4f}', + 'commit_loss': f'{losses["commit"].item():.4f}', + 'semantic_loss': f'{losses["semantic"].item():.4f}', + 'lr': f'{self.scheduler_g.get_last_lr()[0]:.9f}', + 'disc': 'ON' if use_discriminator else 'OFF', + 'step': self.global_step + }) + + # Log to tensorboard + if self.is_main_process() and self.global_step % self.args.log_interval == 0: + for k, v in losses.items(): + self.writer.add_scalar(f'train/{k}_loss', v.item(), self.global_step) + self.writer.add_scalar('train/total_loss', total_loss.item(), self.global_step) + self.writer.add_scalar('train/lr', self.scheduler_g.get_last_lr()[0], self.global_step) + self.writer.add_scalar('train/bandwidth', bw, self.global_step) + self.writer.add_scalar('train/discriminator_active', float(use_discriminator), self.global_step) + if use_discriminator: + self.writer.add_scalar('train/disc_loss', total_losses['disc'] / max(1, batch_idx), self.global_step) + + # Save checkpoint at step intervals + if self.global_step > 0 and self.global_step % self.args.save_step_interval == 0: + self.save_checkpoint_step(self.global_step) + if self.is_main_process(): + print(f"\nSaved checkpoint at step {self.global_step}") + + self.global_step += 1 + + # Return average losses + n_batches = len(self.train_loader) + return {k: v / n_batches for k, v in total_losses.items()} + + @torch.no_grad() + def validate(self, epoch): + """Validation loop""" + self.model.eval() + + total_losses = { + 'total': 0, 'rec': 0, 'stft': 0, 'mel': 0, + 'commit': 0, 'semantic': 0 + } + + # Store audio samples for tensorboard + audio_samples = {'train': [], 'val': []} + + for batch_idx, (audio, paths) in enumerate(tqdm(self.val_loader, desc='Validation', disable=not self.is_main_process())): + audio = audio.to(self.device) + audio_signal = AudioSignal(audio, self.config['sample_rate']) + + # Use medium bandwidth for validation + bw = self.config['target_bandwidths'][2] + + output, commit_loss, semantic_loss, _ = self.model(audio, bw) + recons_signal = AudioSignal(output, self.config['sample_rate']) + + # Compute losses + losses = { + 'rec': self.l1_loss(recons_signal, audio_signal), + 'stft': self.stft_loss(recons_signal, audio_signal), + 'mel': self.mel_loss(recons_signal, audio_signal), + 'commit': commit_loss, + 'semantic': semantic_loss + } + + total_loss = sum(self.loss_weights.get(k, 0) * v for k, v in losses.items()) + + total_losses['total'] += total_loss.item() + for k, v in losses.items(): + total_losses[k] += v.item() + + # Collect audio samples for tensorboard (first 3 from validation) + if self.is_main_process() and len(audio_samples['val']) < 3: + audio_samples['val'].append({ + 'original': audio[0].cpu(), + 'reconstructed': output[0].cpu(), + 'path': paths[0] + }) + + # Get train samples for comparison + if self.is_main_process(): + self.model.eval() + for batch_idx, (audio, paths) in enumerate(self.train_loader): + if len(audio_samples['train']) >= 3: + break + audio = audio.to(self.device) + bw = self.config['target_bandwidths'][2] + output, _, _, _ = self.model(audio, bw) + audio_samples['train'].append({ + 'original': audio[0].cpu(), + 'reconstructed': output[0].cpu(), + 'path': paths[0] + }) + + # Log audio samples to tensorboard + if self.is_main_process(): + for split in ['train', 'val']: + for idx, sample in enumerate(audio_samples[split]): + self.writer.add_audio( + f'{split}/original_{idx}', + sample['original'], + epoch, + sample_rate=self.config['sample_rate'] + ) + self.writer.add_audio( + f'{split}/reconstructed_{idx}', + sample['reconstructed'], + epoch, + sample_rate=self.config['sample_rate'] + ) + + # Average losses + n_batches = len(self.val_loader) + val_metrics = {k: v / n_batches for k, v in total_losses.items()} + + # Log validation metrics + if self.is_main_process(): + for key, value in val_metrics.items(): + self.writer.add_scalar(f'val/{key}_loss', value, epoch) + + return val_metrics + + def save_checkpoint(self, epoch, is_best=False): + """Save model checkpoint (epoch-based)""" + if not self.is_main_process(): + return + + model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict() + + # Get current learning rates for verification + current_lr_g = self.scheduler_g.get_last_lr()[0] + + checkpoint = { + 'epoch': epoch, + 'global_step': self.global_step, + 'model_state_dict': model_state, + 'optimizer_g_state_dict': self.optimizer_g.state_dict(), + 'scheduler_g_state_dict': self.scheduler_g.state_dict(), + 'scheduler_g_last_epoch': self.scheduler_g.last_epoch, # Explicitly save this + 'current_lr_g': current_lr_g, # Save for verification + 'config': self.config, + 'args': self.args + } + + if self.discriminator is not None: + disc_state = self.discriminator.module.state_dict() if self.distributed else self.discriminator.state_dict() + current_lr_d = self.scheduler_d.get_last_lr()[0] + checkpoint['discriminator_state_dict'] = disc_state + checkpoint['optimizer_d_state_dict'] = self.optimizer_d.state_dict() + checkpoint['scheduler_d_state_dict'] = self.scheduler_d.state_dict() + checkpoint['scheduler_d_last_epoch'] = self.scheduler_d.last_epoch + checkpoint['current_lr_d'] = current_lr_d + + # Save latest checkpoint + checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth') + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + torch.save(checkpoint, checkpoint_path) + + # Save best checkpoint + if is_best: + best_path = os.path.join(self.args.output_dir, 'checkpoints', 'best.pth') + torch.save(checkpoint, best_path) + + # Save periodic checkpoint + if epoch % self.args.save_interval == 0: + epoch_path = os.path.join(self.args.output_dir, 'checkpoints', f'epoch_{epoch}.pth') + torch.save(checkpoint, epoch_path) + + + def save_checkpoint_step(self, step): + """Save model checkpoint (step-based)""" + if not self.is_main_process(): + return + + # Get current epoch from training loop + current_epoch = step // len(self.train_loader) + + model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict() + + # Get current learning rates for verification + current_lr_g = self.scheduler_g.get_last_lr()[0] + + checkpoint = { + 'epoch': current_epoch, + 'global_step': step, + 'model_state_dict': model_state, + 'optimizer_g_state_dict': self.optimizer_g.state_dict(), + 'scheduler_g_state_dict': self.scheduler_g.state_dict(), + 'scheduler_g_last_epoch': self.scheduler_g.last_epoch, # Explicitly save this + 'current_lr_g': current_lr_g, # Save for verification + 'config': self.config, + 'args': self.args + } + + if self.discriminator is not None: + disc_state = self.discriminator.module.state_dict() if self.distributed else self.discriminator.state_dict() + current_lr_d = self.scheduler_d.get_last_lr()[0] + checkpoint['discriminator_state_dict'] = disc_state + checkpoint['optimizer_d_state_dict'] = self.optimizer_d.state_dict() + checkpoint['scheduler_d_state_dict'] = self.scheduler_d.state_dict() + checkpoint['scheduler_d_last_epoch'] = self.scheduler_d.last_epoch + checkpoint['current_lr_d'] = current_lr_d + + # Save step-based checkpoint + step_path = os.path.join(self.args.output_dir, 'checkpoints', f'step_{step}.pth') + torch.save(checkpoint, step_path) + + # Also update latest checkpoint + latest_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth') + torch.save(checkpoint, latest_path) + + # Keep only the last N step-based checkpoints to save disk space + if self.args.keep_last_n_steps > 0: + checkpoint_dir = os.path.join(self.args.output_dir, 'checkpoints') + step_checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith('step_')]) + if len(step_checkpoints) > self.args.keep_last_n_steps: + for old_checkpoint in step_checkpoints[:-self.args.keep_last_n_steps]: + os.remove(os.path.join(checkpoint_dir, old_checkpoint)) + + + def load_checkpoint(self): + """Load checkpoint with proper state restoration""" + checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth') + if os.path.exists(checkpoint_path): + print(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) + + # Load model state + if self.distributed: + self.model.module.load_state_dict(checkpoint['model_state_dict']) + else: + self.model.load_state_dict(checkpoint['model_state_dict']) + + # Load optimizer state + self.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict']) + + # Load scheduler state + self.scheduler_g.load_state_dict(checkpoint['scheduler_g_state_dict']) + + # Restore scheduler's last_epoch from checkpoint + if 'scheduler_g_last_epoch' in checkpoint: + self.scheduler_g.last_epoch = checkpoint['scheduler_g_last_epoch'] + else: + # Fallback: use global_step if the explicit value wasn't saved + self.scheduler_g.last_epoch = checkpoint['global_step'] + + # Force scheduler to recompute its internal state + self.scheduler_g._last_lr = self.scheduler_g.get_lr() + + # Load discriminator if present + if self.discriminator is not None and 'discriminator_state_dict' in checkpoint: + if self.distributed: + self.discriminator.module.load_state_dict(checkpoint['discriminator_state_dict']) + else: + self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) + self.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict']) + self.scheduler_d.load_state_dict(checkpoint['scheduler_d_state_dict']) + + # Restore discriminator scheduler's last_epoch + if 'scheduler_d_last_epoch' in checkpoint: + self.scheduler_d.last_epoch = checkpoint['scheduler_d_last_epoch'] + else: + self.scheduler_d.last_epoch = checkpoint['global_step'] + + self.scheduler_d._last_lr = self.scheduler_d.get_lr() + + # Restore training state + self.start_epoch = checkpoint['epoch'] + 1 + self.global_step = checkpoint['global_step'] + + # Verify learning rate restoration + current_lr_g = self.scheduler_g.get_last_lr()[0] + saved_lr_g = checkpoint.get('current_lr_g', None) + + print(f"\n{'='*60}") + print(f"CHECKPOINT LOADED SUCCESSFULLY") + print(f"{'='*60}") + print(f"Resumed from epoch: {checkpoint['epoch']}") + print(f"Global step: {self.global_step}") + print(f"Scheduler last_epoch: {self.scheduler_g.last_epoch}") + print(f"Current learning rate (generator): {current_lr_g:.9f}") + if saved_lr_g is not None: + print(f"Saved learning rate (generator): {saved_lr_g:.9f}") + if abs(current_lr_g - saved_lr_g) > 1e-9: + print("⚠️ WARNING: Learning rate mismatch! This might indicate improper state restoration.") + + if self.discriminator is not None: + current_lr_d = self.scheduler_d.get_last_lr()[0] + saved_lr_d = checkpoint.get('current_lr_d', None) + print(f"Current learning rate (discriminator): {current_lr_d:.9f}") + if saved_lr_d is not None: + print(f"Saved learning rate (discriminator): {saved_lr_d:.9f}") + print(f"Discriminator status: {'ACTIVE' if self.global_step >= self.args.discriminator_start_step else f'INACTIVE (starts at step {self.args.discriminator_start_step})'}") + + print(f"Next epoch: {self.start_epoch}") + print(f"Next step checkpoint at: step {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}") + print(f"{'='*60}\n") + + # Double-check by creating a fresh scheduler and comparing + if self.global_step > 0: + temp_scheduler = CosineWarmupScheduler( + self.optimizer_g, + self.args.warmup_steps, + self.total_steps, + eta_min=1e-6, + last_epoch=-1 + ) + # Step it to the current global step + for _ in range(self.global_step): + temp_scheduler.step() + expected_lr = temp_scheduler.get_last_lr()[0] + if abs(current_lr_g - expected_lr) > 1e-9: + print(f"⚠️ Learning rate verification failed!") + print(f" Expected: {expected_lr:.9f}") + print(f" Got: {current_lr_g:.9f}") + print(" The scheduler state might not be properly restored.") + else: + print(f"No checkpoint found at {checkpoint_path}, starting from scratch") + + def train(self): + """Main training loop""" + best_val_loss = float('inf') + + # Print training configuration + if self.is_main_process(): + print(f"\n{'='*50}") + print(f"Training Configuration:") + print(f"{'='*50}") + print(f"Total epochs: {self.args.num_epochs}") + print(f"Steps per epoch: {len(self.train_loader)}") + print(f"Total steps: {self.total_steps}") + print(f"Warmup steps: {self.args.warmup_steps}") + print(f"Discriminator starts at step: {self.args.discriminator_start_step}") + print(f"Checkpoint saving:") + print(f" - Every {self.args.save_interval} epochs") + print(f" - Every {self.args.save_step_interval} steps") + print(f" - Keep last {self.args.keep_last_n_steps} step checkpoints") + if self.start_epoch > 0: + print(f"RESUMING from epoch {self.start_epoch}, step {self.global_step}") + print(f"{'='*50}\n") + + for epoch in range(self.start_epoch, self.args.num_epochs): + # IMPORTANT: Set the epoch for distributed sampler when resuming + # This ensures proper data shuffling across epochs + if self.distributed and hasattr(self.train_loader.sampler, 'set_epoch'): + self.train_loader.sampler.set_epoch(epoch) + + # Train + train_metrics = self.train_epoch(epoch) + + # Validate + val_metrics = self.validate(epoch) + + # Log epoch metrics + if self.is_main_process(): + print(f"\nEpoch {epoch} Summary:") + print(f"Train - Total: {train_metrics['total']:.4f}, Rec: {train_metrics['rec']:.4f}, " + f"STFT: {train_metrics['stft']:.4f}, Mel: {train_metrics['mel']:.4f}, " + f"Commit: {train_metrics['commit']:.4f}, Semantic: {train_metrics['semantic']:.4f}") + if self.discriminator is not None: + print(f" Gen: {train_metrics['gen']:.4f}, Feat: {train_metrics['feat']:.4f}, " + f"Disc: {train_metrics['disc']:.4f}") + print(f" Discriminator Status: {'Active' if self.global_step >= self.args.discriminator_start_step else f'Starting at step {self.args.discriminator_start_step}'}") + print(f"Val - Total: {val_metrics['total']:.4f}, Rec: {val_metrics['rec']:.4f}, " + f"STFT: {val_metrics['stft']:.4f}, Mel: {val_metrics['mel']:.4f}, " + f"Commit: {val_metrics['commit']:.4f}, Semantic: {val_metrics['semantic']:.4f}") + print(f"Current Step: {self.global_step}, Next step checkpoint at: {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}") + print(f"Current LR: {self.scheduler_g.get_last_lr()[0]:.9f}") + + # Save checkpoint + is_best = val_metrics['total'] < best_val_loss + if is_best: + best_val_loss = val_metrics['total'] + self.save_checkpoint(epoch, is_best) + + # Save final model + if self.is_main_process(): + model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict() + + final_path = os.path.join(self.args.output_dir, 'checkpoints', 'final.pth') + torch.save({ + 'model_state_dict': model_state, + 'config': self.config + }, final_path) + + # Also save just the model weights in the format expected by the original code + model_only_path = os.path.join(self.args.output_dir, 'model.pth') + torch.save(model_state, model_only_path) + + # Copy config + import shutil + shutil.copy(self.args.config, os.path.join(self.args.output_dir, 'config.json')) + + # Cleanup + if self.is_main_process(): + self.writer.close() + if self.distributed: + dist.destroy_process_group() + + +def main(): + parser = argparse.ArgumentParser(description='Train Boson Audio Codec') + + # Data arguments + parser.add_argument('--data_csv', type=str, required=True, + help='Path to CSV file containing audio file paths') + parser.add_argument('--config', type=str, default='config.json', + help='Path to config JSON file') + + # Training argumentssss + parser.add_argument('--batch_size', type=int, default=32, + help='Batch size per GPU') + parser.add_argument('--num_epochs', type=int, default=100, + help='Number of training epochs') + parser.add_argument('--learning_rate', type=float, default=1e-4, + help='Initial learning rate') + parser.add_argument('--weight_decay', type=float, default=0.01, + help='Weight decay') + parser.add_argument('--segment_duration', type=float, default=2., + help='Audio segment duration in seconds') + + # Scheduler arguments + parser.add_argument('--warmup_steps', type=int, default=5000, + help='Number of warmup steps for cosine scheduler') + + # Loss arguments + parser.add_argument('--use_discriminator', action='store_true', + help='Use adversarial training with discriminator') + parser.add_argument('--discriminator_start_step', type=int, default=24_000, + help='Start training discriminator after N steps') + parser.add_argument('--disc_interval', type=int, default=1, + help='Train discriminator every N steps') + + # System arguments + parser.add_argument('--output_dir', type=str, default='outputs', + help='Output directory for checkpoints and logs') + parser.add_argument('--num_workers', type=int, default=16, + help='Number of data loading workers') + parser.add_argument('--seed', type=int, default=42, + help='Random seed') + parser.add_argument('--local_rank', type=int, default=0, + help='Local rank for distributed training') + + # Logging arguments + parser.add_argument('--log_interval', type=int, default=10, + help='Log every N steps') + parser.add_argument('--save_interval', type=int, default=1, + help='Save checkpoint every N epochs') + parser.add_argument('--save_step_interval', type=int, default=1000, + help='Save checkpoint every N steps') + parser.add_argument('--keep_last_n_steps', type=int, default=5, + help='Keep only the last N step-based checkpoints (0 to keep all)') + + # Resume training + parser.add_argument('--resume', action='store_true', + help='Resume training from latest checkpoint') + + args = parser.parse_args() + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Train + trainer = BosonTrainer(args) + trainer.train() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/train_boson_mixed_precision.py b/train_boson_mixed_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..31f967ba62d421cec837449b9cf75ffd61e44609 --- /dev/null +++ b/train_boson_mixed_precision.py @@ -0,0 +1,977 @@ +#!/usr/bin/env python3 +""" +Training script for Boson Audio Codec with DAC-inspired losses +""" + +import os +import json +import argparse +import random +from pathlib import Path +from datetime import datetime +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter +from torch.cuda.amp import autocast, GradScaler +import torchaudio +import librosa +from tqdm import tqdm +from audiotools import AudioSignal, STFTParams + +# Import from the provided codebase +from higgs_audio_tokenizer import HiggsAudioTokenizer +from quantization.distrib import broadcast_tensors, sync_buffer, is_distributed, world_size, rank +from quantization.ddp_utils import set_random_seed, is_logging_process, get_timestamp + +# Import DAC losses and discriminator +import sys +sys.path.append('.') # Add current directory to path +from loss import L1Loss, MultiScaleSTFTLoss, MelSpectrogramLoss, GANLoss +from discriminator import Discriminator + + +class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): + """Cosine scheduler with linear warmup""" + def __init__(self, optimizer, warmup_steps, total_steps, eta_min=1e-6, last_epoch=-1): + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.eta_min = eta_min + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch < self.warmup_steps: + # Linear warmup + warmup_factor = self.last_epoch / self.warmup_steps + return [base_lr * warmup_factor for base_lr in self.base_lrs] + else: + # Cosine annealing + progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) + cosine_factor = 0.5 * (1 + np.cos(np.pi * progress)) + return [self.eta_min + (base_lr - self.eta_min) * cosine_factor for base_lr in self.base_lrs] + + +class AudioDataset(Dataset): + """Dataset for loading audio files from CSV""" + def __init__(self, csv_path, sample_rate=44100, segment_duration=2.0, is_train=True): + self.df = pd.read_csv(csv_path) + self.sample_rate = sample_rate + self.segment_duration = segment_duration + self.segment_length = int(sample_rate * segment_duration) + self.is_train = is_train + + # Filter out files that don't exist + valid_files = [] + for idx, row in self.df.iterrows(): + if os.path.exists(row.iloc[0]): + valid_files.append(row.iloc[0]) + self.audio_paths = valid_files + print(f"Found {len(self.audio_paths)} valid audio files") + + def __len__(self): + return len(self.audio_paths) + + def __getitem__(self, idx): + audio_path = self.audio_paths[idx] + + try: + # Load audio using librosa + audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True) + + # Random segment extraction for training + if len(audio) > self.segment_length: + if self.is_train: + start = random.randint(0, len(audio) - self.segment_length) + else: + start = 0 # Always use beginning for validation + audio = audio[start:start + self.segment_length] + else: + # Pad if too short + audio = np.pad(audio, (0, self.segment_length - len(audio))) + + # Convert to tensor and add batch dimension + audio_tensor = torch.FloatTensor(audio).unsqueeze(0) + + return audio_tensor, audio_path + + except Exception as e: + print(f"Error loading {audio_path}: {e}") + # Return silence if loading fails + return torch.zeros(1, self.segment_length), audio_path + + +class BosonTrainer: + def __init__(self, args): + self.args = args + self.distributed = False + + # Check if we're in a distributed environment + if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) > 1: + self.distributed = True + self.setup_ddp() + self.device = torch.device(f'cuda:{args.local_rank}') + else: + # Single GPU mode + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + torch.cuda.set_device(0) + set_random_seed(args.seed) + + # Load config + with open(args.config, 'r') as f: + self.config = json.load(f) + + # Initialize models + self.model = self.build_model() + self.discriminator = self.build_discriminator() if args.use_discriminator else None + + # Setup data loaders + self.train_loader, self.val_loader = self.setup_data_loaders() + + # Setup optimizers + self.optimizer_g = torch.optim.AdamW( + self.model.parameters(), + lr=args.learning_rate, + betas=(0.5, 0.9), + weight_decay=args.weight_decay + ) + + if self.discriminator is not None: + self.optimizer_d = torch.optim.AdamW( + self.discriminator.parameters(), + lr=args.learning_rate * 2, # Typically discriminator learns faster + betas=(0.5, 0.9), + weight_decay=args.weight_decay + ) + + # Initialize gradient scalers for mixed precision + if args.use_mixed_precision: + self.scaler_g = GradScaler() + self.scaler_d = GradScaler() if self.discriminator is not None else None + else: + self.scaler_g = None + self.scaler_d = None + + # Calculate total training steps + self.total_steps = args.num_epochs * len(self.train_loader) + + # Setup schedulers with warmup + self.scheduler_g = CosineWarmupScheduler( + self.optimizer_g, + warmup_steps=args.warmup_steps, + total_steps=self.total_steps, + eta_min=1e-6 + ) + + if self.discriminator is not None: + self.scheduler_d = CosineWarmupScheduler( + self.optimizer_d, + warmup_steps=args.warmup_steps, + total_steps=self.total_steps, + eta_min=1e-6 + ) + + # Setup losses + self.setup_losses() + + # Setup tensorboard + if not self.distributed or rank() == 0: + self.writer = SummaryWriter( + log_dir=os.path.join(args.output_dir, 'logs', get_timestamp()) + ) + + self.global_step = 0 + self.start_epoch = 0 + + # Load checkpoint if exists + if args.resume: + self.load_checkpoint() + + def setup_ddp(self): + """Initialize DDP""" + if 'LOCAL_RANK' in os.environ: + self.args.local_rank = int(os.environ['LOCAL_RANK']) + dist.init_process_group(backend='nccl') + torch.cuda.set_device(self.args.local_rank) + set_random_seed(self.args.seed + rank()) + + def build_model(self): + """Build and wrap model with DDP if needed""" + + print(self.config) + model = HiggsAudioTokenizer( + n_filters=self.config['n_filters'], + D=self.config['D'], + target_bandwidths=self.config['target_bandwidths'], + ratios=self.config['ratios'], + sample_rate=self.config['sample_rate'], + bins=self.config['bins'], + n_q=self.config['n_q'], + codebook_dim=self.config.get('codebook_dim', None), + semantic_techer=self.config['semantic_techer'], + device=self.device + ).to(self.device) + + if self.distributed: + # Broadcast model parameters to ensure all ranks have same initialization + broadcast_tensors(model.parameters()) + # Wrap with DDP + model = DDP(model, device_ids=[self.args.local_rank]) + + return model + + # def build_discriminator(self): + # """Build discriminator with DDP if needed""" + # # Use sample rate from config + # discriminator = Discriminator( + # rates=[], # No multi-rate discriminator for now + # periods=[2, 3, 5, 7, 11], + # fft_sizes=[2048, 1024, 512], + # sample_rate=self.config['sample_rate'], + # ).to(self.device) + + # if self.distributed: + # broadcast_tensors(discriminator.parameters()) + # discriminator = DDP(discriminator, device_ids=[self.args.local_rank]) + + # return discriminator + + def build_discriminator(self): + """Build discriminator with DDP if needed""" + discriminator = Discriminator( + rates=[], # No multi-rate discriminator + periods=[2, 3, 5, 7, 11], + fft_sizes=[2048, 1024, 512], + sample_rate=self.config['sample_rate'], # 44100 + ).to(self.device) + + if self.distributed: + broadcast_tensors(discriminator.parameters()) + discriminator = DDP(discriminator, device_ids=[self.args.local_rank]) + + return discriminator + + def setup_losses(self): + """Setup all loss functions""" + # Basic losses + self.l1_loss = L1Loss() + self.stft_loss = MultiScaleSTFTLoss( + window_lengths=[2048, 1024, 512, 256, 128], + loss_fn=nn.L1Loss(), + clamp_eps=1e-5, + mag_weight=1.0, + log_weight=1.0, + ) + self.mel_loss = MelSpectrogramLoss( + n_mels=[150, 80], + window_lengths=[2048, 512], + mel_fmin=[0.0, 0.0], + mel_fmax=[None, None], + clamp_eps=1e-5, + mag_weight=1.0, + log_weight=1.0, + ) + + # GAN loss if using discriminator + if self.discriminator is not None: + self.gan_loss = GANLoss(self.discriminator) + + # Loss weights (matching DAC's proven configuration) + self.loss_weights = { + 'rec': 1., # Waveform L1 loss + 'stft': 1., # Multi-scale STFT loss + 'mel': 45.0, # Mel-spectrogram loss (DISABLED) + #'mel': 0.0, # Mel-spectrogram loss (DISABLED) + 'commit': 0.25, # Commitment loss + 'semantic': 1., # Semantic loss + 'gen': 1., # Generator adversarial loss + 'feat': 2.0, # Feature matching loss + } + + def setup_data_loaders(self): + """Setup data loaders (distributed or single GPU)""" + # Split data into train/val + df = pd.read_csv(self.args.data_csv) + n_total = len(df) + n_train = int(n_total * 0.9) + + # Create temporary CSV files for train/val split + train_csv = '/tmp/train_audio.csv' + val_csv = '/tmp/val_audio.csv' + + if not self.distributed or rank() == 0: + df[:n_train].to_csv(train_csv, index=False) + df[n_train:].to_csv(val_csv, index=False) + + # Synchronize across processes if distributed + if self.distributed: + dist.barrier() + + # Create datasets + train_dataset = AudioDataset( + train_csv, + sample_rate=self.config['sample_rate'], + segment_duration=self.args.segment_duration, + is_train=True + ) + + val_dataset = AudioDataset( + val_csv, + sample_rate=self.config['sample_rate'], + segment_duration=self.args.segment_duration, + is_train=False + ) + + # Create samplers and loaders + if self.distributed: + train_sampler = DistributedSampler(train_dataset, shuffle=True) + val_sampler = DistributedSampler(val_dataset, shuffle=False) + else: + train_sampler = None + val_sampler = None + + train_loader = DataLoader( + train_dataset, + batch_size=self.args.batch_size, + sampler=train_sampler, + shuffle=(train_sampler is None), + num_workers=self.args.num_workers, + pin_memory=True, + drop_last=True + ) + + val_loader = DataLoader( + val_dataset, + batch_size=self.args.batch_size, + sampler=val_sampler, + shuffle=False, + num_workers=self.args.num_workers, + pin_memory=True, + drop_last=False + ) + + return train_loader, val_loader + + def is_main_process(self): + """Check if this is the main process""" + return not self.distributed or rank() == 0 + + def train_epoch(self, epoch): + """Train for one epoch""" + self.model.train() + if self.discriminator is not None: + self.discriminator.train() + + if self.distributed: + self.train_loader.sampler.set_epoch(epoch) + + total_losses = { + 'total': 0, 'rec': 0, 'stft': 0, 'mel': 0, + 'commit': 0, 'semantic': 0, 'gen': 0, 'feat': 0, 'disc': 0 + } + + pbar = tqdm(self.train_loader, desc=f'Epoch {epoch}', disable=not self.is_main_process()) + + for batch_idx, (audio, paths) in enumerate(pbar): + audio = audio.to(self.device) + + # Create AudioSignal objects for loss computation + audio_signal = AudioSignal(audio, self.config['sample_rate']) + + # Forward pass with random bandwidth + bw_idx = random.randint(0, len(self.config['target_bandwidths']) - 1) + bw = self.config['target_bandwidths'][bw_idx] + + # Use autocast for mixed precision + with autocast(dtype=torch.bfloat16, enabled=self.args.use_mixed_precision): + output, commit_loss, semantic_loss, _ = self.model(audio, bw) + recons_signal = AudioSignal(output, self.config['sample_rate']) + + # Check if discriminator should be active (after discriminator_start_step) + use_discriminator = (self.discriminator is not None and + self.global_step >= self.args.discriminator_start_step) + + # Train discriminator first if using GAN and past the start step + if use_discriminator and self.global_step % self.args.disc_interval == 0: + self.optimizer_d.zero_grad() + + with autocast(dtype=torch.bfloat16, enabled=self.args.use_mixed_precision): + disc_loss = self.gan_loss.discriminator_loss(recons_signal, audio_signal) + + if self.scaler_d is not None: + self.scaler_d.scale(disc_loss).backward() + self.scaler_d.unscale_(self.optimizer_d) + torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 10.0) + self.scaler_d.step(self.optimizer_d) + self.scaler_d.update() + else: + disc_loss.backward() + torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 10.0) + self.optimizer_d.step() + + self.scheduler_d.step() + total_losses['disc'] += disc_loss.item() + + # Train generator + losses = {} + + # Compute losses with autocast + with autocast(dtype=torch.bfloat16, enabled=self.args.use_mixed_precision): + # Reconstruction losses + losses['rec'] = self.l1_loss(recons_signal, audio_signal) + losses['stft'] = self.stft_loss(recons_signal, audio_signal) + losses['mel'] = self.mel_loss(recons_signal, audio_signal) + # losses['mel'] = torch.tensor(0.0, device=self.device) # 15. + losses['commit'] = commit_loss + losses['semantic'] = semantic_loss + + # GAN losses if discriminator is active + if use_discriminator: + gen_loss, feat_loss = self.gan_loss.generator_loss(recons_signal, audio_signal) + losses['gen'] = gen_loss + losses['feat'] = feat_loss + else: + # Set to zero for logging purposes + losses['gen'] = torch.tensor(0.0, device=self.device) + losses['feat'] = torch.tensor(0.0, device=self.device) + + # Total weighted loss + total_loss = sum(self.loss_weights.get(k, 0) * v for k, v in losses.items() + if k not in ['gen', 'feat'] or use_discriminator) + + # Backward pass + self.optimizer_g.zero_grad() + + if self.scaler_g is not None: + self.scaler_g.scale(total_loss).backward() + self.scaler_g.unscale_(self.optimizer_g) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.scaler_g.step(self.optimizer_g) + self.scaler_g.update() + else: + total_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer_g.step() + + self.scheduler_g.step() + + # Update metrics + total_losses['total'] += total_loss.item() + for k, v in losses.items(): + total_losses[k] += v.item() + + # Update progress bar + if self.is_main_process(): + pbar.set_postfix({ + 'loss': f'{total_loss.item():.4f}', + 'rec': f'{losses["rec"].item():.4f}', + 'mel': f'{losses["mel"].item():.4f}', + 'commit_loss': f'{losses["commit"].item():.4f}', + 'semantic_loss': f'{losses["semantic"].item():.4f}', + 'lr': f'{self.scheduler_g.get_last_lr()[0]:.9f}', + 'disc': 'ON' if use_discriminator else 'OFF', + 'step': self.global_step + }) + + # Log to tensorboard + if self.is_main_process() and self.global_step % self.args.log_interval == 0: + for k, v in losses.items(): + self.writer.add_scalar(f'train/{k}_loss', v.item(), self.global_step) + self.writer.add_scalar('train/total_loss', total_loss.item(), self.global_step) + self.writer.add_scalar('train/lr', self.scheduler_g.get_last_lr()[0], self.global_step) + self.writer.add_scalar('train/bandwidth', bw, self.global_step) + self.writer.add_scalar('train/discriminator_active', float(use_discriminator), self.global_step) + if use_discriminator: + self.writer.add_scalar('train/disc_loss', total_losses['disc'] / max(1, batch_idx), self.global_step) + if self.scaler_g is not None: + self.writer.add_scalar('train/grad_scale', self.scaler_g.get_scale(), self.global_step) + + # Save checkpoint at step intervals + if self.global_step > 0 and self.global_step % self.args.save_step_interval == 0: + self.save_checkpoint_step(self.global_step) + if self.is_main_process(): + print(f"\nSaved checkpoint at step {self.global_step}") + + self.global_step += 1 + + # Return average losses + n_batches = len(self.train_loader) + return {k: v / n_batches for k, v in total_losses.items()} + + @torch.no_grad() + def validate(self, epoch): + """Validation loop""" + self.model.eval() + + total_losses = { + 'total': 0, 'rec': 0, 'stft': 0, 'mel': 0, + 'commit': 0, 'semantic': 0 + } + + # Store audio samples for tensorboard + audio_samples = {'train': [], 'val': []} + + for batch_idx, (audio, paths) in enumerate(tqdm(self.val_loader, desc='Validation', disable=not self.is_main_process())): + audio = audio.to(self.device) + audio_signal = AudioSignal(audio, self.config['sample_rate']) + + # Use medium bandwidth for validation + bw = self.config['target_bandwidths'][2] + + # Use autocast for validation too + with autocast(dtype=torch.bfloat16, enabled=self.args.use_mixed_precision): + output, commit_loss, semantic_loss, _ = self.model(audio, bw) + recons_signal = AudioSignal(output, self.config['sample_rate']) + + # Compute losses + losses = { + 'rec': self.l1_loss(recons_signal, audio_signal), + 'stft': self.stft_loss(recons_signal, audio_signal), + 'mel': self.mel_loss(recons_signal, audio_signal), + 'commit': commit_loss, + 'semantic': semantic_loss + } + + total_loss = sum(self.loss_weights.get(k, 0) * v for k, v in losses.items()) + + total_losses['total'] += total_loss.item() + for k, v in losses.items(): + total_losses[k] += v.item() + + # Collect audio samples for tensorboard (first 3 from validation) + if self.is_main_process() and len(audio_samples['val']) < 3: + audio_samples['val'].append({ + 'original': audio[0].cpu(), + 'reconstructed': output[0].cpu(), + 'path': paths[0] + }) + + # Get train samples for comparison + if self.is_main_process(): + self.model.eval() + for batch_idx, (audio, paths) in enumerate(self.train_loader): + if len(audio_samples['train']) >= 3: + break + audio = audio.to(self.device) + bw = self.config['target_bandwidths'][2] + with autocast(dtype=torch.bfloat16, enabled=self.args.use_mixed_precision): + output, _, _, _ = self.model(audio, bw) + audio_samples['train'].append({ + 'original': audio[0].cpu(), + 'reconstructed': output[0].cpu(), + 'path': paths[0] + }) + + # Log audio samples to tensorboard + if self.is_main_process(): + for split in ['train', 'val']: + for idx, sample in enumerate(audio_samples[split]): + self.writer.add_audio( + f'{split}/original_{idx}', + sample['original'], + epoch, + sample_rate=self.config['sample_rate'] + ) + self.writer.add_audio( + f'{split}/reconstructed_{idx}', + sample['reconstructed'], + epoch, + sample_rate=self.config['sample_rate'] + ) + + # Average losses + n_batches = len(self.val_loader) + val_metrics = {k: v / n_batches for k, v in total_losses.items()} + + # Log validation metrics + if self.is_main_process(): + for key, value in val_metrics.items(): + self.writer.add_scalar(f'val/{key}_loss', value, epoch) + + return val_metrics + + def save_checkpoint(self, epoch, is_best=False): + """Save model checkpoint (epoch-based)""" + if not self.is_main_process(): + return + + model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict() + + # Get current learning rates for verification + current_lr_g = self.scheduler_g.get_last_lr()[0] + + checkpoint = { + 'epoch': epoch, + 'global_step': self.global_step, + 'model_state_dict': model_state, + 'optimizer_g_state_dict': self.optimizer_g.state_dict(), + 'scheduler_g_state_dict': self.scheduler_g.state_dict(), + 'scheduler_g_last_epoch': self.scheduler_g.last_epoch, # Explicitly save this + 'current_lr_g': current_lr_g, # Save for verification + 'config': self.config, + 'args': self.args + } + + # Save gradient scaler states if using mixed precision + if self.scaler_g is not None: + checkpoint['scaler_g_state_dict'] = self.scaler_g.state_dict() + + if self.discriminator is not None: + disc_state = self.discriminator.module.state_dict() if self.distributed else self.discriminator.state_dict() + current_lr_d = self.scheduler_d.get_last_lr()[0] + checkpoint['discriminator_state_dict'] = disc_state + checkpoint['optimizer_d_state_dict'] = self.optimizer_d.state_dict() + checkpoint['scheduler_d_state_dict'] = self.scheduler_d.state_dict() + checkpoint['scheduler_d_last_epoch'] = self.scheduler_d.last_epoch + checkpoint['current_lr_d'] = current_lr_d + + if self.scaler_d is not None: + checkpoint['scaler_d_state_dict'] = self.scaler_d.state_dict() + + # Save latest checkpoint + checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth') + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + torch.save(checkpoint, checkpoint_path) + + # Save best checkpoint + if is_best: + best_path = os.path.join(self.args.output_dir, 'checkpoints', 'best.pth') + torch.save(checkpoint, best_path) + + # Save periodic checkpoint + if epoch % self.args.save_interval == 0: + epoch_path = os.path.join(self.args.output_dir, 'checkpoints', f'epoch_{epoch}.pth') + torch.save(checkpoint, epoch_path) + + + def save_checkpoint_step(self, step): + """Save model checkpoint (step-based)""" + if not self.is_main_process(): + return + + # Get current epoch from training loop + current_epoch = step // len(self.train_loader) + + model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict() + + # Get current learning rates for verification + current_lr_g = self.scheduler_g.get_last_lr()[0] + + checkpoint = { + 'epoch': current_epoch, + 'global_step': step, + 'model_state_dict': model_state, + 'optimizer_g_state_dict': self.optimizer_g.state_dict(), + 'scheduler_g_state_dict': self.scheduler_g.state_dict(), + 'scheduler_g_last_epoch': self.scheduler_g.last_epoch, # Explicitly save this + 'current_lr_g': current_lr_g, # Save for verification + 'config': self.config, + 'args': self.args + } + + # Save gradient scaler states if using mixed precision + if self.scaler_g is not None: + checkpoint['scaler_g_state_dict'] = self.scaler_g.state_dict() + + if self.discriminator is not None: + disc_state = self.discriminator.module.state_dict() if self.distributed else self.discriminator.state_dict() + current_lr_d = self.scheduler_d.get_last_lr()[0] + checkpoint['discriminator_state_dict'] = disc_state + checkpoint['optimizer_d_state_dict'] = self.optimizer_d.state_dict() + checkpoint['scheduler_d_state_dict'] = self.scheduler_d.state_dict() + checkpoint['scheduler_d_last_epoch'] = self.scheduler_d.last_epoch + checkpoint['current_lr_d'] = current_lr_d + + if self.scaler_d is not None: + checkpoint['scaler_d_state_dict'] = self.scaler_d.state_dict() + + # Create checkpoint directory if it doesn't exist + checkpoint_dir = os.path.join(self.args.output_dir, 'checkpoints') + os.makedirs(checkpoint_dir, exist_ok=True) + + # Save step-based checkpoint + step_path = os.path.join(self.args.output_dir, 'checkpoints', f'step_{step}.pth') + torch.save(checkpoint, step_path) + + # Also update latest checkpoint + latest_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth') + torch.save(checkpoint, latest_path) + + # Keep only the last N step-based checkpoints to save disk space + if self.args.keep_last_n_steps > 0: + checkpoint_dir = os.path.join(self.args.output_dir, 'checkpoints') + step_checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith('step_')]) + if len(step_checkpoints) > self.args.keep_last_n_steps: + for old_checkpoint in step_checkpoints[:-self.args.keep_last_n_steps]: + os.remove(os.path.join(checkpoint_dir, old_checkpoint)) + + + def load_checkpoint(self): + """Load checkpoint with proper state restoration""" + checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth') + if os.path.exists(checkpoint_path): + print(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) + + # Load model state + if self.distributed: + self.model.module.load_state_dict(checkpoint['model_state_dict']) + else: + self.model.load_state_dict(checkpoint['model_state_dict']) + + # Load optimizer state + self.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict']) + + # Load scheduler state + self.scheduler_g.load_state_dict(checkpoint['scheduler_g_state_dict']) + + # Restore scheduler's last_epoch from checkpoint + if 'scheduler_g_last_epoch' in checkpoint: + self.scheduler_g.last_epoch = checkpoint['scheduler_g_last_epoch'] + else: + # Fallback: use global_step if the explicit value wasn't saved + self.scheduler_g.last_epoch = checkpoint['global_step'] + + # Force scheduler to recompute its internal state + self.scheduler_g._last_lr = self.scheduler_g.get_lr() + + # Load gradient scaler state if using mixed precision + if self.scaler_g is not None and 'scaler_g_state_dict' in checkpoint: + self.scaler_g.load_state_dict(checkpoint['scaler_g_state_dict']) + + # Load discriminator if present + if self.discriminator is not None and 'discriminator_state_dict' in checkpoint: + if self.distributed: + self.discriminator.module.load_state_dict(checkpoint['discriminator_state_dict']) + else: + self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) + self.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict']) + self.scheduler_d.load_state_dict(checkpoint['scheduler_d_state_dict']) + + # Restore discriminator scheduler's last_epoch + if 'scheduler_d_last_epoch' in checkpoint: + self.scheduler_d.last_epoch = checkpoint['scheduler_d_last_epoch'] + else: + self.scheduler_d.last_epoch = checkpoint['global_step'] + + self.scheduler_d._last_lr = self.scheduler_d.get_lr() + + # Load discriminator gradient scaler state if using mixed precision + if self.scaler_d is not None and 'scaler_d_state_dict' in checkpoint: + self.scaler_d.load_state_dict(checkpoint['scaler_d_state_dict']) + + # Restore training state + self.start_epoch = checkpoint['epoch'] + 1 + self.global_step = checkpoint['global_step'] + + # Verify learning rate restoration + current_lr_g = self.scheduler_g.get_last_lr()[0] + saved_lr_g = checkpoint.get('current_lr_g', None) + + print(f"\n{'='*60}") + print(f"CHECKPOINT LOADED SUCCESSFULLY") + print(f"{'='*60}") + print(f"Resumed from epoch: {checkpoint['epoch']}") + print(f"Global step: {self.global_step}") + print(f"Scheduler last_epoch: {self.scheduler_g.last_epoch}") + print(f"Current learning rate (generator): {current_lr_g:.9f}") + print(f"Mixed precision: {'ENABLED' if self.args.use_mixed_precision else 'DISABLED'}") + if saved_lr_g is not None: + print(f"Saved learning rate (generator): {saved_lr_g:.9f}") + if abs(current_lr_g - saved_lr_g) > 1e-9: + print("⚠️ WARNING: Learning rate mismatch! This might indicate improper state restoration.") + + if self.discriminator is not None: + current_lr_d = self.scheduler_d.get_last_lr()[0] + saved_lr_d = checkpoint.get('current_lr_d', None) + print(f"Current learning rate (discriminator): {current_lr_d:.9f}") + if saved_lr_d is not None: + print(f"Saved learning rate (discriminator): {saved_lr_d:.9f}") + print(f"Discriminator status: {'ACTIVE' if self.global_step >= self.args.discriminator_start_step else f'INACTIVE (starts at step {self.args.discriminator_start_step})'}") + + print(f"Next epoch: {self.start_epoch}") + print(f"Next step checkpoint at: step {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}") + print(f"{'='*60}\n") + + # Double-check by creating a fresh scheduler and comparing + if self.global_step > 0: + temp_scheduler = CosineWarmupScheduler( + self.optimizer_g, + self.args.warmup_steps, + self.total_steps, + eta_min=1e-6, + last_epoch=-1 + ) + # Step it to the current global step + for _ in range(self.global_step): + temp_scheduler.step() + expected_lr = temp_scheduler.get_last_lr()[0] + if abs(current_lr_g - expected_lr) > 1e-9: + print(f"⚠️ Learning rate verification failed!") + print(f" Expected: {expected_lr:.9f}") + print(f" Got: {current_lr_g:.9f}") + print(" The scheduler state might not be properly restored.") + else: + print(f"No checkpoint found at {checkpoint_path}, starting from scratch") + + def train(self): + """Main training loop""" + best_val_loss = float('inf') + + # Print training configuration + if self.is_main_process(): + print(f"\n{'='*50}") + print(f"Training Configuration:") + print(f"{'='*50}") + print(f"Total epochs: {self.args.num_epochs}") + print(f"Steps per epoch: {len(self.train_loader)}") + print(f"Total steps: {self.total_steps}") + print(f"Warmup steps: {self.args.warmup_steps}") + print(f"Mixed precision training: {'ENABLED (bfloat16)' if self.args.use_mixed_precision else 'DISABLED'}") + print(f"Discriminator starts at step: {self.args.discriminator_start_step}") + print(f"Checkpoint saving:") + print(f" - Every {self.args.save_interval} epochs") + print(f" - Every {self.args.save_step_interval} steps") + print(f" - Keep last {self.args.keep_last_n_steps} step checkpoints") + if self.start_epoch > 0: + print(f"RESUMING from epoch {self.start_epoch}, step {self.global_step}") + print(f"{'='*50}\n") + + for epoch in range(self.start_epoch, self.args.num_epochs): + # IMPORTANT: Set the epoch for distributed sampler when resuming + # This ensures proper data shuffling across epochs + if self.distributed and hasattr(self.train_loader.sampler, 'set_epoch'): + self.train_loader.sampler.set_epoch(epoch) + + # Train + train_metrics = self.train_epoch(epoch) + + # Validate + val_metrics = self.validate(epoch) + + # Log epoch metrics + if self.is_main_process(): + print(f"\nEpoch {epoch} Summary:") + print(f"Train - Total: {train_metrics['total']:.4f}, Rec: {train_metrics['rec']:.4f}, " + f"STFT: {train_metrics['stft']:.4f}, Mel: {train_metrics['mel']:.4f}, " + f"Commit: {train_metrics['commit']:.4f}, Semantic: {train_metrics['semantic']:.4f}") + if self.discriminator is not None: + print(f" Gen: {train_metrics['gen']:.4f}, Feat: {train_metrics['feat']:.4f}, " + f"Disc: {train_metrics['disc']:.4f}") + print(f" Discriminator Status: {'Active' if self.global_step >= self.args.discriminator_start_step else f'Starting at step {self.args.discriminator_start_step}'}") + print(f"Val - Total: {val_metrics['total']:.4f}, Rec: {val_metrics['rec']:.4f}, " + f"STFT: {val_metrics['stft']:.4f}, Mel: {val_metrics['mel']:.4f}, " + f"Commit: {val_metrics['commit']:.4f}, Semantic: {val_metrics['semantic']:.4f}") + print(f"Current Step: {self.global_step}, Next step checkpoint at: {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}") + print(f"Current LR: {self.scheduler_g.get_last_lr()[0]:.9f}") + + # Save checkpoint + is_best = val_metrics['total'] < best_val_loss + if is_best: + best_val_loss = val_metrics['total'] + self.save_checkpoint(epoch, is_best) + + # Save final model + if self.is_main_process(): + model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict() + + final_path = os.path.join(self.args.output_dir, 'checkpoints', 'final.pth') + torch.save({ + 'model_state_dict': model_state, + 'config': self.config + }, final_path) + + # Also save just the model weights in the format expected by the original code + model_only_path = os.path.join(self.args.output_dir, 'model.pth') + torch.save(model_state, model_only_path) + + # Copy config + import shutil + shutil.copy(self.args.config, os.path.join(self.args.output_dir, 'config.json')) + + # Cleanup + if self.is_main_process(): + self.writer.close() + if self.distributed: + dist.destroy_process_group() + + +def main(): + parser = argparse.ArgumentParser(description='Train Boson Audio Codec') + + # Data arguments + parser.add_argument('--data_csv', type=str, required=True, + help='Path to CSV file containing audio file paths') + parser.add_argument('--config', type=str, default='config.json', + help='Path to config JSON file') + + # Training arguments + parser.add_argument('--batch_size', type=int, default=28, + help='Batch size per GPU') + parser.add_argument('--num_epochs', type=int, default=100, + help='Number of training epochs') + parser.add_argument('--learning_rate', type=float, default=1e-4, + help='Initial learning rate') + parser.add_argument('--weight_decay', type=float, default=0.01, + help='Weight decay') + parser.add_argument('--segment_duration', type=float, default=2., + help='Audio segment duration in seconds') + + # Mixed precision training + parser.add_argument('--use_mixed_precision', action='store_true', + help='Use bfloat16 mixed precision training') + + # Scheduler arguments + parser.add_argument('--warmup_steps', type=int, default=5000, + help='Number of warmup steps for cosine scheduler') + + # Loss arguments + parser.add_argument('--use_discriminator', action='store_true', + help='Use adversarial training with discriminator') + parser.add_argument('--discriminator_start_step', type=int, default=25_000, + help='Start training discriminator after N steps') + parser.add_argument('--disc_interval', type=int, default=1, + help='Train discriminator every N steps') + + # System arguments + parser.add_argument('--output_dir', type=str, default='outputs_mp_cqt', + help='Output directory for checkpoints and logs') + parser.add_argument('--num_workers', type=int, default=16, + help='Number of data loading workers') + parser.add_argument('--seed', type=int, default=42, + help='Random seed') + parser.add_argument('--local_rank', type=int, default=0, + help='Local rank for distributed training') + + # Logging arguments + parser.add_argument('--log_interval', type=int, default=10, + help='Log every N steps') + parser.add_argument('--save_interval', type=int, default=1, + help='Save checkpoint every N epochs') + parser.add_argument('--save_step_interval', type=int, default=1000, + help='Save checkpoint every N steps') + parser.add_argument('--keep_last_n_steps', type=int, default=5, + help='Keep only the last N step-based checkpoints (0 to keep all)') + + # Resume training + parser.add_argument('--resume', action='store_true', + help='Resume training from latest checkpoint') + + args = parser.parse_args() + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Train + trainer = BosonTrainer(args) + trainer.train() + + +if __name__ == '__main__': + torch.set_float32_matmul_precision('high') + main() \ No newline at end of file