Spaces:
Running
Running
| import datetime | |
| import itertools | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import random | |
| import numpy as np | |
| from typing import Tuple, List, Dict, Any, Union, Optional | |
| from dataclasses import dataclass | |
| from .dataset import ChatTSTimeRCDPretrainDataset | |
| from .ts_encoder_bi_bias import TimeSeriesEncoder | |
| from .time_rcd_config import TimeRCDConfig, default_config | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| class PretrainBatch: | |
| """Batch structure for pretraining tasks.""" | |
| time_series: torch.Tensor | |
| labels: torch.Tensor | |
| masked_time_series: torch.Tensor | |
| mask_indices: torch.Tensor | |
| class TimeSeriesPretrainModel(nn.Module): | |
| """Model for time series pretraining with masked reconstruction and anomaly detection.""" | |
| def __init__(self, config: TimeRCDConfig): | |
| super().__init__() | |
| self.config = config | |
| # Extract TimeSeriesEncoder parameters from config | |
| ts_config = config.ts_config | |
| self.ts_encoder = TimeSeriesEncoder( | |
| d_model=ts_config.d_model, | |
| d_proj=ts_config.d_proj, | |
| patch_size=ts_config.patch_size, | |
| num_layers=ts_config.num_layers, | |
| num_heads=ts_config.num_heads, | |
| d_ff_dropout=ts_config.d_ff_dropout, | |
| use_rope=ts_config.use_rope, | |
| num_features=ts_config.num_features, | |
| activation=ts_config.activation | |
| ) | |
| # Masked reconstruction head | |
| self.reconstruction_head = nn.Sequential( | |
| nn.Linear(config.ts_config.d_proj, config.ts_config.d_proj * 4), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.ts_config.d_proj * 4, config.ts_config.d_proj * 4), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.ts_config.d_proj * 4, 1) # (B, seq_len, num_features, 1) | |
| ) | |
| # Anomaly detection head | |
| self.anomaly_head = nn.Sequential( | |
| nn.Linear(config.ts_config.d_proj, config.ts_config.d_proj // 2), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.ts_config.d_proj // 2, 2) # (B, seq_len, num_features, 2) for binary classification | |
| ) | |
| def forward(self, time_series: torch.Tensor, mask: Optional[torch.Tensor] = None): | |
| """Forward pass through the encoder.""" | |
| local_embeddings = self.ts_encoder(time_series, mask) | |
| return local_embeddings | |
| def masked_reconstruction_loss(self, | |
| local_embeddings: torch.Tensor, # (B, seq_len, num_features, d_proj) | |
| original_time_series: torch.Tensor, # (B, seq_len, num_features), | |
| mask: torch.Tensor # (B, seq_len) | |
| ) -> torch.Tensor: | |
| """Compute masked reconstruction loss.""" | |
| batch_size, seq_len, num_features = original_time_series.shape | |
| patch_size = self.config.ts_config.patch_size | |
| mask = mask.bool() | |
| # local_embeddings: [B, seq_len, num_features, d_proj] | |
| reconstructed = self.reconstruction_head(local_embeddings) # (B, seq_len, num_features, 1) | |
| reconstructed = reconstructed.view(batch_size, seq_len, num_features) | |
| mask_expanded = mask.unsqueeze(-1).expand(-1, -1, num_features) # (B, seq_len, num_features) | |
| reconstruction_loss = F.mse_loss( | |
| reconstructed[mask_expanded], | |
| original_time_series[mask_expanded] | |
| ) | |
| return reconstruction_loss | |
| def anomaly_detection_loss(self, | |
| local_embeddings: torch.Tensor, # (B, seq_len, num_features, d_proj) | |
| labels: torch.Tensor) -> torch.Tensor: # (B, seq_len) | |
| """Compute anomaly detection loss for each timestep.""" | |
| # Project local embeddings to anomaly scores | |
| logits = self.anomaly_head(local_embeddings) # (B, seq_len, num_features, 2) | |
| logits = torch.mean(logits, dim=-2) # Average over num_features to get (B, seq_len, 2) | |
| # Reshape for loss computation | |
| batch_size, seq_len, _ = logits.shape | |
| logits = logits.view(-1, 2) # (B*seq_len, 2) | |
| labels = labels.view(-1) # (B*seq_len) | |
| labels = (labels > 0.5).long() | |
| # Create mask for valid labels (not padding) | |
| valid_mask = (labels != -1) | |
| # Compute loss only on valid timesteps | |
| if valid_mask.sum() > 0: | |
| anomaly_loss = F.cross_entropy( | |
| logits[valid_mask], | |
| labels[valid_mask] | |
| ) | |
| else: | |
| anomaly_loss = torch.tensor(0.0, device=logits.device) | |
| return anomaly_loss | |
| def create_random_mask(time_series: torch.Tensor, # (B, max_seq_len, num_features) | |
| attention_mask: torch.Tensor, # (B, max_seq_len) | |
| mask_ratio: float = 0.15) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Create random mask for time series patches, only masking valid sequence parts.""" | |
| batch_size, seq_len, num_features = time_series.shape | |
| patch_size = default_config.ts_config.patch_size | |
| mask = torch.zeros(batch_size, seq_len) # (B, max_seq_len) | |
| for i in range(batch_size): | |
| # Get valid sequence length for this sample | |
| valid_length = attention_mask[i].sum().item() | |
| # Calculate number of patches in valid sequence | |
| num_valid_patches = (valid_length - 1) // patch_size + 1 | |
| num_masked = int(num_valid_patches * mask_ratio) | |
| if num_masked > 0: | |
| # Only select patches from valid sequence | |
| masked_patches = torch.randperm(num_valid_patches)[:num_masked] | |
| for j in masked_patches: | |
| start_idx = j * patch_size | |
| end_idx = min((j + 1) * patch_size, valid_length) # Don't exceed valid length | |
| mask[i, start_idx:end_idx] = 1 | |
| # Create masked time series - only mask valid parts | |
| masked_time_series = time_series.clone() | |
| mask_indices = mask.bool() & attention_mask # Only mask where both mask and attention_mask are True | |
| mask_expanded = mask_indices.unsqueeze(-1).expand(-1, -1, num_features) # (B, max_seq_len, num_features) | |
| masked_time_series[mask_expanded] = torch.randn_like(masked_time_series[mask_expanded]) * 0.1 | |
| # Update mask to only include valid parts | |
| mask = mask * attention_mask.float() | |
| return masked_time_series, mask # (B, max_seq_len, num_features), (B, max_seq_len) | |
| def collate_fn(batch): | |
| """Collate function for pretraining dataset.""" | |
| time_series_list, normal_time_series_list, labels_list, attribute_list = zip(*batch) | |
| # Convert to tensors and pad sequences | |
| if time_series_list[0].ndim == 1: | |
| time_series_tensors = [ts.unsqueeze(-1) for ts in time_series_list] # Add feature dimension | |
| normal_time_series_tensors = [nts.unsqueeze(-1) for nts in normal_time_series_list] | |
| else: | |
| time_series_tensors = [ts for ts in time_series_list] | |
| normal_time_series_tensors = [nts for nts in normal_time_series_list] | |
| # standardize time series | |
| concatenated = torch.cat(time_series_tensors, dim=0) # (total_length, num_features) | |
| mean = concatenated.mean(dim=0, keepdim=True) # (1, num_features) | |
| std = concatenated.std(dim=0, keepdim=True) # (1, num_features) | |
| std = std + 1e-4 | |
| time_series_tensors_std = [(ts - mean) / std for ts in time_series_tensors] | |
| normal_time_series_tensors_std = [(nts - mean) / std for nts in normal_time_series_tensors] | |
| time_series_tensors = time_series_tensors_std | |
| normal_time_series_tensors = normal_time_series_tensors_std | |
| # labels_tensor = torch.stack(labels_list) | |
| labels = [label for label in labels_list] | |
| # Pad time series to same length | |
| padded_time_series = torch.nn.utils.rnn.pad_sequence( | |
| time_series_tensors, batch_first=True, padding_value=0.0 | |
| ) # (B, max_seq_len, num_features) | |
| padded_normal_time_series = torch.nn.utils.rnn.pad_sequence( | |
| normal_time_series_tensors, batch_first=True, padding_value=0.0 | |
| ) # (B, max_seq_len, num_features) | |
| padded_labels = torch.nn.utils.rnn.pad_sequence( | |
| labels, batch_first=True, padding_value=-1 | |
| ) # (B, max_seq_len) | |
| sequence_lengths = [ts.size(0) for ts in time_series_tensors] | |
| B, max_seq_len, num_features = padded_time_series.shape | |
| attention_mask = torch.zeros(B, max_seq_len, dtype=torch.bool) # (B, max_seq_len) | |
| for i, length in enumerate(sequence_lengths): | |
| attention_mask[i, :length] = True | |
| # Create random masks for reconstruction task - only mask valid sequence parts | |
| masked_time_series, mask = create_random_mask(padded_time_series, attention_mask) | |
| return { | |
| 'time_series': padded_time_series, | |
| 'normal_time_series': padded_normal_time_series, | |
| 'masked_time_series': masked_time_series, | |
| 'mask': mask, # for reconstruction task | |
| 'labels': padded_labels, | |
| 'attention_mask': attention_mask, # for padding | |
| 'attribute': attribute_list | |
| } | |
| def test_collate_fn(batch): | |
| """Collate function for pretraining dataset.""" | |
| # Unpack the batch correctly - batch is a list of (time_series, mask) tuples | |
| time_series_list, mask_list = zip(*batch) | |
| # Stack into batch format instead of concatenating | |
| # This maintains the batch dimension: (B, seq_len, num_features) | |
| batched_time_series = torch.stack(time_series_list, dim=0) | |
| print(f"batched_time_series shape: {batched_time_series.shape}") | |
| # Stack masks into batch format: (B, seq_len) | |
| batched_mask = torch.stack(mask_list, dim=0) | |
| print(f"batched_mask shape: {batched_mask.shape}") | |
| return { | |
| 'time_series': batched_time_series, | |
| 'attention_mask': batched_mask, # for padding | |
| } |