Spaces:
Sleeping
Sleeping
| import torchvision.transforms as T | |
| class AugmentationPipeline: | |
| """ | |
| Data augmentation and preprocessing transformations for CropGuard. | |
| """ | |
| def __init__(self): | |
| # Mean and Std from ImageNet (can be adjusted later if needed) | |
| self.mean = [0.485, 0.456, 0.406] | |
| self.std = [0.229, 0.224, 0.225] | |
| # Define transformations | |
| self.train_transforms = T.Compose([ | |
| T.RandomHorizontalFlip(p=0.5), | |
| T.RandomVerticalFlip(p=0.5), | |
| T.RandomRotation(degrees=30), | |
| T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
| T.ToTensor(), | |
| T.Normalize(mean=self.mean, std=self.std) | |
| ]) | |
| self.val_transforms = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize(mean=self.mean, std=self.std) | |
| ]) | |
| self.test_transforms = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize(mean=self.mean, std=self.std) | |
| ]) | |
| def get_transforms(self, phase="train"): | |
| """ | |
| Returns the appropriate transformation based on phase. | |
| """ | |
| if phase == "train": | |
| return self.train_transforms | |
| elif phase == "val": | |
| return self.val_transforms | |
| elif phase == "test": | |
| return self.test_transforms | |
| else: | |
| raise ValueError(f"Unknown phase: {phase}. Use 'train', 'val', or 'test'.") |