一、代码参考:
xxxxx://github.com/kubeflow/training-operator/blob/master/examples/tensorflow/dist-mnist/dist_mnist.py
二、Tensorflow分布式知识点
2.1 为什么进行分布式训练?目标是什么?
- why
1 使用正确的硬件配置可以显著减少训练时间
2 更短的训练时间可以更快地迭代以达到建模目标
- tf.distributed.Strategy的goal
1 使用简单
2 开箱即用
3 跨不同硬件设备
2.2 TF分布式模式
- 数据并行
- 几乎使用于任何架构模型
- 原理:数据切片
- 例如:
- 模型并行
- 适合具有可以并行运行的独立计算部分的模型(类似pipeline 并行,每个设备有模型的不同部分)
- 原理(右上:模型并行,右下:模型并行和数据并行):
2.3 TF分布式并行策略-同步并行和异步并行
同步并行和异步并行指的是参数的更新方式
- 同步并行(整个集群的速度上限受限于最慢机器的速度)
- 如何使每个GPU有所有其他GPU的梯度相加(基于规约的模式)?Ring All-Reduce算法:
- Reduce-scatter phase
- All-gather phase
如果是单机,则是同步单机上所有GPU的梯度;如果是多机,则是同步集群内所有机器的所有GPU的梯度。
经过至多2*(n-1)轮同步,就可以完成所有Worker的梯度更新。这种方式下所有节点的地位是平等的,因此不存在某个节点的负载瓶颈,随着Worker的增加,整体的通信量并不随着增加。加速比几乎可以跟机器数量成线性关系且不存在明显瓶颈。
- 异步并行(工作线程之间不依彼此,独立运行;缺点:可能会根据过时的值计算参数更新,导致延迟收敛)
- 和同步并行的对比:
- 常见:参数服务器
- Coordinator:创建资源、调度训练任务、写入检查点以及处理任务失败
- worker: 不存储参数,独立地执行训练样本子集的训练任务
- ps: 管理参数、更新参数
2.4 参数服务器PS分布式原理
- 属于并行策略:数据并行-异步并行
- 组成:ps、worker
- worker: 独立地执行训练样本子集的训练任务
- ps: 管理参数、更新参数
- 分布式API
- 创建集群的方法是为每个任务(task)启动一个服务(工作节点服务或者主节点服务)。这些任务可以分布在不同的机器上,也可以同一台机器启动多个任务,使用不同的GPU来运行。
- 创建一个tf.train.ClusterSpec,用于对集群中的所有任务进行描述,该描述对所有任务应该是相同的。
- tf.train.Server(cluster,job_name,task_index),用于创建一个服务(worker\chief),并运行相应作业上的计算任务,运行的任务在task_index指定的机器上启动。
- tf.device(device_name_of_function)。在指定的CPU或者GPU上设备上执行张量运算。
#方法1:指定在task0所有的机器上执行Tensor的操作运算
with tf.device("/job:ps/task:0"):
weights_1 = tf.Variable(...)
biases_1 = tf.Variable(...)
#方法2:在构造操作对象时自动将设备分配给操作对象
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=cluster)):
weights_1 = tf.Variable(...)
biases_1 = tf.Variable(...)
- 分布式训练代码框架
# 使用命令行参数解析来指定分布式集群的配置信息,例如 ps_hosts、worker_hosts、job_name 和 task_index
tf.app.flags.DEFINE_string("ps_hosts","","Comma-separatedlistof hostname:portpairs")
tf.app.flags.DEFINE_string("worker_hosts","","Comma-separatedlistof hostname:portpairs")
tf.app.flags.DEFINE_string("job_name","","oneof'ps','worker'")
tf.app.flags.DEFINE_integer("task_index",0,"Index of task within the job")
FLAGS = tf.app.flags.FLAGS
ps_hosts = FLAGS.ps_host.split(",")
worker_host = FLAGS.worker_hosts(",")
#tf.train.ClusterSpec 创建了一个 cluster 对象,其中包含了参数服务器和工作节点的主机地址和端口号。
cluster = tf.train.ClusterSpec({"ps":ps_hosts,"worker":worker_hosts})
#使用 tf.train.Server 创建当前任务节点的服务器,根据 job_name 和 task_index 参数来指定当前节点的角色和索引
server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)
#如果是参数服务器("ps"),则调用 server.join() 方法使参数服务器保持运行状态
if FLAGS.job_name == "ps":
server.join()
#如果是工作节点("worker"),则构建 TensorFlow 计算图模型
elif FLAGS.job_name == "worker":
# build tensorflow graph model
#创建了一个 tf.train.Supervisor 对象 sv来管理和监督模型的训练过程
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index==0),logdir="/tmp/train_logs")
#并使用 sv.prepare_or_wait_for_session 方法来创建会话(session)并连接到指定的 server.target
sess = sv.prepare_or_wait_for_session(server.target)
while not sv.should_stop():
# 模型训练
三、实验
3.1 环境:
NVIDIA-A100-SXM4-40GB
CUDA11.7
tensorflow[and-cuda]==2.14.0
3.2 mnist分布式训练代码(mnist_test.py):
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import json
import math
import os
import sys
import tempfile
import time
import input_data
import tensorflow as tf2
#from tensorflow.examples.tutorials.mnist import input_data
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
#命令行参数定义,通过flags.DEFINE_xxx方法定义不同类型的命令行参数,包括字符串string、整数integer、布尔值boolean、浮点数float等
flags = tf.app.flags
#定义存储 MNIST 数据集文件的目录路径
flags.DEFINE_string("data_dir", "/PS/mnist.npz",
"Directory for storing mnist data")
#定义存储训练模型的参数和日志文件的目录路径
flags.DEFINE_string("train_dir", "/PS/output",
"Directory for storing mnist checkpint and logs")
#是否只下载数据而不进行训练
flags.DEFINE_boolean("download_only", False,
"Only perform downloading of data; Do not proceed to "
"session preparation, model definition or training")
#定义神经网络中隐藏层的神经元数量
flags.DEFINE_integer("hidden_units", 100,
"Number of units in the hidden layer of the NN")
#训练过程中的全局训练步数
flags.DEFINE_integer("train_steps", 20000,
"Number of (global) training steps to perform")
#训练时每个批次的样本数
flags.DEFINE_integer("batch_size", 100, "Training batch size")
#训练过程中的学习率
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
#定义工作节点的索引。如果是主节点任务,则为0,负责变量初始化。如果是工作节点任务,则为大于等于0的整数。
flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
#每台机器上使用的 GPU 数量。如果不使用 GPU,可以将其设置为0。
flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine."
"If you don't use GPU, please set it to '0'")
#在应用参数更新之前要聚合的副本数量。仅对 sync_replicas 模式有效,默认值为 num_workers(工作节点数)。
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
#是否使用同步副本(synchronized replicas)模式,在此模式下,来自工作节点的参数更新在应用之前会被聚合,以避免过时的梯度。
flags.DEFINE_boolean(
"sync_replicas", False,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")
#是否已经存在服务器。如果为 True,将通过 GRPC URL 使用工作节点的主机名。否则,将创建一个进程内的 TensorFlow 服务器。
flags.DEFINE_boolean(
"existing_servers", False, "Whether servers already exists. If True, "
"will use the worker hosts via their GRPC URLs (one client process "
"per worker host). Otherwise, will create an in-process TensorFlow "
"server.")
#参数服务器的主机名和端口号,使用逗号分隔的字符串
flags.DEFINE_string("ps_hosts", "localhost:2222",
"Comma-separated list of hostname:port pairs")
#工作节点的主机名和端口号,使用逗号分隔的字符串。
flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
"Comma-separated list of hostname:port pairs")
#任务名称,可以是 "worker" 或 "ps"
flags.DEFINE_string("job_name", None, "job name: worker or ps")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28
def batch_generator(x, y, batch_size):
num_samples = len(x)
while True:
indices = np.random.choice(num_samples, batch_size, replace=False)
batch_x = x[indices%num_samples]
batch_y = y[indices%num_samples]
yield batch_x, batch_y
# Example:
# cluster = {'ps': ['host1:2222', 'host2:2222'],
# 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
# os.environ['TF_CONFIG'] = json.dumps(
# {'cluster': cluster,
# 'task': {'type': 'worker', 'index': 1}})
def main(unused_argv):
# Parse environment variable TF_CONFIG to get job_name and task_index
# If not explicitly specified in the constructor and the TF_CONFIG
# environment variable is present, load cluster_spec from TF_CONFIG.
# 解析环境变量 TF_CONFIG,并根据解析结果设置全局变量 FLAGS 的值,以便在 TensorFlow 分布式集群中启动不同的任务
tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}')
task_config = tf_config.get('task', {})
task_type = task_config.get('type')
task_index = task_config.get('index')
FLAGS.job_name = task_type
FLAGS.task_index = task_index
#mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
#使用 TensorFlow 提供的 Keras API 加载 MNIST 数据集,并对数据进行了预处理
(x_train, y_train), (x_test, y_test) = tf2.keras.datasets.mnist.load_data(FLAGS.data_dir)
#print('Training data shape:', x_train.shape, y_train.shape)
#print('Testing data shape:', x_test.shape, y_test.shape)
#将像素值缩放到 0-1 范围内
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
#对标签进行 one-hot 编码
y_train = np.eye(10)[y_train]
y_test = np.eye(10)[y_test]
#改变数据的形状
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)
y_train = y_train.astype('float32').reshape(-1, 10)
y_test = y_test.astype('float32').reshape(-1, 10)
#print('Training data shape:', x_train.shape, y_train.shape)
#print('Testing data shape:', x_test.shape, y_test.shape)
#代码检查 FLAGS.job_name 和 FLAGS.task_index 是否已经被设置
#如果没有则抛出异常,因为在分布式集群中,每个任务必须有一个唯一的名称和索引,以便其他任务与其通信。
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index == "":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)
#从TF_CONFIG中获取集群配置信息,包括参数服务器和工作节点的地址
cluster_config = tf_config.get('cluster', {})
ps_hosts = cluster_config.get('ps')
worker_hosts = cluster_config.get('worker')
ps_hosts_str = ','.join(ps_hosts)
worker_hosts_str = ','.join(worker_hosts)
FLAGS.ps_hosts = ps_hosts_str
FLAGS.worker_hosts = worker_hosts_str
# Construct the cluster and start the server
#将逗号分隔的字符串再拆分成列表的操作,以便在构造 ClusterSpec 对象时使用
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
# Get the number of workers.
num_workers = len(worker_spec)
#构造 ClusterSpec 对象,表示整个分布式集群的拓扑结构。
#所有任务的拓朴结构都是相同的
#其中ps和worker为作业名称,ps_spec和worker_spec为该作业的任务所在节点的地址信息ip:port列表
cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec})
#创建一个服务,并运行相应作业上的计算任务
if not FLAGS.existing_servers:
# Not using existing servers. Create an in-process server.
# 创建一个服务(主节点服务或工作节点服务),执行实际的训练操作,即运行相应作业上的计算任务
server = tf.train.Server(
cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
# 参数服务器任务:存储和更新参数,在运行时通过 join() 方法来加入到 TensorFlow 服务器集群中
if FLAGS.job_name == "ps":
server.join()
# 将第一个worker节点设置为chief,chief任务:保存checkpoint和记录日志
is_chief = (FLAGS.task_index == 0)
# 根据 FLAGS.num_gpus 的值,来决定工作节点的设备分配方式
if FLAGS.num_gpus > 0:
# Avoid gpu allocation conflict: now allocate task_num -> #gpu
# for each worker in the corresponding machine
#避免多个工作节点同时分配到同一个 GPU,因此需要将每个工作节点分配到不同的 GPU 上
gpu = (FLAGS.task_index % FLAGS.num_gpus)
worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
elif FLAGS.num_gpus == 0:
# Just allocate the CPU to worker server
cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
print("worker_device:",worker_device)
# The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers.
# The ps use CPU and workers use corresponding GPU
#设置TensorFlow设备和分布式设置
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=cluster)):
#构建要训练的模型,构建计算图
#定义一个全局变量global_step,用于记录训练步数
global_step = tf.Variable(0, name="global_step", trainable=False)
# Variables of the hidden layer
#定义隐藏层的权重变量hid_w,并使用截断正态分布进行初始化
hid_w = tf.Variable(
tf.truncated_normal(
[IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
stddev=1.0 / IMAGE_PIXELS),
name="hid_w")
#定义隐藏层的偏置变量hid_b,并初始化为全零。
hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
# Variables of the softmax layer
#定义softmax层的权重变量sm_w,并使用截断正态分布进行初始化。
sm_w = tf.Variable(
tf.truncated_normal(
[FLAGS.hidden_units, 10],
stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
name="sm_w")
#定义softmax层的偏置变量sm_b,并初始化为全零。
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
# Ops: located on the worker specified with FLAGS.task_index
#定义输入数据的占位符x和标签数据的占位符y_。
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10])
#计算隐藏层的线性变换结果
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
#对隐藏层的线性变换结果进行ReLU激活函数处理
hid = tf.nn.relu(hid_lin)
#计算输出层的softmax激活函数结果。
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
#计算交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
#定义Adam优化器
opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
#如果使用同步副本训练
if FLAGS.sync_replicas:
if FLAGS.replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
#同步训练的优化器
opt = tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=num_workers,
name="mnist_sync_replicas")
#定义训练操作,使用优化器最小化交叉熵损失
train_step = opt.minimize(cross_entropy, global_step=global_step)
if FLAGS.sync_replicas:
#local_step_init_op:用于初始化每个任务的本地步数。它会将每个任务的本地步数初始化为0。
local_init_op = opt.local_step_init_op
if is_chief:
#仅在主任务(is_chief=True)中执行的初始化操作。用于初始化同步训练所需的特定于主任务的变量或状态。
local_init_op = opt.chief_init_op
ready_for_local_init_op = opt.ready_for_local_init_op
# Initial token and chief queue runners required by the sync_replicas mode
chief_queue_runner = opt.get_chief_queue_runner()
sync_init_op = opt.get_init_tokens_op()
#定义全局变量的初始化操作
init_op = tf.global_variables_initializer()
#train_dir = tempfile.mkdtemp()
train_dir = FLAGS.train_dir
#创建tf.train.Supervisor来管理模型的训练过程
# 创建一个supervisor来监督训练过程
if FLAGS.sync_replicas:
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
recovery_wait_secs=1,
global_step=global_step)
else:
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
recovery_wait_secs=1,
global_step=global_step)
#配置会话参数
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
device_filters=["/job:ps",
"/job:worker/task:%d" % FLAGS.task_index])
# The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
if is_chief:
print("Worker %d: Initializing session..." % FLAGS.task_index)
else:
print("Worker %d: Waiting for session to be initialized..." %
FLAGS.task_index)
if FLAGS.existing_servers:
server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
print("Using existing server at: %s" % server_grpc_url)
#prepare_or_wait_for_session配合tf.train.Supervisor使用
#supervisor负责会话初始化和从检查点恢复模型
sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config)
else:
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index)
if FLAGS.sync_replicas and is_chief:
# Chief worker will start the chief queue runner and call the init op.
#主任务调用sess.run(sync_init_op)来初始化同步训练所需的token
#并通过sv.start_queue_runners()启动chief_queue_runner队列运行器
sess.run(sync_init_op)
sv.start_queue_runners(sess, [chief_queue_runner])
# Perform training
time_begin = time.time()
print("Training begins @ %f" % time_begin)
#生成批量数据的生成器
train_generator = batch_generator(x_train, y_train, FLAGS.batch_size)
local_step = 0
while True:
#模型训练
# Training feed
#batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
batch_xs, batch_ys = next(train_generator)
train_feed = {x: batch_xs, y_: batch_ys}
#调用sess.run()运行cross_entropy操作得到损失值
_, step = sess.run([train_step, global_step], feed_dict=train_feed)
local_step += 1
now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" %
(now, FLAGS.task_index, local_step, step))
#如果达到训练步数上限,则退出训练循环
if step >= FLAGS.train_steps:
break
time_end = time.time()
print("Training ends @ %f" % time_end)
training_time = time_end - time_begin
print("Training elapsed time: %f s" % training_time)
# Validation feed
#val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
val_feed = {x: x_test, y_: y_test}
#运行交叉熵损失的计算,并返回验证集的损失值。
val_xent = sess.run(cross_entropy, feed_dict=val_feed)
#打印训练完成后的验证集交叉熵损失
print("After %d training step(s), validation cross entropy = %g" %
(FLAGS.train_steps, val_xent))
if __name__ == "__main__":
#用于启动TensorFlow应用程序的常用方法
#会启动TensorFlow应用程序,并根据命令行参数执行相应的操作。
tf.app.run()
3.3 分布式启动TFJob的yaml文件:
apiVersion: "kubeflow.org/v1"
kind: "TFJob"
metadata:
name: "dist-mnist-for-e2e-test"
spec:
tfReplicaSpecs:
PS:
replicas: 1
restartPolicy: Never
template:
metadata:
labels:
job_name: "ps"
spec:
nodeSelector:
kubernetes.io/hostname: hostname # 指定运行节点的标签名称和值
containers:
- name: tensorflow
image: tensorflow_mnist:v1
workingDir: /PS
command: ["sh", "-c", "python3 mnist_test.py --num_gpus 2 --batch_size 100"] # 添加运行命令
resources:
limits:
nvidia.com/gpu: 2 # 指定使用的 GPU 数量为 2
env:
- name: NVIDIA_VISIBLE_DEVICES
value: "all"
Worker:
replicas: 1
restartPolicy: Never
template:
metadata:
labels:
job_name: "worker"
spec:
nodeSelector:
kubernetes.io/hostname: hostname # 指定运行节点的标签名称和值
containers:
- name: tensorflow
image: tensorflow_mnist:v1
workingDir: /PS
command: ["sh", "-c", "python3 mnist_test.py --num_gpus 2 --batch_size 100"] # 添加运行命令
resources:
limits:
nvidia.com/gpu: 2 # 指定使用的 GPU 数量为 2
env:
- name: NVIDIA_VISIBLE_DEVICES
value: "all"from __future__ import absolute_import
3.4 input_data.py(来自tensorflow1.x中tutorials/mnist/input_data.py)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# xxxx://xxx.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for downloading and reading MNIST data (deprecated).
This module and all its submodules are deprecated.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import gzip
import os
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.platform import gfile
from tensorflow.python.util.deprecation import deprecated
_Datasets = collections.namedtuple('_Datasets', ['train', 'validation', 'test'])
# CVDF mirror of xxx://yann.lecun.com/exdb/mnist/
DEFAULT_SOURCE_URL = 'xxxxx://storage.googleapis.com/cvdf-datasets/mnist/'
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
@deprecated(None, 'Please use tf.data to implement this functionality.')
def _extract_images(f):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth].
Args:
f: A file object that can be passed into a gzip reader.
Returns:
data: A 4D uint8 numpy array [index, y, x, depth].
Raises:
ValueError: If the bytestream does not start with 2051.
"""
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST image file: %s' %
(magic, f.name))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
@deprecated(None, 'Please use tf.one_hot on tensors.')
def _dense_to_one_hot(labels_dense, num_classes):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
@deprecated(None, 'Please use tf.data to implement this functionality.')
def _extract_labels(f, one_hot=False, num_classes=10):
"""Extract the labels into a 1D uint8 numpy array [index].
Args:
f: A file object that can be passed into a gzip reader.
one_hot: Does one hot encoding for the result.
num_classes: Number of classes for the one hot encoding.
Returns:
labels: a 1D uint8 numpy array.
Raises:
ValueError: If the bystream doesn't start with 2049.
"""
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST label file: %s' %
(magic, f.name))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
if one_hot:
return _dense_to_one_hot(labels, num_classes)
return labels
class _DataSet(object):
"""Container class for a _DataSet (deprecated).
THIS CLASS IS DEPRECATED.
"""
@deprecated(None, 'Please use alternatives such as official/mnist/_DataSet.py'
' from tensorflow/models.')
def __init__(self,
images,
labels,
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
seed=None):
"""Construct a _DataSet.
one_hot arg is used only if fake_data is true. `dtype` can be either
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
`[0, 1]`. Seed arg provides for convenient deterministic testing.
Args:
images: The images
labels: The labels
fake_data: Ignore inages and labels, use fake data.
one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
False).
dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
range [0,255]. float32 output has range [0,1].
reshape: Bool. If True returned images are returned flattened to vectors.
seed: The random seed to use.
"""
seed1, seed2 = random_seed.get_seed(seed)
# If op level seed is not set, use whatever graph level seed is returned
numpy.random.seed(seed1 if seed is None else seed2)
dtype = dtypes.as_dtype(dtype).base_dtype
if dtype not in (dtypes.uint8, dtypes.float32):
raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
dtype)
if fake_data:
self._num_examples = 10000
self.one_hot = one_hot
else:
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
self._num_examples = images.shape[0]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
if reshape:
assert images.shape[3] == 1
images = images.reshape(images.shape[0],
images.shape[1] * images.shape[2])
if dtype == dtypes.float32:
# Convert from [0, 255] -> [0.0, 1.0].
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size, fake_data=False, shuffle=True):
"""Return the next `batch_size` examples from this data set."""
if fake_data:
fake_image = [1] * 784
if self.one_hot:
fake_label = [1] + [0] * 9
else:
fake_label = 0
return [fake_image for _ in xrange(batch_size)
], [fake_label for _ in xrange(batch_size)]
start = self._index_in_epoch
# Shuffle for the first epoch
if self._epochs_completed == 0 and start == 0 and shuffle:
perm0 = numpy.arange(self._num_examples)
numpy.random.shuffle(perm0)
self._images = self.images[perm0]
self._labels = self.labels[perm0]
# Go to the next epoch
if start + batch_size > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Get the rest examples in this epoch
rest_num_examples = self._num_examples - start
images_rest_part = self._images[start:self._num_examples]
labels_rest_part = self._labels[start:self._num_examples]
# Shuffle the data
if shuffle:
perm = numpy.arange(self._num_examples)
numpy.random.shuffle(perm)
self._images = self.images[perm]
self._labels = self.labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size - rest_num_examples
end = self._index_in_epoch
images_new_part = self._images[start:end]
labels_new_part = self._labels[start:end]
return numpy.concatenate((images_rest_part, images_new_part),
axis=0), numpy.concatenate(
(labels_rest_part, labels_new_part), axis=0)
else:
self._index_in_epoch += batch_size
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]
@deprecated(None, 'Please write your own downloading logic.')
def _maybe_download(filename, work_directory, source_url):
"""Download the data from source url, unless it's already here.
Args:
filename: string, name of the file in the directory.
work_directory: string, path to working directory.
source_url: url to download from if file doesn't exist.
Returns:
Path to resulting file.
"""
if not gfile.Exists(work_directory):
gfile.MakeDirs(work_directory)
filepath = os.path.join(work_directory, filename)
if not gfile.Exists(filepath):
urllib.request.urlretrieve(source_url, filepath)
with gfile.GFile(filepath) as f:
size = f.size()
print('Successfully downloaded', filename, size, 'bytes.')
return filepath
@deprecated(None, 'Please use alternatives such as:'
' tensorflow_datasets.load(\'mnist\')')
def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
validation_size=5000,
seed=None,
source_url=DEFAULT_SOURCE_URL):
if fake_data:
def fake():
return _DataSet([], [],
fake_data=True,
one_hot=one_hot,
dtype=dtype,
seed=seed)
train = fake()
validation = fake()
test = fake()
return _Datasets(train=train, validation=validation, test=test)
if not source_url: # empty string check
source_url = DEFAULT_SOURCE_URL
train_images_file = 'train-images-idx3-ubyte.gz'
train_labels_file = 'train-labels-idx1-ubyte.gz'
test_images_file = 't10k-images-idx3-ubyte.gz'
test_labels_file = 't10k-labels-idx1-ubyte.gz'
local_file = _maybe_download(train_images_file, train_dir,
source_url + train_images_file)
with gfile.Open(local_file, 'rb') as f:
train_images = _extract_images(f)
local_file = _maybe_download(train_labels_file, train_dir,
source_url + train_labels_file)
with gfile.Open(local_file, 'rb') as f:
train_labels = _extract_labels(f, one_hot=one_hot)
local_file = _maybe_download(test_images_file, train_dir,
source_url + test_images_file)
with gfile.Open(local_file, 'rb') as f:
test_images = _extract_images(f)
local_file = _maybe_download(test_labels_file, train_dir,
source_url + test_labels_file)
with gfile.Open(local_file, 'rb') as f:
test_labels = _extract_labels(f, one_hot=one_hot)
if not 0 <= validation_size <= len(train_images):
raise ValueError(
'Validation size should be between 0 and {}. Received: {}.'.format(
len(train_images), validation_size))
validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
options = dict(dtype=dtype, reshape=reshape, seed=seed)
train = _DataSet(train_images, train_labels, **options)
validation = _DataSet(validation_images, validation_labels, **options)
test = _DataSet(test_images, test_labels, **options)
return _Datasets(train=train, validation=validation, test=test)