代码:
"""Run downstream classification"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
import utils.optimizer as optimizer
import epl
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer("task_index", None, "Worker or server index")
tf.flags.DEFINE_string("worker_hosts", "", "worker hosts")
tf.flags.DEFINE_string("buckets", "", "tables info")
tf.flags.DEFINE_string("train_table", "", "tables info")
tf.flags.DEFINE_string("val_table", "", "tables info")
tf.flags.DEFINE_string("checkpoint_dir", '',
"""Path to checkpoint folder""")
tf.flags.DEFINE_integer("num_epochs", 100,
"""Number of training epochs (default: 20)""")
tf.flags.DEFINE_integer("max_steps", 10000, "")
tf.flags.DEFINE_integer("batch_size", 256, """Batch size (default: 64)""")
tf.flags.DEFINE_integer("display_step", 200,
"""Number of steps to display log into TensorBoard (default: 20)""")
tf.flags.DEFINE_integer("save_checkpoints_steps", 1000,
"How often to save the model checkpoint.")
tf.flags.DEFINE_float("learning_rate", 0.001,
"""Learning rate (default: 0.0005)""")
tf.flags.DEFINE_float("max_grad_norm", 5.0,
"""Maximum value of the global norm of the gradients for clipping (default: 5.0)""")
tf.flags.DEFINE_integer("num_pipe_stages", 1, "number of pipeline stages")
tf.flags.DEFINE_integer("num_micro_batch", 1, "number of pipeline micro batches")
def str2list(str_in, shape, separator=' ', dtype=tf.int32):
data = tf.string_split([str_in], separator)
data = tf.string_to_number(data.values, dtype)
return tf.reshape(data, shape)
def file_based_input_fn_builder(input_file, slice_id, slice_count, is_training, drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
def _decode_record(*record):
"""Decodes a record to a TensorFlow example."""
(cert_no, coll_case_no, embedding, dt, label) = record
embedding = str2list(embedding, shape=[512], separator='\002', dtype=tf.float32)
example = {'input_embed': embedding,
'label': label,
'dt': dt,
'cert_no': cert_no,
'coll_case_no': coll_case_no}
return example
def input_fn(params):
"""The actual input function."""
d = tf.data.TableRecordDataset([input_file], record_defaults=['', '', '', '', 0])
if is_training:
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(buffer_size=1000)
d = d.apply(tf.contrib.data.map_and_batch(
lambda v1, v2, v3, v4, v5: _decode_record(v1, v2, v3, v4, v5),
batch_size=FLAGS.batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
def create_model(input_embed, label):
with tf.variable_scope("loss", reuse=tf.AUTO_REUSE):
with tf.variable_scope("cls"):
logits = tf.layers.dense(
input_embed,
2,
activation=None,
kernel_initializer=tf.truncated_normal_initializer())
one_hot_label = tf.one_hot(label, depth=2, dtype=tf.float32)
loss = tf.losses.softmax_cross_entropy(one_hot_label, logits)
probs = tf.nn.softmax(logits, axis=-1)
predict = tf.argmax(probs, axis=-1, output_type=tf.int32)
acc = tf.metrics.accuracy(label, predict)
auc = tf.metrics.auc(label, probs[:,-1])
return (loss, acc, auc)
def model_fn_builder(checkpoint_dir, learning_rate):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, mode):
"""The `model_fn` for Estimator."""
input_embed = features['input_embed']
label = features["label"]
# create loss
(loss, acc, auc) = create_model(input_embed, label)
output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
#rms optimizer
tvars = tf.trainable_variables()
grads = tf.gradients(loss, tvars)
clipped_grads, global_norm = tf.clip_by_global_norm(grads, FLAGS.max_grad_norm)
tf.summary.scalar('global_grad_norm', global_norm)
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.RMSPropOptimizer(learning_rate)
train_op = optimizer.apply_gradients(zip(clipped_grads, tvars),
name='train_op',
global_step=global_step)
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op)
elif mode == tf.estimator.ModeKeys.EVAL:
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
eval_metric_ops={'Accuracy':acc, "AUC":auc})
else:
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
return output_spec
return model_fn
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info("############## Start #####################")
checkpoint_dir = os.path.join(FLAGS.buckets, FLAGS.checkpoint_dir)
train_file = FLAGS.train_table
val_file = FLAGS.val_table
worker_spec = FLAGS.worker_hosts.split(",")
worker_count = len(worker_spec)
task_index = FLAGS.task_index
epl_env = epl.Env.get()
total_device = len(epl_env.cluster.available_devices)
num_replica = total_device // FLAGS.num_pipe_stages
micro_batch = FLAGS.batch_size // epl_env.config.pipeline.num_micro_batch
micro_batch = micro_batch // num_replica
print("total_batch: {}, num_micro_batch: {}, num_replica: {}, micro_batch: {}".format(
FLAGS.batch_size,
epl_env.config.pipeline.num_micro_batch,
num_replica,
micro_batch))
print("task_index:", task_index)
print("total_device:", total_device)
model_fn = model_fn_builder(checkpoint_dir, FLAGS.learning_rate)
train_input_fn = file_based_input_fn_builder(
input_file=train_file,
slice_id=task_index,
slice_count=worker_count,
is_training=True,
drop_remainder=True
)
val_input_fn = file_based_input_fn_builder(
input_file=val_file,
slice_id=task_index,
slice_count=worker_count,
is_training=False,
drop_remainder=False
)
sess_config = tf.ConfigProto(allow_soft_placement=True)
config = tf.estimator.RunConfig(session_config=sess_config,
save_checkpoints_steps=FLAGS.save_checkpoints_steps)
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=config,
model_dir=checkpoint_dir)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.max_steps)
eval_spec = tf.estimator.EvalSpec(input_fn=val_input_fn, start_delay_secs=6, throttle_secs=1)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
tf.logging.info("################# All process done. ########################")
if __name__ == '__main__':
env_dist = os.environ
print(env_dist.get('TF_CONFIG'))
config_json = {}
config_json["pipeline.num_micro_batch"] = FLAGS.num_micro_batch
epl.init(epl.Config(config_json))
if FLAGS.num_pipe_stages == 1:
epl.set_default_strategy(epl.replicate(device_count=1))
tf.app.run()
训练提交worker sql:
pai -name tensorflow1120_py3
-Dscript="***/resources/***.tar.gz"
-DentryFile="train_downstream_cls.py"
-Dbuckets="***"
-DuserDefinedParameters="--num_epochs=10 --max_steps=100000 --buckets=*** --checkpoint_dir=*** --train_table=*** --val_table=*** “
-Dtables="***, ***"
-Dcluster="{\"worker\":{\"count\":8,\"cpu\":400,\"gpu\":100}}"