Pytorch implementation of "A simple neural network module for relational reasoning" (Relational Networks)

Overview

Pytorch implementation of Relational Networks - A simple neural network module for relational reasoning

Implemented & tested on Sort-of-CLEVR task.

Sort-of-CLEVR

Sort-of-CLEVR is simplified version of CLEVR.This is composed of 10000 images and 20 questions (10 relational questions and 10 non-relational questions) per each image. 6 colors (red, green, blue, orange, gray, yellow) are assigned to randomly chosen shape (square or circle), and placed in a image.

Non-relational questions are composed of 3 subtypes:

  1. Shape of certain colored object
  2. Horizontal location of certain colored object : whether it is on the left side of the image or right side of the image
  3. Vertical location of certain colored object : whether it is on the upside of the image or downside of the image

Theses questions are "non-relational" because the agent only need to focus on certain object.

Relational questions are composed of 3 subtypes:

  1. Shape of the object which is closest to the certain colored object
  2. Shape of the object which is furthest to the certain colored object
  3. Number of objects which have the same shape with the certain colored object

These questions are "relational" because the agent has to consider the relations between objects.

Questions are encoded into a vector of size of 11 : 6 for one-hot vector for certain color among 6 colors, 2 for one-hot vector of relational/non-relational questions. 3 for one-hot vector of 3 subtypes.

I.e., with the sample image shown, we can generate non-relational questions like:

  1. What is the shape of the red object? => Circle (even though it does not really look like "circle"...)
  2. Is green object placed on the left side of the image? => yes
  3. Is orange object placed on the upside of the image? => no

And relational questions:

  1. What is the shape of the object closest to the red object? => square
  2. What is the shape of the object furthest to the orange object? => circle
  3. How many objects have same shape with the blue object? => 3

Setup

Create conda environment from environment.yml file

$ conda env create -f environment.yml

Activate environment

$ conda activate RN3

If you don't use conda install python 3 normally and use pip install to install remaining dependencies. The list of dependencies can be found in the environment.yml file.

Usage

$ ./run.sh

or

$ python sort_of_clevr_generator.py

to generate sort-of-clevr dataset and

 $ python main.py 

to train the binary RN model. Alternatively, use

 $ python main.py --relation-type=ternary

to train the ternary RN model.

Modifications

In the original paper, Sort-of-CLEVR task used different model from CLEVR task. However, because model used CLEVR requires much less time to compute (network is much smaller), this model is used for Sort-of-CLEVR task.

Result

Relational Networks (20th epoch) CNN + MLP (without RN, 100th epoch)
Non-relational question 99% 66%
Relational question 89% 66%

CNN + MLP occured overfitting to the training data.

Relational networks shows far better results in relational questions and non-relation questions.

Contributions

@gngdb speeds up the model by 10 times.

