The first public PyTorch implementation of Attentive Recurrent Comparators

Overview

arc-pytorch

PyTorch implementation of Attentive Recurrent Comparators by Shyam et al.

A blog explaining Attentive Recurrent Comparators

Visualizing Attention

On Same characters

On Different Characters

How to run?

Download data

python download_data.py

A one-time 52MB download. Shouldn't take more than a few minutes.

Train

python train.py --cuda

Let it train until the accuracy rises to at least 80%. Early stopping is not implemented yet. You will have to manually kill the process.

Visualize

python viz.py --cuda --load 0.13591022789478302 --same

Run with exactly the same parameters as train.py and specify the model to load. Specify "--same" if you want to generate a sample with same characters in both images. The script dumps images to a directory in visualization. The name of directory is taken from --name parameter if specified, else name is a function of the parameters of network.

Comments
  • type error?

    type error?

            X = np.zeros((2 * batch_size, image_size, image_size), dtype='uint8')
            ... fill in X ...
            if part == 'train':
                X = self.augmentor.augment_batch(X)
            else:
                X = X / 255.0
    
    Wouldn't X -> 0's if dividing by 255.0?
    opened by phobrain 2
  • download_data.py problems

    download_data.py problems

    Python 2.7.10

    1. Maybe a python 3 thing? Or a 'force you to read the code, idiot' sort of thing? :-)

    python download_data.py File "download_data.py", line 28 def extract() -> None: ^ SyntaxError: invalid syntax

    Deleted the '-> None' pattern after ().

    1. urllib problems - Python 3 req't?

    File "download_data.py", line 2, in import urllib.request ImportError: No module named request (tensorflow) priot arc-pytorch% pip install urllib Collecting urllib Downloading urllib-1.21.1.tar.gz (226kB) ... Successfully installed urllib-1.21.1 (tensorflow) priot arc-pytorch% !py python download_data.py Traceback (most recent call last): File "download_data.py", line 2, in import urllib.request ImportError: No module named request

    opened by phobrain 2
  • delta_caps

    delta_caps

    Hi Sanyam

    Let me begin by thanking you for taking your time implementing this!

    I've been playing around porting your code to Tensorflow when I noticed something. The delta_caps variable in the method _get_filterbanks wants delta in the range [-1, 1] (as in zoom out, zoom in) but then in the method you take the absolute value of it.

    Is there any reason for doing it like this?

    opened by Dammi87 1
  • Did you do the Contextual ARC implementation too?

    Did you do the Contextual ARC implementation too?

    Hi, thanks for the repo, it's great!

    I was curious if this repo includes the 21-way Contextual ARC, or is this only the Naive ARC approach? I believe the contextual ARC has a bidirectional LSTM that uses the 21 pairwise embeddings, but I'm not sure how to implement that.

    Thanks

    opened by exnx 0
  • RGB image as input?

    RGB image as input?

    Hi, thanks for your great implementation on ARC. Is the current version also support 3-channel RGB images as input?

    I found that the implementation works fine with gray-scale input images, such as Omniglot dataset, but when I tried with different dataset with RGB images, it seems like it doesn't work.

    opened by jmkim0309 1
  • Scaling up

    Scaling up

    First issue in scaling up:

    Looks like one needs to analyze all the photos to train with, to get preprocess() factors like in this example:

    http://blog.outcome.io/pytorch-quick-start-classifying-an-image/

    Seems someone would have written a program by now to take a set of images and output the numbers for

    normalize = transforms.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
    )
    preprocess = transforms.Compose([
       transforms.Scale(256),
       transforms.CenterCrop(224),
       transforms.ToTensor(),
       normalize
    ])
    

    Here's a stab at it:

    import numpy as np
    from PIL import Image
    
    pic_dir = '~/images/prun299'
    fileList = pic_dir + '/files'
    
    pixCount = 0
    RGB = [0.0, 0.0, 0.0]
    
    with open(fileList) as fp:
        for line in fp:
            file = pic_dir + "/" + line.rstrip()
            try:
                im = Image.open(file)
            except Exception, e:
                print "None: %s %s" % (file, str(e))
                continue
    
            for x in range(im.width):
                for y in range(im.height):
                    pix = im.getpixel((x, y))
                    RGB[0] += pix[0]
                    RGB[1] += pix[1]
                    RGB[2] += pix[2]
            pixCount += im.width * im.height
            im.close()
    
    RGB[0] /= pixCount
    RGB[1] /= pixCount
    RGB[2] /= pixCount
    
    DEV = [0.0, 0.0, 0.0]
    
    print('pass 2')
    
    with open(fileList) as fp:
        for line in fp:
            #print('line ' + line)
            file = pic_dir + "/" + line.rstrip()
            try:
                im = Image.open(file)
            except:
                continue
    
            #print('file ' + file)
            for x in range(im.width):
                for y in range(im.height):
                    pix = im.getpixel((x, y))
                    d = RGB[0] - pix[0]
                    DEV[0] += d * d
                    d = RGB[1] - pix[1]
                    DEV[1] += d * d
                    d = RGB[2] - pix[2]
                    DEV[2] += d * d
             im.close()
    
    DEV[0] /= pixCount
    DEV[1] /= pixCount
    DEV[2] /= pixCount
    DEV = np.sqrt(DEV)
    
    RGB[0] /= 255
    RGB[1] /= 255
    RGB[2] /= 255
    
    DEV[0] /= 255
    DEV[1] /= 255
    DEV[2] /= 255
    
    print('mean=[' + str(RGB[0]) + ', ' + str(RGB[1]) + ', ' + str(RGB[2]) + '],')
    print('std=[' + str(DEV[0]) + ', ' + str(DEV[1]) + ', ' + str(DEV[2]) + ']')
    
    #  6764 files:
    # mean=[0.3876046197, 0.3751385941, 0.3667266388],
    # std=[0.2649736267, 0.2584158245, 0.2701408752]
    

    Resulting in this initial loader, keras version works, this one untested.

    from torchvision import models, transforms
    from PIL import Image
    
    pair_dir = '~/pb'
    pic_dir = '~/images/prun299'
    image_size = 299
    
    normalize = transforms.Normalize(
     mean=[0.3876046197, 0.3751385941, 0.3667266388],
     std=[0.2649736267, 0.2584158245, 0.2701408752]
    )
    preprocess = transforms.Compose([
       transforms.Scale(256),
       transforms.CenterCrop(image_size),
       transforms.ToTensor(),
       normalize
    ])
    
    file_map = {}
    
    def load_preproc():
        print('Loading pics')
        scan_file(pair_dir + '/test.neg')
        scan_file(pair_dir + '/test.pos')
        scan_file(pair_dir + '/train.pos')
        scan_file(pair_dir + '/train.neg')
    
    def scan_file(fname):
        print('Scan file: ' + fname)
        ct = 0
        ct2 = 0
        with open(fname) as fp:
            for line in fp:
                fname1, fname2 = line.split()
                if (type(file_map.get(fname1, None)) is NoneType):
                    ct += 1
                    img_pil = Image.open(pic_dir + '/' + fname1)
                    img_tensor = preprocess(img_pil)
                    img_tensor.unsqueeze_(0)
                    file_map[fname1] = img_tensor
                else:
                    ct2 += 1
    
                if (type(file_map.get(fname2, None)) is NoneType):
                    ct += 1
                    img_pil = Image.open(pic_dir + '/' + fname2)
                    img_tensor = preprocess(img_pil)
                    img_tensor.unsqueeze_(0)
                    file_map[fname2] = img_tensor
                else:
                    ct2 += 1
    
        print('    loaded: ' + str(ct) + ' already loaded: ' + str(ct2))
    

    Here's how the keras file-load-preproc portion looks:

    from keras.preprocessing import image
    from keras.applications.inception_v3 import preprocess_input
    
                    im1 = image.load_img(pic_dir + '/' + fname1, target_size=input_dim)
                    x = image.img_to_array(im1)
                    x = np.expand_dims(x, axis=0)
                    x = preprocess_input(x)[0]
                    file_map[fname1] = x
    
    opened by phobrain 1
  • viz.py no result

    viz.py no result

    After making the same 2.7 os.makedirs() change to viz.py as in #2,

    % python viz.py --load 0.677756249905 --same /Users/priot/anaconda/lib/python2.7/site-packages/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment. warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.') %

    And no obvious action.

    ls -l saved_models/6_8_128_cpu/0.677756249905 -rw-r--r-- 1 priot staff 433348 Aug 29 19:21 saved_models/6_8_128_cpu/0.677756249905

    OSX Darwin/Macbook Pro

    opened by phobrain 2
  • other python 2.7

    other python 2.7

    Lots of similar changes as in download issue #1 I closed with fix, plus this pattern:

    -- train.py

    < def get_pct_accuracy(pred: Variable, target) :

    def get_pct_accuracy(pred, target) :

    opened by phobrain 9
