Respair commited on
Commit
1b8ddf6
·
verified ·
1 Parent(s): 0983f95

Update train_boson_mixed_precision.py

Browse files
Files changed (1) hide show
  1. train_boson_mixed_precision.py +3 -4
train_boson_mixed_precision.py CHANGED
@@ -20,14 +20,13 @@ import librosa
20
  from tqdm import tqdm
21
  from audiotools import AudioSignal, STFTParams
22
 
23
- # Import from the provided codebase
24
  from higgs_audio_tokenizer import HiggsAudioTokenizer
25
  from quantization.distrib import broadcast_tensors, sync_buffer, is_distributed, world_size, rank
26
  from quantization.ddp_utils import set_random_seed, is_logging_process, get_timestamp
27
 
28
- # Import DAC losses and discriminator
29
  import sys
30
- sys.path.append('.') # Add current directory to path
31
  from loss import L1Loss, MultiScaleSTFTLoss, MelSpectrogramLoss, GANLoss
32
  from discriminator import Discriminator
33
 
@@ -711,7 +710,7 @@ class BosonTrainer:
711
  print(f"Loading checkpoint from {checkpoint_path}")
712
  checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
713
 
714
- # Load model state
715
  if self.distributed:
716
  self.model.module.load_state_dict(checkpoint['model_state_dict'])
717
  else:
 
20
  from tqdm import tqdm
21
  from audiotools import AudioSignal, STFTParams
22
 
23
+
24
  from higgs_audio_tokenizer import HiggsAudioTokenizer
25
  from quantization.distrib import broadcast_tensors, sync_buffer, is_distributed, world_size, rank
26
  from quantization.ddp_utils import set_random_seed, is_logging_process, get_timestamp
27
 
 
28
  import sys
29
+ sys.path.append('.')
30
  from loss import L1Loss, MultiScaleSTFTLoss, MelSpectrogramLoss, GANLoss
31
  from discriminator import Discriminator
32
 
 
710
  print(f"Loading checkpoint from {checkpoint_path}")
711
  checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
712
 
713
+
714
  if self.distributed:
715
  self.model.module.load_state_dict(checkpoint['model_state_dict'])
716
  else: