94 lines
3.4 KiB
Python
94 lines
3.4 KiB
Python
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Contains a variant of the LeNet model definition."""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import tensorflow as tf
|
||
|
|
||
|
slim = tf.contrib.slim
|
||
|
|
||
|
|
||
|
def lenet(images, num_classes=10, is_training=False,
|
||
|
dropout_keep_prob=0.5,
|
||
|
prediction_fn=slim.softmax,
|
||
|
scope='LeNet'):
|
||
|
"""Creates a variant of the LeNet model.
|
||
|
|
||
|
Note that since the output is a set of 'logits', the values fall in the
|
||
|
interval of (-infinity, infinity). Consequently, to convert the outputs to a
|
||
|
probability distribution over the characters, one will need to convert them
|
||
|
using the softmax function:
|
||
|
|
||
|
logits = lenet.lenet(images, is_training=False)
|
||
|
probabilities = tf.nn.softmax(logits)
|
||
|
predictions = tf.argmax(logits, 1)
|
||
|
|
||
|
Args:
|
||
|
images: A batch of `Tensors` of size [batch_size, height, width, channels].
|
||
|
num_classes: the number of classes in the dataset.
|
||
|
is_training: specifies whether or not we're currently training the model.
|
||
|
This variable will determine the behaviour of the dropout layer.
|
||
|
dropout_keep_prob: the percentage of activation values that are retained.
|
||
|
prediction_fn: a function to get predictions out of logits.
|
||
|
scope: Optional variable_scope.
|
||
|
|
||
|
Returns:
|
||
|
logits: the pre-softmax activations, a tensor of size
|
||
|
[batch_size, `num_classes`]
|
||
|
end_points: a dictionary from components of the network to the corresponding
|
||
|
activation.
|
||
|
"""
|
||
|
end_points = {}
|
||
|
|
||
|
with tf.variable_scope(scope, 'LeNet', [images, num_classes]):
|
||
|
net = slim.conv2d(images, 32, [5, 5], scope='conv1')
|
||
|
net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
|
||
|
net = slim.conv2d(net, 64, [5, 5], scope='conv2')
|
||
|
net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
|
||
|
net = slim.flatten(net)
|
||
|
end_points['Flatten'] = net
|
||
|
|
||
|
net = slim.fully_connected(net, 1024, scope='fc3')
|
||
|
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
|
||
|
scope='dropout3')
|
||
|
logits = slim.fully_connected(net, num_classes, activation_fn=None,
|
||
|
scope='fc4')
|
||
|
|
||
|
end_points['Logits'] = logits
|
||
|
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
|
||
|
|
||
|
return logits, end_points
|
||
|
lenet.default_image_size = 28
|
||
|
|
||
|
|
||
|
def lenet_arg_scope(weight_decay=0.0):
|
||
|
"""Defines the default lenet argument scope.
|
||
|
|
||
|
Args:
|
||
|
weight_decay: The weight decay to use for regularizing the model.
|
||
|
|
||
|
Returns:
|
||
|
An `arg_scope` to use for the inception v3 model.
|
||
|
"""
|
||
|
with slim.arg_scope(
|
||
|
[slim.conv2d, slim.fully_connected],
|
||
|
weights_regularizer=slim.l2_regularizer(weight_decay),
|
||
|
weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
|
||
|
activation_fn=tf.nn.relu) as sc:
|
||
|
return sc
|