File size: 20,712 Bytes
7245cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
import torch 
import math
import random
import numpy as np
from PIL import Image

def random_insert_latent_frame(
    image_latent: torch.Tensor,
    noisy_model_input: torch.Tensor,
    target_latents: torch.Tensor,
    input_intervals: torch.Tensor,
    output_intervals: torch.Tensor,
    special_info
):
    """
    Inserts latent frames into noisy input, pads targets, and builds flattened intervals with flags.

    Args:
        image_latent:     [B, latent_count, C, H, W]
        noisy_model_input:[B, F, C, H, W]
        target_latents:   [B, F, C, H, W]
        input_intervals:  [B, N, frames_per_latent, L]
        output_intervals: [B, M, frames_per_latent, L]

    For each sample randomly choose:
    Mode A (50%):
        - Insert two image_latent frames at start of noisy input and targets.
        - Pad target_latents by prepending two zero-frames.
        - Pad input_intervals by repeating its last group once.
    Mode B (50%):
        - Insert one image_latent frame at start and repeat last noisy frame at end.
        - Pad target_latents by prepending one one-frame and appending last target frame.
        - Pad output_intervals by repeating its last group once.

    After padding intervals, flatten each group from [frames_per_latent, L] to [frames_per_latent * L],
    then append a 4-element flag (1 for input groups, 0 for output groups).

    Returns:
        outputs:     Tensor [B, F+2, C, H, W]
        new_targets: Tensor [B, F+2, C, H, W]
        masks:       Tensor [B, F+2] bool mask of latent inserts
        intervals:   Tensor [B, N+M+1, fpl * L + 4]
    """
    B, F, C, H, W = noisy_model_input.shape
    _, N, fpl, L = input_intervals.shape
    _, M, _, _ = output_intervals.shape
    device = noisy_model_input.device

    new_F = F + 1 if special_info == "just_one" else F + 2
    outputs = torch.empty((B, new_F, C, H, W), device=device)
    masks = torch.zeros((B, new_F), dtype=torch.bool, device=device)
    combined_groups = N + M #+ 1
    feature_len = fpl * L
    # intervals = torch.empty((B, combined_groups, feature_len + 4), device=device,
    #                         dtype=input_intervals.dtype)
    intervals = torch.empty((B, combined_groups, feature_len), device=device,
                            dtype=input_intervals.dtype)
    new_targets = torch.empty((B, new_F, C, H, W), device=device,
                            dtype=target_latents.dtype)

    for b in range(B):
        latent = image_latent[b, 0]
        frames = noisy_model_input[b]
        tgt = target_latents[b]

        limit = 10 if special_info == "use_a" else 0.5
        if special_info == "just_one": #ALWAYS_MODE_A
            # Mode A: two latent inserts, zero-prefixed targets
            outputs[b, 0] = latent
            masks[b, :1] = True
            outputs[b, 1:] = frames

            # pad targets: two large-numbers - these should be ignored
            large_number = torch.ones_like(tgt[0])*10000
            new_targets[b, 0] = large_number
            new_targets[b, 1:] = tgt

            # pad intervals: input + replicated last input group
            #pad_group = input_intervals[b, -1:].clone()
            in_groups = input_intervals[b] #torch.cat([input_intervals[b], pad_group], dim=0)
            out_groups = output_intervals[b]
        elif random.random() < limit: #ALWAYS_MODE_A
            # Mode A: two latent inserts, zero-prefixed targets
            outputs[b, 0] = latent
            outputs[b, 1] = latent
            masks[b, :2] = True
            outputs[b, 2:] = frames

            # pad targets: two large-numbers - these should be ignored
            large_number = torch.ones_like(tgt[0])*10000
            new_targets[b, 0] = large_number
            new_targets[b, 1] = large_number
            new_targets[b, 2:] = tgt

            # pad intervals: input + replicated last input group
            pad_group = input_intervals[b, -1:].clone()
            in_groups = torch.cat([input_intervals[b], pad_group], dim=0)
            out_groups = output_intervals[b]
        else:
            # Mode B: one latent insert & last-frame repeat, one-prefixed/appended targets
            outputs[b, 0] = latent
            masks[b, 0] = True
            outputs[b, 1:new_F-1] = frames
            outputs[b, new_F-1] = frames[-1]

            # pad targets: one one-frame then original then last frame
            zero = torch.zeros_like(tgt[0])
            new_targets[b, 0] = zero
            new_targets[b, 1:new_F-1] = tgt
            new_targets[b, new_F-1] = tgt[-1]

            # pad intervals: output + replicated last output group
            in_groups = input_intervals[b]
            pad_group = output_intervals[b, -1:].clone()
            out_groups = torch.cat([output_intervals[b], pad_group], dim=0)

        # flatten & flag groups
        flat_in = in_groups.reshape(-1, feature_len)
        proc_in = torch.cat([flat_in], dim=1)

        flat_out = out_groups.reshape(-1, feature_len)
        proc_out = torch.cat([flat_out], dim=1)

        intervals[b] = torch.cat([proc_in, proc_out], dim=0)

    return outputs, new_targets, masks, intervals




