Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Tests for sample_generator.""" | |
| import os | |
| from absl import flags | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| from deeplab2 import common | |
| from deeplab2.data import data_utils | |
| from deeplab2.data import dataset | |
| from deeplab2.data import sample_generator | |
| image_utils = tf.keras.preprocessing.image | |
| flags.DEFINE_string( | |
| 'panoptic_annotation_data', | |
| 'deeplab2/data/testdata/', | |
| 'Path to annotated test image.') | |
| flags.DEFINE_bool('update_golden_data', False, | |
| 'Whether or not to update the golden data for testing.') | |
| FLAGS = flags.FLAGS | |
| _FILENAME_PREFIX = 'dummy_000000_000000' | |
| _IMAGE_FOLDER = 'leftImg8bit/' | |
| _TARGET_FOLDER = 'targets/' | |
| def _get_groundtruth_image(computed_image_array, groundtruth_image_filename): | |
| if FLAGS.update_golden_data: | |
| image = Image.fromarray(tf.squeeze(computed_image_array).numpy()) | |
| with tf.io.gfile.GFile(groundtruth_image_filename, mode='wb') as fp: | |
| image.save(fp) | |
| return computed_image_array | |
| with tf.io.gfile.GFile(groundtruth_image_filename, mode='rb') as fp: | |
| image = data_utils.read_image(fp.read()) | |
| # If loaded image has 3 channels, the returned shape is [height, width, 3]. | |
| # If loaded image has 1 channel, the returned shape is [height, width]. | |
| image = np.squeeze(image_utils.img_to_array(image)) | |
| return image | |
| def _get_groundtruth_array(computed_image_array, groundtruth_image_filename): | |
| if FLAGS.update_golden_data: | |
| with tf.io.gfile.GFile(groundtruth_image_filename, mode='wb') as fp: | |
| np.save(fp, computed_image_array) | |
| return computed_image_array | |
| with tf.io.gfile.GFile(groundtruth_image_filename, mode='rb') as fp: | |
| # If loaded data has C>1 channels, the returned shape is [height, width, C]. | |
| # If loaded data has 1 channel, the returned shape is [height, width]. | |
| array = np.squeeze(np.load(fp)) | |
| return array | |
| class PanopticSampleGeneratorTest(tf.test.TestCase): | |
| def setUp(self): | |
| super().setUp() | |
| self._test_img_data_dir = os.path.join( | |
| FLAGS.test_srcdir, | |
| FLAGS.panoptic_annotation_data, | |
| _IMAGE_FOLDER) | |
| self._test_gt_data_dir = os.path.join( | |
| FLAGS.test_srcdir, | |
| FLAGS.panoptic_annotation_data) | |
| self._test_target_data_dir = os.path.join( | |
| FLAGS.test_srcdir, | |
| FLAGS.panoptic_annotation_data, | |
| _TARGET_FOLDER) | |
| image_path = self._test_img_data_dir + _FILENAME_PREFIX + '_leftImg8bit.png' | |
| with tf.io.gfile.GFile(image_path, 'rb') as image_file: | |
| rgb_image = data_utils.read_image(image_file.read()) | |
| self._rgb_image = tf.convert_to_tensor(np.array(rgb_image)) | |
| label_path = self._test_gt_data_dir + 'dummy_gt_for_vps.png' | |
| with tf.io.gfile.GFile(label_path, 'rb') as label_file: | |
| label = data_utils.read_image(label_file.read()) | |
| self._label = tf.expand_dims(tf.convert_to_tensor( | |
| np.dot(np.array(label), [1, 256, 256 * 256])), -1) | |
| def test_input_generator(self): | |
| tf.random.set_seed(0) | |
| np.random.seed(0) | |
| small_instances = {'threshold': 4096, 'weight': 3.0} | |
| generator = sample_generator.PanopticSampleGenerator( | |
| dataset.CITYSCAPES_PANOPTIC_INFORMATION._asdict(), | |
| focus_small_instances=small_instances, | |
| is_training=True, | |
| crop_size=[769, 769], | |
| thing_id_mask_annotations=True) | |
| input_sample = { | |
| 'image': self._rgb_image, | |
| 'image_name': 'test_image', | |
| 'label': self._label, | |
| 'height': 800, | |
| 'width': 800 | |
| } | |
| sample = generator(input_sample) | |
| self.assertIn(common.IMAGE, sample) | |
| self.assertIn(common.GT_SEMANTIC_KEY, sample) | |
| self.assertIn(common.GT_PANOPTIC_KEY, sample) | |
| self.assertIn(common.GT_INSTANCE_CENTER_KEY, sample) | |
| self.assertIn(common.GT_INSTANCE_REGRESSION_KEY, sample) | |
| self.assertIn(common.GT_IS_CROWD, sample) | |
| self.assertIn(common.GT_THING_ID_MASK_KEY, sample) | |
| self.assertIn(common.GT_THING_ID_CLASS_KEY, sample) | |
| self.assertIn(common.SEMANTIC_LOSS_WEIGHT_KEY, sample) | |
| self.assertIn(common.CENTER_LOSS_WEIGHT_KEY, sample) | |
| self.assertIn(common.REGRESSION_LOSS_WEIGHT_KEY, sample) | |
| self.assertListEqual(sample[common.IMAGE].shape.as_list(), [769, 769, 3]) | |
| self.assertListEqual(sample[common.GT_SEMANTIC_KEY].shape.as_list(), | |
| [769, 769]) | |
| self.assertListEqual(sample[common.GT_PANOPTIC_KEY].shape.as_list(), | |
| [769, 769]) | |
| self.assertListEqual(sample[common.GT_INSTANCE_CENTER_KEY].shape.as_list(), | |
| [769, 769]) | |
| self.assertListEqual( | |
| sample[common.GT_INSTANCE_REGRESSION_KEY].shape.as_list(), | |
| [769, 769, 2]) | |
| self.assertListEqual(sample[common.GT_IS_CROWD].shape.as_list(), [769, 769]) | |
| self.assertListEqual(sample[common.GT_THING_ID_MASK_KEY].shape.as_list(), | |
| [769, 769]) | |
| self.assertListEqual(sample[common.GT_THING_ID_CLASS_KEY].shape.as_list(), | |
| [128]) | |
| self.assertListEqual( | |
| sample[common.SEMANTIC_LOSS_WEIGHT_KEY].shape.as_list(), [769, 769]) | |
| self.assertListEqual(sample[common.CENTER_LOSS_WEIGHT_KEY].shape.as_list(), | |
| [769, 769]) | |
| self.assertListEqual( | |
| sample[common.REGRESSION_LOSS_WEIGHT_KEY].shape.as_list(), | |
| [769, 769]) | |
| gt_sem = sample[common.GT_SEMANTIC_KEY] | |
| gt_pan = sample[common.GT_PANOPTIC_KEY] | |
| gt_center = tf.cast(sample[common.GT_INSTANCE_CENTER_KEY] * 255, tf.uint8) | |
| gt_is_crowd = sample[common.GT_IS_CROWD] | |
| gt_thing_id_mask = sample[common.GT_THING_ID_MASK_KEY] | |
| gt_thing_id_class = sample[common.GT_THING_ID_CLASS_KEY] | |
| image = tf.cast(sample[common.IMAGE], tf.uint8) | |
| # semantic weights can be in range of [0, 3] in this example. | |
| semantic_weights = tf.cast(sample[common.SEMANTIC_LOSS_WEIGHT_KEY] * 85, | |
| tf.uint8) | |
| center_weights = tf.cast(sample[common.CENTER_LOSS_WEIGHT_KEY] * 255, | |
| tf.uint8) | |
| offset_weights = tf.cast(sample[common.REGRESSION_LOSS_WEIGHT_KEY] * 255, | |
| tf.uint8) | |
| np.testing.assert_almost_equal( | |
| image.numpy(), | |
| _get_groundtruth_image( | |
| image, | |
| self._test_target_data_dir + 'rgb_target.png')) | |
| np.testing.assert_almost_equal( | |
| gt_sem.numpy(), | |
| _get_groundtruth_image( | |
| gt_sem, | |
| self._test_target_data_dir + 'semantic_target.png')) | |
| # Save gt as png. Pillow is currently unable to correctly save the image as | |
| # 32bit, but uses 16bit which overflows. | |
| _ = _get_groundtruth_image( | |
| gt_pan, self._test_target_data_dir + 'panoptic_target.png') | |
| np.testing.assert_almost_equal( | |
| gt_pan.numpy(), | |
| _get_groundtruth_array( | |
| gt_pan, | |
| self._test_target_data_dir + 'panoptic_target.npy')) | |
| np.testing.assert_almost_equal( | |
| gt_thing_id_mask.numpy(), | |
| _get_groundtruth_array( | |
| gt_thing_id_mask, | |
| self._test_target_data_dir + 'thing_id_mask_target.npy')) | |
| np.testing.assert_almost_equal( | |
| gt_thing_id_class.numpy(), | |
| _get_groundtruth_array( | |
| gt_thing_id_class, | |
| self._test_target_data_dir + 'thing_id_class_target.npy')) | |
| np.testing.assert_almost_equal( | |
| gt_center.numpy(), | |
| _get_groundtruth_image( | |
| gt_center, | |
| self._test_target_data_dir + 'center_target.png')) | |
| np.testing.assert_almost_equal( | |
| sample[common.GT_INSTANCE_REGRESSION_KEY].numpy(), | |
| _get_groundtruth_array( | |
| sample[common.GT_INSTANCE_REGRESSION_KEY].numpy(), | |
| self._test_target_data_dir + 'offset_target.npy')) | |
| np.testing.assert_array_equal( | |
| gt_is_crowd.numpy(), | |
| _get_groundtruth_array(gt_is_crowd.numpy(), | |
| self._test_target_data_dir + 'is_crowd.npy')) | |
| np.testing.assert_almost_equal( | |
| semantic_weights.numpy(), | |
| _get_groundtruth_image( | |
| semantic_weights, | |
| self._test_target_data_dir + 'semantic_weights.png')) | |
| np.testing.assert_almost_equal( | |
| center_weights.numpy(), | |
| _get_groundtruth_image( | |
| center_weights, | |
| self._test_target_data_dir + 'center_weights.png')) | |
| np.testing.assert_almost_equal( | |
| offset_weights.numpy(), | |
| _get_groundtruth_image( | |
| offset_weights, | |
| self._test_target_data_dir + 'offset_weights.png')) | |
| def test_input_generator_eval(self): | |
| tf.random.set_seed(0) | |
| np.random.seed(0) | |
| small_instances = {'threshold': 4096, 'weight': 3.0} | |
| generator = sample_generator.PanopticSampleGenerator( | |
| dataset.CITYSCAPES_PANOPTIC_INFORMATION._asdict(), | |
| focus_small_instances=small_instances, | |
| is_training=False, | |
| crop_size=[800, 800]) | |
| input_sample = { | |
| 'image': self._rgb_image, | |
| 'image_name': 'test_image', | |
| 'label': self._label, | |
| 'height': 800, | |
| 'width': 800 | |
| } | |
| sample = generator(input_sample) | |
| self.assertIn(common.GT_SEMANTIC_RAW, sample) | |
| self.assertIn(common.GT_PANOPTIC_RAW, sample) | |
| self.assertIn(common.GT_IS_CROWD_RAW, sample) | |
| gt_sem_raw = sample[common.GT_SEMANTIC_RAW] | |
| gt_pan_raw = sample[common.GT_PANOPTIC_RAW] | |
| gt_is_crowd_raw = sample[common.GT_IS_CROWD_RAW] | |
| self.assertListEqual(gt_sem_raw.shape.as_list(), [800, 800]) | |
| self.assertListEqual(gt_pan_raw.shape.as_list(), [800, 800]) | |
| self.assertListEqual(gt_is_crowd_raw.shape.as_list(), [800, 800]) | |
| np.testing.assert_almost_equal( | |
| gt_sem_raw.numpy(), | |
| _get_groundtruth_image( | |
| gt_sem_raw, | |
| self._test_target_data_dir + 'eval_semantic_target.png')) | |
| np.testing.assert_almost_equal( | |
| gt_pan_raw.numpy(), | |
| _get_groundtruth_array( | |
| gt_pan_raw, | |
| self._test_target_data_dir + 'eval_panoptic_target.npy')) | |
| np.testing.assert_almost_equal( | |
| gt_is_crowd_raw.numpy(), | |
| _get_groundtruth_array(gt_is_crowd_raw, self._test_target_data_dir + | |
| 'eval_is_crowd.npy')) | |
| if __name__ == '__main__': | |
| tf.test.main() | |