Update train_boson_mixed_precision.py
Browse files
train_boson_mixed_precision.py
CHANGED
|
@@ -20,14 +20,13 @@ import librosa
|
|
| 20 |
from tqdm import tqdm
|
| 21 |
from audiotools import AudioSignal, STFTParams
|
| 22 |
|
| 23 |
-
|
| 24 |
from higgs_audio_tokenizer import HiggsAudioTokenizer
|
| 25 |
from quantization.distrib import broadcast_tensors, sync_buffer, is_distributed, world_size, rank
|
| 26 |
from quantization.ddp_utils import set_random_seed, is_logging_process, get_timestamp
|
| 27 |
|
| 28 |
-
# Import DAC losses and discriminator
|
| 29 |
import sys
|
| 30 |
-
sys.path.append('.')
|
| 31 |
from loss import L1Loss, MultiScaleSTFTLoss, MelSpectrogramLoss, GANLoss
|
| 32 |
from discriminator import Discriminator
|
| 33 |
|
|
@@ -711,7 +710,7 @@ class BosonTrainer:
|
|
| 711 |
print(f"Loading checkpoint from {checkpoint_path}")
|
| 712 |
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
| 713 |
|
| 714 |
-
|
| 715 |
if self.distributed:
|
| 716 |
self.model.module.load_state_dict(checkpoint['model_state_dict'])
|
| 717 |
else:
|
|
|
|
| 20 |
from tqdm import tqdm
|
| 21 |
from audiotools import AudioSignal, STFTParams
|
| 22 |
|
| 23 |
+
|
| 24 |
from higgs_audio_tokenizer import HiggsAudioTokenizer
|
| 25 |
from quantization.distrib import broadcast_tensors, sync_buffer, is_distributed, world_size, rank
|
| 26 |
from quantization.ddp_utils import set_random_seed, is_logging_process, get_timestamp
|
| 27 |
|
|
|
|
| 28 |
import sys
|
| 29 |
+
sys.path.append('.')
|
| 30 |
from loss import L1Loss, MultiScaleSTFTLoss, MelSpectrogramLoss, GANLoss
|
| 31 |
from discriminator import Discriminator
|
| 32 |
|
|
|
|
| 710 |
print(f"Loading checkpoint from {checkpoint_path}")
|
| 711 |
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
| 712 |
|
| 713 |
+
|
| 714 |
if self.distributed:
|
| 715 |
self.model.module.load_state_dict(checkpoint['model_state_dict'])
|
| 716 |
else:
|