Owner
Sanyam Agarwal
Visiting Research Scholar-- Machine Learning Perception Lab at Georgia Institute of Technology
Sanyam Agarwal
A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

TorchRL Disclaimer This library is not officially released yet and is subject to change. The features are available before an official release so that

Meta Research 860 Jan 7, 2023
Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) - PyTorch Implementation

Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding PyTorch implementation for the Scalable Attentive Sentence-Pair Modeling vi

Microsoft 25 Dec 2, 2022
Dynamic Attentive Graph Learning for Image Restoration, ICCV2021 [PyTorch Code]

Dynamic Attentive Graph Learning for Image Restoration This repository is for GATIR introduced in the following paper: Chong Mou, Jian Zhang, Zhuoyuan

Jian Zhang 84 Dec 9, 2022
Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context Code in both PyTorch and TensorFlow

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context This repository contains the code in both PyTorch and TensorFlow for our paper

Zhilin Yang 3.3k Jan 6, 2023
A PaddlePaddle implementation of Time Interval Aware Self-Attentive Sequential Recommendation.

TiSASRec.paddle A PaddlePaddle implementation of Time Interval Aware Self-Attentive Sequential Recommendation. Introduction 论文:Time Interval Aware Sel

Paddorch 2 Nov 28, 2021
PyTorch implementation DRO: Deep Recurrent Optimizer for Structure-from-Motion

