Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Utility script to perform net surgery on a model. | |
| This script will perform net surgery on DeepLab models trained on a source | |
| dataset and create a new checkpoint for the target dataset. | |
| """ | |
| from typing import Any, Dict, Text, Tuple | |
| from absl import app | |
| from absl import flags | |
| from absl import logging | |
| import numpy as np | |
| import tensorflow as tf | |
| from google.protobuf import text_format | |
| from deeplab2 import common | |
| from deeplab2 import config_pb2 | |
| from deeplab2.data import dataset | |
| from deeplab2.model import deeplab | |
| FLAGS = flags.FLAGS | |
| flags.DEFINE_string('source_dataset', 'cityscapes', | |
| 'Dataset name on which the model has been pretrained. ' | |
| 'Supported datasets: `cityscapes`.') | |
| flags.DEFINE_string('target_dataset', 'motchallenge_step', | |
| 'Dataset name for conversion. Supported datasets: ' | |
| '`motchallenge_step`.') | |
| flags.DEFINE_string('input_config_path', None, | |
| 'Path to a config file that defines the DeepLab model and ' | |
| 'the checkpoint path.') | |
| flags.DEFINE_string('output_checkpoint_path', None, | |
| 'Output filename for the generated checkpoint file.') | |
| _SUPPORTED_SOURCE_DATASETS = {'cityscapes'} | |
| _SUPPORTED_TARGET_DATASETS = {'motchallenge_step'} | |
| _CITYSCAPES_TO_MOTCHALLENGE_STEP = ( | |
| 1, # sidewalk | |
| 2, # building | |
| 8, # vegetation | |
| 10, # sky | |
| 11, # pedestrian | |
| 12, # rider | |
| 18, # bicycle | |
| ) | |
| _DATASET_TO_INFO = { | |
| 'cityscapes': dataset.CITYSCAPES_PANOPTIC_INFORMATION, | |
| 'motchallenge_step': dataset.MOTCHALLENGE_STEP_INFORMATION, | |
| } | |
| _INPUT_SIZE = (1025, 2049, 3) | |
| def _load_model( | |
| config_path: Text, | |
| source_dataset: Text) -> Tuple[deeplab.DeepLab, | |
| config_pb2.ExperimentOptions]: | |
| """Load DeepLab model based on config and dataset.""" | |
| options = config_pb2.ExperimentOptions() | |
| with tf.io.gfile.GFile(config_path) as f: | |
| text_format.Parse(f.read(), options) | |
| options.model_options.panoptic_deeplab.semantic_head.output_channels = ( | |
| _DATASET_TO_INFO[source_dataset].num_classes) | |
| model = deeplab.DeepLab(options, | |
| _DATASET_TO_INFO[source_dataset]) | |
| return model, options | |
| def _convert_bias(input_tensor: np.ndarray, | |
| label_list: Tuple[int, ...]) -> np.ndarray: | |
| """Converts 1D tensor bias w.r.t. label list. | |
| We select the subsets from the input_tensor based on the label_list. | |
| We assume input_tensor has shape = [num_classes], where | |
| input_tensor is the bias weights trained on source dataset, and num_classes | |
| is the number of classes in source dataset. | |
| Args: | |
| input_tensor: A numpy array with ndim == 1. | |
| label_list: A tuple of labels used for net surgery. | |
| Returns: | |
| A numpy array with values modified. | |
| Raises: | |
| ValueError: input_tensor's ndim != 1. | |
| """ | |
| if input_tensor.ndim != 1: | |
| raise ValueError('The bias tensor should have ndim == 1.') | |
| num_elements = len(label_list) | |
| output_tensor = np.zeros(num_elements, dtype=np.float32) | |
| for i, label in enumerate(label_list): | |
| output_tensor[i] = input_tensor[label] | |
| return output_tensor | |
| def _convert_kernels(input_tensor: np.ndarray, | |
| label_list: Tuple[int, ...]) -> np.ndarray: | |
| """Converts 4D tensor kernels w.r.t. label list. | |
| We select the subsets from the input_tensor based on the label_list. | |
| We assume input_tensor has shape = [h, w, input_dim, num_classes], where | |
| input_tensor is the kernel weights trained on source dataset, and num_classes | |
| is the number of classes in source dataset. | |
| Args: | |
| input_tensor: A numpy array with ndim == 4. | |
| label_list: A tuple of labels used for net surgery. | |
| Returns: | |
| A numpy array with values modified. | |
| Raises: | |
| ValueError: input_tensor's ndim != 4. | |
| """ | |
| if input_tensor.ndim != 4: | |
| raise ValueError('The kernels tensor should have ndim == 4.') | |
| num_elements = len(label_list) | |
| kernel_height, kernel_width, input_dim, _ = input_tensor.shape | |
| output_tensor = np.zeros( | |
| (kernel_height, kernel_width, input_dim, num_elements), dtype=np.float32) | |
| for i, label in enumerate(label_list): | |
| output_tensor[:, :, :, i] = input_tensor[:, :, :, label] | |
| return output_tensor | |
| def _restore_checkpoint(restore_dict: Dict[Any, Any], | |
| options: config_pb2.ExperimentOptions | |
| ) -> tf.train.Checkpoint: | |
| """Reads the provided dict items from the checkpoint specified in options. | |
| Args: | |
| restore_dict: A mapping of checkpoint item to location. | |
| options: A experiment configuration containing the checkpoint location. | |
| Returns: | |
| The loaded checkpoint. | |
| """ | |
| ckpt = tf.train.Checkpoint(**restore_dict) | |
| if tf.io.gfile.isdir(options.model_options.initial_checkpoint): | |
| path = tf.train.latest_checkpoint( | |
| options.model_options.initial_checkpoint) | |
| status = ckpt.restore(path) | |
| else: | |
| status = ckpt.restore(options.model_options.initial_checkpoint) | |
| status.expect_partial().assert_existing_objects_matched() | |
| return ckpt | |
| def main(_) -> None: | |
| if FLAGS.source_dataset not in _SUPPORTED_SOURCE_DATASETS: | |
| raise ValueError('Source dataset is not supported. Use --help to get list ' | |
| 'of supported datasets.') | |
| if FLAGS.target_dataset not in _SUPPORTED_TARGET_DATASETS: | |
| raise ValueError('Target dataset is not supported. Use --help to get list ' | |
| 'of supported datasets.') | |
| logging.info('Loading DeepLab model from config %s', FLAGS.input_config_path) | |
| source_model, options = _load_model(FLAGS.input_config_path, | |
| FLAGS.source_dataset) | |
| logging.info('Load pretrained checkpoint.') | |
| _restore_checkpoint(source_model.checkpoint_items, options) | |
| source_model(tf.keras.Input(_INPUT_SIZE), training=False) | |
| logging.info('Perform net surgery.') | |
| semantic_weights = ( | |
| source_model._decoder._semantic_head.final_conv.get_weights()) # pylint: disable=protected-access | |
| if (FLAGS.source_dataset == 'cityscapes' and | |
| FLAGS.target_dataset == 'motchallenge_step'): | |
| # Kernels. | |
| semantic_weights[0] = _convert_kernels(semantic_weights[0], | |
| _CITYSCAPES_TO_MOTCHALLENGE_STEP) | |
| # Bias. | |
| semantic_weights[1] = _convert_bias(semantic_weights[1], | |
| _CITYSCAPES_TO_MOTCHALLENGE_STEP) | |
| logging.info('Load target model without last semantic layer.') | |
| target_model, _ = _load_model(FLAGS.input_config_path, FLAGS.target_dataset) | |
| restore_dict = target_model.checkpoint_items | |
| del restore_dict[common.CKPT_SEMANTIC_LAST_LAYER] | |
| ckpt = _restore_checkpoint(restore_dict, options) | |
| target_model(tf.keras.Input(_INPUT_SIZE), training=False) | |
| target_model._decoder._semantic_head.final_conv.set_weights(semantic_weights) # pylint: disable=protected-access | |
| logging.info('Save checkpoint to output path: %s', | |
| FLAGS.output_checkpoint_path) | |
| ckpt = tf.train.Checkpoint(**target_model.checkpoint_items) | |
| ckpt.save(FLAGS.output_checkpoint_path) | |
| if __name__ == '__main__': | |
| flags.mark_flags_as_required( | |
| ['input_config_path', 'output_checkpoint_path']) | |
| app.run(main) | |