def transform_intervals(
    intervals: torch.Tensor,
    frames_per_latent: int = 4,
    repeat_first: bool = True
) -> torch.Tensor:
    """
    Pad and reshape intervals into [B, num_latent_frames, frames_per_latent, L].

    Args:
        intervals: Tensor of shape [B, N, L]
        frames_per_latent: number of frames per latent group (e.g., 4)
        repeat_first: if True, pad at the beginning by repeating the first row; otherwise pad at the end by repeating the last row.

    Returns:
        Tensor of shape [B, num_latent_frames, frames_per_latent, L]
    """
    B, N, L = intervals.shape
    num_latent = math.ceil(N / frames_per_latent)
    target_N = num_latent * frames_per_latent
    pad_count = target_N - N

    if pad_count > 0:
        # choose row to repeat
        pad_row = intervals[:, :1, :] if repeat_first else intervals[:, -1:, :]
        # replicate pad_row pad_count times
        pad = pad_row.repeat(1, pad_count, 1)
        # pad at beginning or end
        if repeat_first:
            expanded = torch.cat([pad, intervals], dim=1)
        else:
            expanded = torch.cat([intervals, pad], dim=1)
    else:
        expanded = intervals[:, :target_N, :]

    # reshape into latent-frame groups
    return expanded.view(B, num_latent, frames_per_latent, L)

import random
import numpy as np
import torch
from PIL import Image


import random
import numpy as np
import torch
from PIL import Image


def build_blur(frame_paths, gamma=2.2):
    """
    Simulate motion blur using inverse-gamma (linear-light) summation:
    - Load each image, convert to float32 sRGB [0,255]
    - Linearize via inverse gamma: linear = (img/255)^gamma
    - Sum linear values, average, then re-encode via gamma: (linear_avg)^(1/gamma)*255
    Returns a uint8 numpy array.
    """
    acc_lin = None
    for p in frame_paths:
        img = np.array(Image.open(p).convert('RGB'), dtype=np.float32)
        # normalize to [0,1] then linearize
        lin = np.power(img / 255.0, gamma)
        acc_lin = lin if acc_lin is None else acc_lin + lin
    # average in linear domain
    avg_lin = acc_lin / len(frame_paths)
    # gamma-encode back to sRGB domain
    srgb = np.power(avg_lin, 1.0 / gamma) * 255.0
    return np.clip(srgb, 0, 255).astype(np.uint8)

