Generative Handwriting using LSTM Mixture Density Network with TensorFlow

Overview

Generative Handwriting Demo using TensorFlow

example

example

An attempt to implement the random handwriting generation portion of Alex Graves' paper.

See my blog post at blog.otoro.net for more information.

How to use

I tested the implementation on TensorFlow r0.11 and Pyton 3. I also used the following libraries to help:

svgwrite
IPython.display.SVG
IPython.display.display
xml.etree.ElementTree
argparse
pickle

Training

You will need permission from these wonderful people people to get the IAM On-Line Handwriting data. Unzip lineStrokes-all.tar.gz into the data subdirectory, so that you end up with data/lineStrokes/a01, data/lineStrokes/a02, etc. Afterwards, running python train.py will start the training process.

A number of flags can be set for training if you wish to experiment with the parameters. The default values are in train.py

--rnn_size RNN_SIZE             size of RNN hidden state
--num_layers NUM_LAYERS         number of layers in the RNN
--model MODEL                   rnn, gru, or lstm
--batch_size BATCH_SIZE         minibatch size
--seq_length SEQ_LENGTH         RNN sequence length
--num_epochs NUM_EPOCHS         number of epochs
--save_every SAVE_EVERY         save frequency
--grad_clip GRAD_CLIP           clip gradients at this value
--learning_rate LEARNING_RATE   learning rate
--decay_rate DECAY_RATE         decay rate for rmsprop
--num_mixture NUM_MIXTURE       number of gaussian mixtures
--data_scale DATA_SCALE         factor to scale raw data down by
--keep_prob KEEP_PROB           dropout keep probability

Generating a Handwriting Sample

I've included a pretrained model in /save so it should work out of the box. Running python sample.py --filename example_name --sample_length 1000 will generate 4 .svg files for each example, with 1000 points.

IPython interactive session.

If you wish to experiment with this code interactively, just run %run -i sample.py in an IPython console, and then the following code is an example on how to generate samples and show them inside IPython.

[strokes, params] = model.sample(sess, 800)
draw_strokes(strokes, factor=8, svg_filename = 'sample.normal.svg')
draw_strokes_random_color(strokes, factor=8, svg_filename = 'sample.color.svg')
draw_strokes_random_color(strokes, factor=8, per_stroke_mode = False, svg_filename = 'sample.multi_color.svg')
draw_strokes_eos_weighted(strokes, params, factor=8, svg_filename = 'sample.eos.svg')
draw_strokes_pdf(strokes, params, factor=8, svg_filename = 'sample.pdf.svg')

example1a example1b example1c example1d example1e

Have fun-

License

MIT