DRO: Deep Recurrent Optimizer for Structure-from-Motion This is the official PyTorch implementation code for DRO-sfm. For technical details, please re

Alibaba Cloud 56 Dec 12, 2022
An implementation of DeepMind's Relational Recurrent Neural Networks in PyTorch.

relational-rnn-pytorch An implementation of DeepMind's Relational Recurrent Neural Networks (Santoro et al. 2018) in PyTorch. Relational Memory Core (

Sang-gil Lee 241 Nov 18, 2022
Pytorch implementation of the Variational Recurrent Neural Network (VRNN).

VariationalRecurrentNeuralNetwork Pytorch implementation of the Variational RNN (VRNN), from A Recurrent Latent Variable Model for Sequential Data. Th

emmanuel 251 Dec 17, 2022
PyTorch implementation of the Quasi-Recurrent Neural Network - up to 16 times faster than NVIDIA's cuDNN LSTM

Quasi-Recurrent Neural Network (QRNN) for PyTorch Updated to support multi-GPU environments via DataParallel - see the the multigpu_dataparallel.py ex

Salesforce 1.3k Dec 28, 2022
Pytorch implementation of "Attention-Based Recurrent Neural Network Models for Joint Intent Detection and Slot Filling"

RNN-for-Joint-NLU Pytorch implementation of "Attention-Based Recurrent Neural Network Models for Joint Intent Detection and Slot Filling"

Kim SungDong 194 Dec 28, 2022
PyTorch implementation of Hierarchical Multi-label Text Classification: An Attention-based Recurrent Network

hierarchical-multi-label-text-classification-pytorch Hierarchical Multi-label Text Classification: An Attention-based Recurrent Network Approach This

Mingu Kang 17 Dec 13, 2022
Code for the paper "How Attentive are Graph Attention Networks?"

How Attentive are Graph Attention Networks? This repository is the official implementation of How Attentive are Graph Attention Networks?. The PyTorch

null 175 Dec 29, 2022
code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

code for "AttentiveNAS Improving Neural Architecture Search via Attentive Sampling"

Facebook Research 94 Oct 26, 2022
A Structured Self-attentive Sentence Embedding

Structured Self-attentive sentence embeddings Implementation for the paper A Structured Self-Attentive Sentence Embedding, which was published in ICLR

Kaushal Shetty 488 Nov 28, 2022
A framework for attentive explainable deep learning on tabular data

?? kendrite A framework for attentive explainable deep learning on tabular data ?? Quick start kedro run ?? Built upon Technology Description Links ke

Marnix Koops 3 Nov 6, 2021
Locally Constrained Self-Attentive Sequential Recommendation

LOCKER This is the pytorch implementation of this paper: Locally Constrained Self-Attentive Sequential Recommendation. Zhankui He, Handong Zhao, Zhe L

Zhankui (Aaron) He 8 Jul 30, 2022
Keyword-BERT: Keyword-Attentive Deep Semantic Matching

project discription An implementation of the Keyword-BERT model mentioned in my paper Keyword-Attentive Deep Semantic Matching (Plz cite this github r

null 1 Nov 14, 2021
This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

DeepMind 892 Dec 28, 2022
Recurrent Variational Autoencoder that generates sequential data implemented with pytorch

Pytorch Recurrent Variational Autoencoder Model: This is the implementation of Samuel Bowman's Generating Sentences from a Continuous Space with Kim's

Daniil Gavrilov 347 Nov 14, 2022