def generate_1x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1, start = None):
    """
    1× mode at arbitrary base_rate (units of 1/240s):
      - Treat each output step as the sum of `base_rate` consecutive raw frames.
      - Pick window size W ∈ [1, output_len]
      - Randomly choose start index so W*base_rate frames fit
      - Group raw frames into W groups of length base_rate
      - Build blur image over all W*base_rate frames for input
      - For each group, build a blurred output frame by summing its base_rate frames
      - Pad sequence of W blurred frames to output_len by repeating last blurred frame
      - Input interval always [-0.5, 0.5]
      - Output intervals reflect each group’s coverage within [-0.5,0.5]
    """
    N = len(frame_paths)
    max_w = min(output_len, N // base_rate)
    max_w = min(max_w, window_max)
    W = random.randint(1, max_w)
    if start is not None:
        # choose start so that W*base_rate frames fit
        assert N >= W * base_rate, f"Not enough frames for base_rate={base_rate}, need {W * base_rate}, got {N}"
    else:
        start = random.randint(0, N - W * base_rate)
        

    # group start indices
    group_starts = [start + i * base_rate for i in range(W)]
    # flatten raw frame paths for blur input
    blur_paths = []
    for gs in group_starts:
        blur_paths.extend(frame_paths[gs:gs + base_rate])
    blur_img = build_blur(blur_paths)

    # build blurred output frames per group
    seq = []
    for gs in group_starts:
        group = frame_paths[gs:gs + base_rate]
        seq.append(build_blur(group))
    # pad with last blurred frame
    seq += [seq[-1]] * (output_len - len(seq))

    input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
    # each group covers interval of length 1/W
    step = 1.0 / W
    intervals = [[-0.5 + i * step, -0.5 + (i + 1) * step] for i in range(W)]
    num_frames = len(intervals)
    intervals += [intervals[-1]] * (output_len - W)
    output_intervals = torch.tensor(intervals, dtype=torch.float)

    return blur_img, seq, input_interval, output_intervals, num_frames

def generate_2x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1):
    """
    2× mode:
      - Logical window of W output-steps so that 2*W ≤ output_len
      - Raw window spans W*base_rate frames
      - Build blur only over that raw window (flattened) for input
      - before_count = W//2, after_count = W - before_count
      - Define groups for before, during, and after each of length base_rate
      - Build blurred frames for each group
      - Pad sequence of 2*W blurred frames to output_len by repeating last
      - Input interval always [-0.5,0.5]
      - Output intervals relative to window: each group’s center
    """
    N = len(frame_paths)
    max_w = min(output_len // 2, N // base_rate)
    max_w = min(max_w, window_max)
    W = random.randint(1, max_w)
    before_count = W // 2
    after_count = W - before_count
    # choose start so that before and after stay within bounds
    min_start = before_count * base_rate
    max_start = N - (W + after_count) * base_rate
    # ensure we can pick a valid start, else fail
    assert max_start >= min_start, f"Cannot satisfy before/after window for W={W}, base_rate={base_rate}, N={N}"
    start = random.randint(min_start, max_start)


    # window group starts
    window_starts = [start + i * base_rate for i in range(W)]
    # flatten for blur input
    blur_paths = []
    for gs in window_starts:
        blur_paths.extend(frame_paths[gs:gs + base_rate])


    blur_img = build_blur(blur_paths)

    # define before/after group starts
    before_count = W // 2
    after_count = W - before_count
    before_starts = [max(0, start - (i + 1) * base_rate) for i in range(before_count)][::-1]
    after_starts  = [min(N - base_rate, start + W * base_rate + i * base_rate) for i in range(after_count)]

    # all group starts in sequence
    group_starts = before_starts + window_starts + after_starts
    # build blurred frames per group
    seq = []
    for gs in group_starts:
        group = frame_paths[gs:gs + base_rate]
        seq.append(build_blur(group))
    # pad blurred frames to output_len
    seq += [seq[-1]] * (output_len - len(seq))

    input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
    # each group covers 1/(2W) around its center within [-0.5,0.5]
    half = 0.5 / W
    centers = [((gs - start) / (W * base_rate)) - 0.5 + half
               for gs in group_starts]
    intervals = [[c - half, c + half] for c in centers]
    num_frames = len(intervals)
    intervals += [intervals[-1]] * (output_len - len(intervals))
    output_intervals = torch.tensor(intervals, dtype=torch.float)

    return blur_img, seq, input_interval, output_intervals, num_frames


def generate_large_blur_sequence(frame_paths, window_max=16, output_len=17, base_rate=1):
    """
    Large blur mode (fixed output_len=25) with instantaneous outputs:
      - Raw window spans 25 * base_rate consecutive frames
      - Build blur over that full raw window for input
      - For output sequence:
          • Pick 1 raw frame every `base_rate` (group_starts)
          • Each output frame is the instantaneous frame at that raw index
      - Input interval always [-0.5, 0.5]
      - Output intervals reflect each 1-frame slice’s coverage within the blur window,
        leaving gaps between.
    """
    N = len(frame_paths)
    total_raw = window_max * base_rate
    assert N >= total_raw, f"Not enough frames for base_rate={base_rate}, need {total_raw}, got {N}"
    start = random.randint(0, N - total_raw)

    # build blur input over the full raw block
    raw_block = frame_paths[start:start + total_raw]
    blur_img = build_blur(raw_block)

    # output sequence: instantaneous frames at each group_start
    seq = []
    group_starts = [start + i * base_rate for i in range(window_max)]
    for gs in group_starts:
        img = np.array(Image.open(frame_paths[gs]).convert('RGB'), dtype=np.uint8)
        seq.append(img)
     # pad blurred frames to output_len
    seq += [seq[-1]] * (output_len - len(seq))

    # compute intervals for each instantaneous frame:
    # each covers [gs, gs+1) over total_raw, normalized to [-0.5, 0.5]
    intervals = []
    for gs in group_starts:
        t0 = (gs - start) / total_raw - 0.5
        t1 = (gs + 1 - start) / total_raw - 0.5
        intervals.append([t0, t1])
    num_frames = len(intervals)
    intervals += [intervals[-1]] * (output_len - len(intervals))
    output_intervals = torch.tensor(intervals, dtype=torch.float)

    # input interval
    input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)
    return blur_img, seq, input_interval, output_intervals, num_frames