Comments
  • Thanks. I have repeat your result, but I wander the result in the paper

    Thanks. I have repeat your result, but I wander the result in the paper

    Train Epoch: 18 [193280/196000 (99%)] Relations accuracy: 95% | Non-relations accuracy: 100% Train Epoch: 18 [194560/196000 (99%)] Relations accuracy: 86% | Non-relations accuracy: 100% Train Epoch: 18 [195840/196000 (100%)] Relations accuracy: 89% | Non-relations accuracy: 100%

    Test set: Relation accuracy: 89% | Non-relation accuracy: 100% Train Epoch: 19 [192000/196000 (98%)] Relations accuracy: 94% | Non-relations accuracy: 100% Train Epoch: 19 [193280/196000 (99%)] Relations accuracy: 80% | Non-relations accuracy: 100% Train Epoch: 19 [194560/196000 (99%)] Relations accuracy: 89% | Non-relations accuracy: 100% Train Epoch: 19 [195840/196000 (100%)] Relations accuracy: 91% | Non-relations accuracy: 100%

    Test set: Relation accuracy: 90% | Non-relation accuracy: 99% Train Epoch: 20 [193280/196000 (99%)] Relations accuracy: 91% | Non-relations accuracy: 100% Train Epoch: 20 [194560/196000 (99%)] Relations accuracy: 97% | Non-relations accuracy: 100% Train Epoch: 20 [195840/196000 (100%)] Relations accuracy: 95% | Non-relations accuracy: 100%

    Test set: Relation accuracy: 89% | Non-relation accuracy: 99%

    opened by robotzheng 5
  • About coord_tensor and np_coord_tensor part in model.py

    About coord_tensor and np_coord_tensor part in model.py

    Hi, the code is almost completely self-explanatory. However, I couldn't understand this part. Could you explain there ? Why you're creating coord_tensor and np_coord_tensor and what is the number 25 there ? I also would like to hear about lines 48-49-50. Edit: I have also another question about the implementation. Eventhough it is completely differ'rent, I don't want to open one more issue :) I guess translator.py file and function inside of it isn't used anywhere in the code. What was the aim of that file ?

    opened by bazingasherlock 4
  • Train Relational Networks for 10 epochs

    Train Relational Networks for 10 epochs

    I have trained RN for 10 epochs. The final test set accuracy is 73% for relational question and 72% for non-relational question. It seems that there is no significant improvement for relational questions.

    opened by xunhuang1995 4
  • Using Dataset

    Using Dataset

    Hi, I'm trying to find a source for the Sort-of-CLEVR Dataset. The provided code in this repository seems to be what I'm looking for, but I need help understanding how to set it up for training, validation, and testing. Could you provide a brief example of how the included code could be used to generate a training, validation, and testing set, and then from this, iterate through these datasets in batches of chosen size?

    opened by slerman12 3
  • AMI to run the code

    AMI to run the code

    Hi, is there any public AMI on amazon so that we can try the code ? I cannot run your code a couple of different pytorch AMIs since I got a couple of errors inside the main.py file.

    opened by bazingasherlock 3
  • Object coordinates missing

    Object coordinates missing

    From the article in the "Dealing with pixels" case:

    So, after convolving the image, each of the d^2 k-dimensional cells in the d × d feature maps was tagged with an arbitrary coordinate indicating its relative spatial position, and was treated as an object for the RN.

    Also, the author (/u/asantoro) confirmed on reddit that objects were of the form: [x, y, v_1, v_2, ..., v_k] where k is the number of kernels and the range of the coordinates x,y doesn't matter. (Reddit link)

    So I think in the model, object coordinates should be added to oi and oj. https://github.com/kimhc6028/relational-networks/blob/master/model.py#L53

    for i in range(25):
        oi = x[:,:,i/5,i%5]
        for j in range(25):
            oj = x[:,:,j/5,j%5]
            x_ = torch.cat((oi,oj,qst), 1)
            x_ = self.g_fc1(x_)
    

    I believe this should improve performance on questions where the spatial relationship between objects is important (closest, furthest, ...).

    opened by thomashenn 3
  • Add CNN-MLP Model

    Add CNN-MLP Model

    To have a clean A/B kind of model, this PR factors out the CNN input stage and the final FC stage. This highlights the differences between the two models clearly.

    The specific CNN-MLP implementation has been taken from yangky11's fork. However, since my branch was taken after the '10x speedup' patch, it includes that too.

    RN @ epoch 20 :  
       Test set: Relation accuracy: 87% | Non-relation accuracy: 99%
    CNN-MLP  @ epoch 100 :  
       Test set: Relation accuracy: 66% | Non-relation accuracy: 69%
    
    

    It's also now Python3-ready, and there are some other minor tweaks.

    What do you think?

    opened by mdda 2
  • Variables and their requires_grad flags

    Variables and their requires_grad flags

    Why all of your variables here and here's requires_grad argumand set to False ? You set requires_grad parameter of coord_tensor variable to false and then you concatenate it with the output of cnn (defined as x),whose requires_grad arguman is true by default I guess, at line 78. In this case what is the requires_grad parameter of concatenated variable (output of line 78 which is also x_flat)?

    opened by bazingasherlock 2
  • Dealing with pixels misunderstanding ?

    Dealing with pixels misunderstanding ?

    Hello sir, I am reading the paper and your code. However there is one thing I dont understand from your implementation. In the paper, at the Section 4 (Dealing with pixels), the author said that: "So after convolving the image, each of the d2 k-dimensional cells in the d x d feature maps was tagged with an arbitrary coordinate ...". So my question is in your code, which part is referring to this and would you please explain more about it since it is kind a difficult for me to understand it. Thanks !

    opened by phongnhhn92 2
  • Magic Number Question

    Magic Number Question

    Hi, in line 21 I had a hard time understanding the calculation (24+2)*2+11 The input of the linear layer should be two objects each encoded as a vector size of 24. Where did the +2 and +11 came from?

    Thanks!

    opened by talbaumel 2
  • Training details

    Training details

    Regarding the training procedure for the entire CLEVR, how did you manage to train pixel and state description stages? i.e., did you train end-to-end the whole system (LSTMs, ConvInputModel, and RN)? Or was it by stages?

    Another question off the topic: What is the purpose of coord_oi and coord_oj

    Thank you! Great implementation by the way.

    opened by affromero 1
  • Failure to replicate results

    Failure to replicate results

    I wanted to report that I didn't manage to replicate the results in the paper or in the repo. Relation accuracy: 80% Non-Relation accuracy: 93%

    trained for 20 epochs, with all default arguments

    opened by ValerioB88 0
  • Question on sort_of_clevr_generator

    Question on sort_of_clevr_generator "count+4"

    Hi, thanks for your work and sharing of the code!!! I have on question on data generation part, I know the questions and answers are represented in one-hot vectors where questions = 2 x (6 for one-hot vector of color), 3 for question type(binary, ternary, norel), 3 for question subtype answers = yes, no, rectangle, circle, r, g, b, o, k, y

    My question is why you use count+4 in here bianry question-subtype 3, which is as follows:

      elif subtype == 2:
          """count->1~6"""
          my_obj = objects[color][2]
          count = -1
          for obj in objects:
              if obj[2] == my_obj:
                  count +=1 
          answer = count+4
    

    As I understand, the count is already the number of Number of objects which have the same shape with the certain colored object. The +4 in [ yes, no, rectangle, circle, r, g, b, o, k, y] means the colors?

    Any help would be appreciated and thanks for your time

    opened by lizhenstat 1
  • Theoretical question: can RN generalize to shape-color combinations not previously seen in training data?

    Theoretical question: can RN generalize to shape-color combinations not previously seen in training data?

    Say, if in the training data the RN saw many green circles, and many yellow triangles, but it never saw green triangles or yellow circles, would the RN perform well if asked a question about those shape-color combinations it never saw in the training data? In other words, can the RN learn the abstract concept of "color", the abstract concept of "shape" and generalize those concepts to understand new questions involving novel color-shape combinations?

    opened by PabloMessina 0
  • Question about input to relational network

    Question about input to relational network

    I do not understand well how you pair the objects. My question is if each pair of objects is feed as input once or twice. If it is fed once then how do you choose the order? to be independent of the order each pair should be fed twice, right? Otherwise there is some break of symmetry in the input pair. I could not understand this from the original article, and the code is still obscure to me to understand this. thanks! UPDATE: I think I got it know, you input all pairs twice (in different order each time), and also each object with itself, thus n^2 instead of n(n-1)/2, Is this correct?

    opened by fernande2000 0