Comments
  • upgrade train to python 3 and tensorflow r0.11

    upgrade train to python 3 and tensorflow r0.11

    For Python 3:

    • add parenthesis to print
    • change cPickle to pickle and open files as binary
    • change xrange to range

    For TensorFlow r0.11:

    • force concatenate instead of tuple for the LSTM states
    • train and regenerate checkpoints

    So only minor updates. To be noticed, at some point in time the concatenate will no longer be supported, so that will pop up as an issue in a next release of TensorFlow.

    opened by ghost 4
  • Unable to experiment with sample.py

    Unable to experiment with sample.py

    When I run %run -i sample.py in IPython console or just python sample.py, I get the following error:

    In [1]: %run -i sample.py WARNING:tensorflow:<tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell object at 0x7f093d719390>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) ~/Downloads/write-rnn-tensorflow/sample.py in () 34 saved_args = pickle.load(f) 35 ---> 36 model = Model(saved_args, True) 37 sess = tf.InteractiveSession() 38 #saver = tf.train.Saver(tf.all_variables())

    ~/Downloads/write-rnn-tensorflow/model.py in init(self, args, infer) 46 # inputs = tf.split(axis=1, num_or_size_splits=args.seq_length, value=self.input_data) 47 # inputs = [tf.squeeze(input_, [1]) for input_ in inputs] ---> 48 inputs = tf.unpack(self.input_data, axis=1) 49 50 outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, self.state_in, cell, loop_function=None, scope='rnnlm')

    AttributeError: module 'tensorflow' has no attribute 'unpack'

    Does this has something to do with the latest version of Tensorflow? I am using v1.3.0 if that helps.

    opened by ritiek 2
  • fix link to on-line database

    fix link to on-line database

    I accidentally registered to access the wrong database. This pull request fixes the link to point to the on-line database, which has the "data/lineStrokes-all.tar.gz" file.

    opened by bskaggs 2
  • EOS Data

    EOS Data

    Hello,

    Just wanted to clarify something. The data is arranged in x, y, eos - but you have the following line of code: z_eos = z[:, 0:1]

    This would say grab the first column. However, shouldn't it be something like: z_eos = z[:,2]

    Do you somehow rearrange the data around?

    opened by johnlarkin1 1
  • Opnames

    Opnames

    Hey, after the TF1 pull request, merge this if you want. It structures the graph so that the main ops needed for inference can be accessed by name (data_in, state_in, all the data_out for all mdn params, eos, state_out).

    If you look at sample_frozen.py you'll see what I mean. I also have it working in openframeworks too. Will post that example separately.

    Note if you look at the diff before merging TF1, you'll also see a bunch of tensorflow 1 compatibility changes.

    opened by memo 1
  • Added validation split, compute validation loss in training

    Added validation split, compute validation loss in training

    Quick attempt to add a validation split and compute the validation loss when training. The general idea is to try to get a sense if the model starts to overfit to the data it has seen. This also provides a more deterministic metric is choosing hyperparamaters across runs.

    This could probably be cleaned up and I'm not sure if this follows tensorflow idioms, etc. But it seems to be working in the current form. I'm curious to get feedback if others think this change could be a useful metric during training.

    opened by dribnet 1
  • TF 0.9: import of rnn_cell and seq2seq deprecated

    TF 0.9: import of rnn_cell and seq2seq deprecated

    With TF 0.9, this gives the errors mentioned in the title. This can be resolved by replacing rnn_cell with tf.nn.rnn_cell and seq2seq with tf.nn.seq2seq.

    opened by edwin-de-jong 1
  • Fix for Tensorflow 7.1

    Fix for Tensorflow 7.1

    This should fix the Issue #2, NotImplementedError: Negative indices are currently unsupported. Tensorflow 7.1 changes the behavior of seq2seq. The issue was first noticed in char-rnn-tensorflow https://github.com/sherjilozair/char-rnn-tensorflow/issues/10

    opened by rajshah4 1
  • lineStrokes-all.tar.gz cannot be downloaded

    lineStrokes-all.tar.gz cannot be downloaded

    I have registered my account in http://www.fki.inf.unibe.ch/ ;but lineStrokes-all.tar.gz still cannot be downloaded; It seems that the file has been deleted;

    Can you give me a valid address to download lineStrokes-all.tar.gz?? Thank you!!

    opened by jassentang 0
  • add --model_dir option to scripts

    add --model_dir option to scripts

    This adds an option --model_dir to train.py and sample.py allowing the model directory to be set explicitly. I find this useful so I can keep multiple models around with different hyperparameters, etc. Backwards compatible - if option is not given defaults to save/ as before. When training, the model_dir will be created if it does not already exist.

    This has been tested and is working for me with tensorflow 0.12.1.

    opened by dribnet 0
  • Can you achieve training loss of -1000 nats

    Can you achieve training loss of -1000 nats

    Hi, thanks for the nice work. In the paper, the training loss can go down to -1000. I can only get to -5. I wonder how your training loss curve looks like.

    Thanks

    opened by leinxx 0
  • NotFoundError, Tensor name [...] not found in checkpoint files save/model.ckpt-11000

    NotFoundError, Tensor name [...] not found in checkpoint files save/model.ckpt-11000

    First, thanks a lot for this very cool code!

    Running the pretrained model with the suggested command: python sample.py --filename example_name --sample_length 1000

    Produces this error:

    WARNING:tensorflow:<tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell object at 0x7f7a3c53e6a0>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:<tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell object at 0x7f7a3b5250f0>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:From /home/javier/repos/write-rnn-tensorflow/model.py:137: calling reduce_max (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version. Instructions for updating: keep_dims is deprecated, use keepdims instead WARNING:tensorflow:From /home/javier/repos/write-rnn-tensorflow/model.py:141: calling reduce_sum (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version. Instructions for updating: keep_dims is deprecated, use keepdims instead 2018-03-15 03:59:51.217146: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 loading model: save/model.ckpt-11000 Traceback (most recent call last): File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1350, in _do_call return fn(*args) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1329, in _run_fn status, run_metadata) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 473, in exit c_api.TF_GetCode(self.status.status)) tensorflow.python.framework.errors_impl.NotFoundError: Tensor name "rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/bias/Adam" not found in checkpoint files save/model.ckpt-11000 [[Node: save/RestoreV2_4 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_4/tensor_names, save/RestoreV2_4/shape_and_slices)]]

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "sample.py", line 42, in saver.restore(sess, ckpt.model_checkpoint_path) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1686, in restore {self.saver_def.filename_tensor_name: save_path}) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 895, in run run_metadata_ptr) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1128, in _run feed_dict_tensor, options, run_metadata) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1344, in _do_run options, run_metadata) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1363, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.NotFoundError: Tensor name "rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/bias/Adam" not found in checkpoint files save/model.ckpt-11000 [[Node: save/RestoreV2_4 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_4/tensor_names, save/RestoreV2_4/shape_and_slices)]]

    Caused by op 'save/RestoreV2_4', defined at: File "sample.py", line 37, in saver = tf.train.Saver() File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1239, in init self.build() File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1248, in build self._build(self._filename, build_save=True, build_restore=True) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1284, in _build build_save=build_save, build_restore=build_restore) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 765, in _build_internal restore_sequentially, reshape) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 428, in _AddRestoreOps tensors = self.restore_op(filename_tensor, saveable, preferred_shard) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 268, in restore_op [spec.tensor.dtype])[0]) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_io_ops.py", line 1031, in restore_v2 shape_and_slices=shape_and_slices, dtypes=dtypes, name=name) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3160, in create_op op_def=op_def) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1625, in init self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

    NotFoundError (see above for traceback): Tensor name "rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/bias/Adam" not found in checkpoint files save/model.ckpt-11000 [[Node: save/RestoreV2_4 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_4/tensor_names, save/RestoreV2_4/shape_and_slices)]]

    Please advise! Thank you in advance.

    opened by lk251 3
  • Generated samples look bad, but training seems ok

    Generated samples look bad, but training seems ok

    I'm getting crappy-looking samples from networks that seem to train properly. Loss converges nicely (see last screenshot), but samples don't compare with @hardmaru's results.

    Here are some examples... Any ideas for what could be causing this?

    image image image image

    This is after training for 10920 iterations (30 * 364) with python train.py --rnn_size 400 --num_layers 3. Default args produce similar results.

    Training and val loss look fine:

    Has anyone been able to train and produce great results with recent tensorflow? I wonder if some defaults have changed, or interface changes are resulting in some bad set up for the loss function.

    I think the loss function is the issue because loss values "look good" but actual results look bad. I think the loss function is optimizing for the wrong thing, basically.

    Any ideas appreciated. My next steps are to review how the loss is defined and maybe compare it with other implementations (https://greydanus.github.io/2016/08/21/handwriting/, https://github.com/snowkylin/rnn-handwriting-generation)

    opened by grisaitis 12
  • ValueError: Trying to share variable rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/kernel, but specified shape (512, 1024) and found shape (259, 1024).

    ValueError: Trying to share variable rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/kernel, but specified shape (512, 1024) and found shape (259, 1024).

    When I run train.py, got this errors. I think is kind of version problems, my tf version is 1.3

    File "/home/lxt/tf_project/HyperNetwork/write-rnn-tensorflow/model.py", line 50, in init outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, self.state_in, cell, loop_function=None, scope='rnnlm') File "/home/lxt/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py", line 152, in rnn_decoder output, state = cell(inp, state)

    opened by lxtGH 3