def generate_test_case(frame_paths,
                       window_max=16,
                       output_len=17,
                       in_start=None,
                       in_end=None,
                       out_start=None,
                       out_end = None,
                       center=None,
                       mode="1x",
                       fps=240):
    """
    Generate blurred input + a target sequence + normalized intervals.

    Args:
        frame_paths: list of all frame filepaths
        window_max: number of groups/bins W
        output_len: desired length of the output sequence
        in_start, in_end: integer indices defining the raw window [in_start, in_end)
        mode: one of "1x", "2x", or "lb"
        fps: frames-per-second (only used to override mode=="2x" if fps==120)

    Returns:
        blur_img: np.ndarray of the global blur over the window
        seq: list of np.ndarray, length = output_len (blured groups or raw frames)
        input_interval: torch.Tensor [[-0.5, 0.5]]
        output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5]
    """
    # 1) slice and blur
    raw_paths = frame_paths[in_start:in_end]

    blur_img = build_blur(raw_paths)

    # 2) build the sequence
    # one target per frame
    seq = [
        np.array(Image.open(p).convert("RGB"), dtype=np.uint8)
        for p in frame_paths[out_start:out_end]
    ]

    # 3) compute normalized intervals
    input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)

    # 2) define the normalizer
    def normalize(x, in_start, in_end):
        return (x - in_start) / (in_end - in_start) - 0.5
    
    base_rate = 240 // fps
   
    # 3) define the raw intervals in absolute frame‐indices
    base_rate = 240 // fps
    if mode == "1x":
        assert in_start == out_start and in_end == out_end
        #assert fps == 240, "haven't implemented 120fps in 1x yet"
        W = (out_end - out_start) // base_rate
        # one frame per window
        group_starts = [out_start + i * base_rate for i in range(W)]
        group_ends   = [out_start + (i + 1) * base_rate for i in range(W)]

    elif mode == "2x":
        W = (out_end - out_start) // base_rate
        # every base_rate frames, starting at out_start
        group_starts = [out_start + i * base_rate for i in range(W)]
        group_ends   = [out_start + (i + 1) * base_rate for i in range(W)]

    elif mode == "lb":
        W = (out_end - out_start) // base_rate
        # sparse “key‐frame” windows from the raw input range
        group_starts = [in_start + i * base_rate for i in range(W)]
        group_ends   = [s + 1 for s in group_starts]

    else:
        raise ValueError(f"Unsupported mode: {mode}")

    # --- after mode‐switch, once you have raw group_starts & group_ends ---
    # 4) build a summed video sequence by blurring each interval
    summed_seq = []
    for s, e in zip(group_starts, group_ends):
        # make sure indices lie in [in_start, in_end)
        s_clamped = max(in_start, min(s, in_end-1))
        e_clamped = max(s_clamped+1,    min(e, in_end))
        # sum/blur the frames in [s_clamped:e_clamped)
        summed = build_blur(frame_paths[s_clamped:e_clamped])
        summed_seq.append(summed)

    # pad to output_len
    if len(summed_seq) < output_len:
        summed_seq += [summed_seq[-1]] * (output_len - len(summed_seq))

    # 5) now normalize your intervals as before
    def normalize(x):
        return (x - in_start) / (in_end - in_start) - 0.5

    intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)]
    num_frames = len(intervals)
    if len(intervals) < output_len:
        intervals += [intervals[-1]] * (output_len - len(intervals))
    
    output_intervals = torch.tensor(intervals, dtype=torch.float)

    # final return now also includes summed_seq
    return blur_img, summed_seq, input_interval, output_intervals, seq, num_frames