Owner
Kim Heecheol
University of Tokyo, Intelligent systems & Informatics Lab.
Kim Heecheol
A PyTorch implementation of the Relational Graph Convolutional Network (RGCN).

Torch-RGCN Torch-RGCN is a PyTorch implementation of the RGCN, originally proposed by Schlichtkrull et al. in Modeling Relational Data with Graph Conv

Thiviyan Singam 66 Nov 30, 2022
(ICCV'21) Official PyTorch implementation of Relational Embedding for Few-Shot Classification

Relational Embedding for Few-Shot Classification (ICCV 2021) Dahyun Kang, Heeseung Kwon, Juhong Min, Minsu Cho [paper], [project hompage] We propose t

Dahyun Kang 82 Dec 24, 2022
[ICCV 2021] Official PyTorch implementation for Deep Relational Metric Learning.

Deep Relational Metric Learning This repository is the official PyTorch implementation of Deep Relational Metric Learning. Framework Datasets CUB-200-

Borui Zhang 39 Dec 10, 2022
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

mandos 43 Dec 7, 2022
Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The original code is written in keras.

CasRel-pytorch-reimplement Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The o

longlongman 170 Dec 1, 2022
This folder contains the implementation of the multi-relational attribute propagation algorithm.

MrAP This folder contains the implementation of the multi-relational attribute propagation algorithm. It requires the package pytorch-scatter. Please

null 6 Dec 6, 2022
ReLoss - Official implementation for paper "Relational Surrogate Loss Learning" ICLR 2022

Relational Surrogate Loss Learning (ReLoss) Official implementation for paper "R

Tao Huang 31 Nov 22, 2022
Temporal-Relational CrossTransformers

Temporal-Relational Cross-Transformers (TRX) This repo contains code for the method introduced in the paper: Temporal-Relational CrossTransformers for

null 83 Dec 12, 2022
Code accompanying "Dynamic Neural Relational Inference" from CVPR 2020

Code accompanying "Dynamic Neural Relational Inference" This codebase accompanies the paper "Dynamic Neural Relational Inference" from CVPR 2020. This

Colin Graber 48 Dec 23, 2022
[CVPR 2021 Oral] Variational Relational Point Completion Network

VRCNet: Variational Relational Point Completion Network This repository contains the PyTorch implementation of the paper: Variational Relational Point

PL 121 Dec 12, 2022
Code for the paper "Query Embedding on Hyper-relational Knowledge Graphs"

Query Embedding on Hyper-Relational Knowledge Graphs This repository contains the code used for the experiments in the paper Query Embedding on Hyper-

DimitrisAlivas 19 Jul 26, 2022
ReSSL: Relational Self-Supervised Learning with Weak Augmentation

ReSSL: Relational Self-Supervised Learning with Weak Augmentation This repository contains PyTorch evaluation code, training code and pretrained model

mingkai 45 Oct 25, 2022
Exploring Relational Context for Multi-Task Dense Prediction [ICCV 2021]

Adaptive Task-Relational Context (ATRC) This repository provides source code for the ICCV 2021 paper Exploring Relational Context for Multi-Task Dense

David Brüggemann 35 Dec 5, 2022
Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).