Owner
hardmaru
I make simple things with neural networks.
hardmaru
Code from the paper "High-Performance Brain-to-Text Communication via Handwriting"

High-Performance Brain-to-Text Communication via Handwriting Overview This repo is associated with this manuscript, preprint and dataset. The code can

Francis R. Willett 306 Jan 3, 2023
[CVPR 21] Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), 2021.

Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, CVPR 2021. Ayan Kumar Bhunia, Pinaki nath Chowdhury, Yongxin Yan

Ayan Kumar Bhunia 44 Dec 12, 2022
OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network

Stock Price Prediction of Apple Inc. Using Recurrent Neural Network OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network Dataset:

Nouroz Rahman 410 Jan 5, 2023
Using LSTM to detect spoofing attacks in an Air-Ground network

Using LSTM to detect spoofing attacks in an Air-Ground network Specifications IDE: Spider Packages: Tensorflow 2.1.0 Keras NumPy Scikit-learn Matplotl

Tiep M. H. 1 Nov 20, 2021
Estimation of human density in a closed space using deep learning.

Siemens HOLLZOF challenge - Human Density Estimation Add project description here. Installing Dependencies: Install Python3 either system-wide, user-w

null 3 Aug 8, 2021
This repository holds the code for the paper "Deep Conditional Gaussian Mixture Model forConstrained Clustering".

