razmars commited on
Commit
dcf99e3
·
verified ·
1 Parent(s): 43cc2dd

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +40 -28
modeling_super_linear.py CHANGED
@@ -295,50 +295,62 @@ class SparseNoisyMoE(nn.Module):
295
  def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
296
  n_fft = 128
297
  ker_len =12
298
- if inputs.ndim == 2: # (B, L) → (B, L, 1)
299
- x = inputs.unsqueeze(2)
300
- else: # already (B, L, C)
301
- x = inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
- B, L, C = x.shape
304
  x = x - x.mean(dim=1, keepdim=True)
305
 
306
- # ------------------------------------------------------------------ parameters
307
  if n_fft is None:
308
- n_fft = 1 << (L - 1).bit_length() # next power-of-two ≥ L
309
  if ker_len is None:
310
- ker_len = min(L // 4, 50) # never larger than the signal
311
  ker_half = ker_len // 2
312
 
313
- # ------------------------------------------------------------------ detrend
314
  if con and ker_len > 0:
315
- # (B, L, C) (B, C, L) for conv1d
316
- x_perm = x.permute(0, 2, 1)
317
  ker = torch.ones(1, 1, ker_len, device=x.device) / ker_len
318
  trend = F.conv1d(x_perm, ker, padding="same")
319
- # Clamp boundary copies so we don’t index out of range for short signals
320
- left = min(ker_half, L - 1)
321
- right = min(ker_half, L - 1)
322
- trend[:, :, :left] = trend[:, :, left:left+1]
323
- trend[:, :, -right:] = trend[:, :, -(right+1):-right]
324
- x_detrended = x_perm - trend
325
- x = x_detrended.permute(0, 2, 1) # back to (B, L, C)
326
 
327
- # ------------------------------------------------------------------ FFT
328
- dft = torch.fft.fft(x, n=n_fft, dim=1) / np.sqrt(n_fft)
329
- dft = dft[:, : n_fft // 2, :] # keep positive freqs
330
- I = torch.abs(dft) ** 2 # periodogram
331
 
332
- # ------------------------------------------------------------------ normalise
333
  I_sum = I.sum(dim=1, keepdim=True)
334
- I_sum[I_sum == 0] = 1 # avoid /0
335
  I /= I_sum
336
 
337
- # ------------------------------------------------------------------ squeeze back if original was 2-D
338
- if inputs.ndim == 2:
339
- I = I.squeeze(2)
340
 
341
- return I
 
 
 
342
 
343
 
344
  def fourier_interp_dim1(self,x, target_len: int = 512):
 
295
  def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
296
  n_fft = 128
297
  ker_len =12
298
+ if inputs.ndim == 2: # (B, L)
299
+ B, L = inputs.shape
300
+ C = 1
301
+ x = inputs.unsqueeze(2) # → (B, L, 1)
302
+ time_first = True # time is dim-1
303
+ elif inputs.ndim == 3:
304
+ B, d1, d2 = inputs.shape
305
+ if d1 < d2: # (B, L, C)
306
+ L, C = d1, d2
307
+ x = inputs
308
+ time_first = True
309
+ else: # (B, C, L)
310
+ C, L = d1, d2
311
+ x = inputs.transpose(1, 2) # → (B, L, C)
312
+ time_first = False
313
+ else:
314
+ raise ValueError("Input must be (B,L), (B,L,C) or (B,C,L)")
315
 
316
+ # ---------- centre the signal ----------
317
  x = x - x.mean(dim=1, keepdim=True)
318
 
319
+ # ---------- parameter defaults ----------
320
  if n_fft is None:
321
+ n_fft = 1 << (L - 1).bit_length()
322
  if ker_len is None:
323
+ ker_len = min(L // 4, 50)
324
  ker_half = ker_len // 2
325
 
326
+ # ---------- high-pass detrend ----------
327
  if con and ker_len > 0:
328
+ x_perm = x.permute(0, 2, 1) # (B, C, L)
 
329
  ker = torch.ones(1, 1, ker_len, device=x.device) / ker_len
330
  trend = F.conv1d(x_perm, ker, padding="same")
331
+ left = min(ker_half, L - 1)
332
+ right = min(ker_half, L - 1)
333
+ trend[:, :, :left] = trend[:, :, left:left+1]
334
+ trend[:, :, -right:] = trend[:, :, -(right+1):-right]
335
+ x = (x_perm - trend).permute(0, 2, 1) # back to (B, L, C)
 
 
336
 
337
+ # ---------- FFT ----------
338
+ dft = torch.fft.fft(x, n=n_fft, dim=1) / np.sqrt(n_fft)
339
+ I = (dft[:, : n_fft//2, :]).abs() ** 2
 
340
 
341
+ # ---------- normalise ----------
342
  I_sum = I.sum(dim=1, keepdim=True)
343
+ I_sum[I_sum == 0] = 1
344
  I /= I_sum
345
 
346
+ # ---------- restore original layout ----------
347
+ if inputs.ndim == 2: # wanted (B, … )
348
+ return I.squeeze(2)
349
 
350
+ if time_first: # original was (B, L, C)
351
+ return I # already (B, F, C)
352
+ else: # original was (B, C, L) → (B, C, F)
353
+ return I.transpose(1, 2)
354
 
355
 
356
  def fourier_interp_dim1(self,x, target_len: int = 512):