def get_conditioning(
    output_len=17,
    in_start=None,
    in_end=None,
    out_start=None,
    out_end=None,
    mode="1x",
    fps=240,
):
    """
    Generate normalized intervals conditioning singals. Just like the above function but without
    loading any images (for inference only).

    Args:
        output_len: desired length of the output sequence
        in_start, in_end: integer indices defining the raw window [in_start, in_end)
        mode: one of "1x", "2x", or "lb"
        fps: frames-per-second (only used to override mode=="2x" if fps==120)

    Returns:
        input_interval: torch.Tensor [[-0.5, 0.5]]
        output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5]
    """

    # 3) compute normalized intervals
    input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float)

    # 2) define the normalizer
    def normalize(x, in_start, in_end):
        return (x - in_start) / (in_end - in_start) - 0.5
    
    base_rate = 240 // fps
   
    # 3) define the raw intervals in absolute frame‐indices
    base_rate = 240 // fps
    if mode == "1x":
        assert in_start == out_start and in_end == out_end
        #assert fps == 240, "haven't implemented 120fps in 1x yet"
        W = (out_end - out_start) // base_rate
        # one frame per window
        group_starts = [out_start + i * base_rate for i in range(W)]
        group_ends   = [out_start + (i + 1) * base_rate for i in range(W)]

    elif mode == "2x":
        W = (out_end - out_start) // base_rate
        # every base_rate frames, starting at out_start
        group_starts = [out_start + i * base_rate for i in range(W)]
        group_ends   = [out_start + (i + 1) * base_rate for i in range(W)]

    elif mode == "lb":
        W = (out_end - out_start) // base_rate
        # sparse “key‐frame” windows from the raw input range
        group_starts = [in_start + i * base_rate for i in range(W)]
        group_ends   = [s + 1 for s in group_starts]

    else:
        raise ValueError(f"Unsupported mode: {mode}")

    # 5) now normalize your intervals as before
    def normalize(x):
        return (x - in_start) / (in_end - in_start) - 0.5

    intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)]
    num_frames = len(intervals)
    if len(intervals) < output_len:
        intervals += [intervals[-1]] * (output_len - len(intervals))
    
    output_intervals = torch.tensor(intervals, dtype=torch.float)

    return input_interval, output_intervals, num_frames