Deep Conditional Gaussian Mixture Model for Constrained Clustering. This repository holds the code for the paper Deep Conditional Gaussian Mixture Mod

null 17 Oct 30, 2022
Tutel MoE: An Optimized Mixture-of-Experts Implementation

Project Tutel Tutel MoE: An Optimized Mixture-of-Experts Implementation. Supported Framework: Pytorch Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32) Ho

Microsoft 344 Dec 29, 2022
Audio Source Separation is the process of separating a mixture into isolated sounds from individual sources

Audio Source Separation is the process of separating a mixture into isolated sounds from individual sources (e.g. just the lead vocals).

Victor Basu 14 Nov 7, 2022
Pytorch implementation of paper: "NeurMiPs: Neural Mixture of Planar Experts for View Synthesis"

NeurMips: Neural Mixture of Planar Experts for View Synthesis This is the official repo for PyTorch implementation of paper "NeurMips: Neural Mixture

James Lin 101 Dec 13, 2022
[ICLR 2022] Pretraining Text Encoders with Adversarial Mixture of Training Signal Generators

AMOS This repository contains the scripts for fine-tuning AMOS pretrained models on GLUE and SQuAD 2.0 benchmarks. Paper: Pretraining Text Encoders wi

Microsoft 22 Sep 15, 2022
PyTorch implementation of the Quasi-Recurrent Neural Network - up to 16 times faster than NVIDIA's cuDNN LSTM

Quasi-Recurrent Neural Network (QRNN) for PyTorch Updated to support multi-GPU environments via DataParallel - see the the multigpu_dataparallel.py ex

Salesforce 1.3k Dec 28, 2022
Details about the wide minima density hypothesis and metrics to compute width of a minima

wide-minima-density-hypothesis Details about the wide minima density hypothesis and metrics to compute width of a minima This repo presents the wide m

Nikhil Iyer 9 Dec 27, 2022
PyTorch implementations of algorithms for density estimation

pytorch-flows A PyTorch implementations of Masked Autoregressive Flow and some other invertible transformations from Glow: Generative Flow with Invert

Ilya Kostrikov 546 Dec 5, 2022
MADE (Masked Autoencoder Density Estimation) implementation in PyTorch

pytorch-made This code is an implementation of "Masked AutoEncoder for Density Estimation" by Germain et al., 2015. The core idea is that you can turn

Andrej 498 Dec 30, 2022
This YoloV5 based model is fit to detect people and different types of land vehicles, and displaying their density on a fitted map, according to their coordinates and detected labels.

This YoloV5 based model is fit to detect people and different types of land vehicles, and displaying their density on a fitted map, according to their

Liron Bdolah 8 May 22, 2022
This program presents convolutional kernel density estimation, a method used to detect intercritical epilpetic spikes (IEDs)

Description This program presents convolutional kernel density estimation, a method used to detect intercritical epilpetic spikes (IEDs) in [Gardy et

Ludovic Gardy 0 Feb 9, 2022
Official code of the paper "Expanding Low-Density Latent Regions for Open-Set Object Detection" (CVPR 2022)

OpenDet Expanding Low-Density Latent Regions for Open-Set Object Detection (CVPR2022) Jiaming Han, Yuqiang Ren, Jian Ding, Xingjia Pan, Ke Yan, Gui-So

csuhan 64 Jan 7, 2023
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Thomas Neumann 117 Nov 27, 2022