Spaces:
Runtime error
Runtime error
| import torch | |
| try: | |
| from apex.normalization import FusedLayerNorm | |
| except ImportError as e: | |
| try: | |
| from xformers.triton import FusedLayerNorm | |
| except ImportError as e: | |
| FusedLayerNorm = None | |
| def replace_all_layernorms(model): | |
| if FusedLayerNorm is None: | |
| print("WARNING: apex.normalization & xformers.triton.FusedLayerNorm is not found, \ | |
| skip using FusedLayerNorm") | |
| return model | |
| for name, module in model.named_children(): | |
| if isinstance(module, torch.nn.LayerNorm): | |
| setattr(model, name, FusedLayerNorm( | |
| module.normalized_shape, module.eps, module.elementwise_affine)) | |
| else: | |
| replace_all_layernorms(module) | |
| return model | |
| def replace_all_groupnorms(model): | |
| try: | |
| from apex.contrib.group_norm import GroupNorm | |
| except ImportError as e: | |
| print("WARNING: apex.contrib.group_norm is not found, skip using apex groupnorm") | |
| return model | |
| for name, module in model.named_children(): | |
| if isinstance(module, torch.nn.GroupNorm): | |
| setattr(model, name, GroupNorm( | |
| module.num_groups, module.num_channels, | |
| eps=module.eps, affine=module.affine)) | |
| else: | |
| replace_all_groupnorms(module) | |
| return model |