blob: fe8dd38331f15775d424d9f4d47e32dfbf6a9042 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model class for Cifar10 Dataset."""
from __future__ import division, print_function
import model_base
import tensorflow as tf
class ResNetCifar10(model_base.ResNet):
"""Cifar10 model with ResNetV1 and basic residual block."""
def __init__(
self,
num_layers,
is_training,
batch_norm_decay,
batch_norm_epsilon,
data_format="channels_first",
):
super(ResNetCifar10, self).__init__(is_training, data_format, batch_norm_decay, batch_norm_epsilon)
self.n = (num_layers - 2) // 6
# Add one in case label starts with 1. No impact if label starts with 0.
self.num_classes = 10 + 1
self.filters = [16, 16, 32, 64]
self.strides = [1, 2, 2]
def forward_pass(self, x, input_data_format="channels_last"):
"""Build the core model within the graph."""
if self._data_format != input_data_format:
if input_data_format == "channels_last":
# Computation requires channels_first.
x = tf.transpose(x, [0, 3, 1, 2])
else:
# Computation requires channels_last.
x = tf.transpose(x, [0, 2, 3, 1])
# Image standardization.
x = x / 128 - 1
x = self._conv(x, 3, 16, 1)
x = self._batch_norm(x)
x = self._relu(x)
# Use basic (non-bottleneck) block and ResNet V1 (post-activation).
res_func = self._residual_v1
# 3 stages of block stacking.
for i in range(3):
with tf.name_scope("stage"):
for j in range(self.n):
if j == 0:
# First block in a stage, filters and strides may change.
x = res_func(x, 3, self.filters[i], self.filters[i + 1], self.strides[i])
else:
# Following blocks in a stage, constant filters and unit stride.
x = res_func(x, 3, self.filters[i + 1], self.filters[i + 1], 1)
x = self._global_avg_pool(x)
x = self._fully_connected(x, self.num_classes)
return x