| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Model class for Cifar10 Dataset.""" |
| from __future__ import division, print_function |
| |
| import model_base |
| import tensorflow as tf |
| |
| |
| class ResNetCifar10(model_base.ResNet): |
| """Cifar10 model with ResNetV1 and basic residual block.""" |
| |
| def __init__( |
| self, |
| num_layers, |
| is_training, |
| batch_norm_decay, |
| batch_norm_epsilon, |
| data_format="channels_first", |
| ): |
| super(ResNetCifar10, self).__init__(is_training, data_format, batch_norm_decay, batch_norm_epsilon) |
| self.n = (num_layers - 2) // 6 |
| # Add one in case label starts with 1. No impact if label starts with 0. |
| self.num_classes = 10 + 1 |
| self.filters = [16, 16, 32, 64] |
| self.strides = [1, 2, 2] |
| |
| def forward_pass(self, x, input_data_format="channels_last"): |
| """Build the core model within the graph.""" |
| if self._data_format != input_data_format: |
| if input_data_format == "channels_last": |
| # Computation requires channels_first. |
| x = tf.transpose(x, [0, 3, 1, 2]) |
| else: |
| # Computation requires channels_last. |
| x = tf.transpose(x, [0, 2, 3, 1]) |
| |
| # Image standardization. |
| x = x / 128 - 1 |
| |
| x = self._conv(x, 3, 16, 1) |
| x = self._batch_norm(x) |
| x = self._relu(x) |
| |
| # Use basic (non-bottleneck) block and ResNet V1 (post-activation). |
| res_func = self._residual_v1 |
| |
| # 3 stages of block stacking. |
| for i in range(3): |
| with tf.name_scope("stage"): |
| for j in range(self.n): |
| if j == 0: |
| # First block in a stage, filters and strides may change. |
| x = res_func(x, 3, self.filters[i], self.filters[i + 1], self.strides[i]) |
| else: |
| # Following blocks in a stage, constant filters and unit stride. |
| x = res_func(x, 3, self.filters[i + 1], self.filters[i + 1], 1) |
| |
| x = self._global_avg_pool(x) |
| x = self._fully_connected(x, self.num_classes) |
| |
| return x |