Relation Prediction as an Auxiliary Training Objective for Knowledge Base Completion This repo provides the code for the paper Relation Prediction as

Facebook Research 85 Jan 2, 2023
This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of Coordinate Independent Convolutional Networks.

Orientation independent Möbius CNNs This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of

Maurice Weiler 59 Dec 9, 2022
UAV-Networks-Routing is a Python simulator for experimenting routing algorithms and mac protocols on unmanned aerial vehicle networks.

UAV-Networks Simulator - Autonomous Networking - A.A. 20/21 UAV-Networks-Routing is a Python simulator for experimenting routing algorithms and mac pr

null 0 Nov 13, 2021
Complex-Valued Neural Networks (CVNN)Complex-Valued Neural Networks (CVNN)

Complex-Valued Neural Networks (CVNN) Done by @NEGU93 - J. Agustin Barrachina Using this library, the only difference with a Tensorflow code is that y

youceF 1 Nov 12, 2021
A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks

A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks without the use of any outside machine learning libraries - all from scratch.

Kordel K. France 2 Nov 14, 2022
Official PyTorch implementation of Joint Object Detection and Multi-Object Tracking with Graph Neural Networks

This is the official PyTorch implementation of our paper: "Joint Object Detection and Multi-Object Tracking with Graph Neural Networks". Our project website and video demos are here.

Richard Wang 443 Dec 6, 2022