This is a super simple visualization toolbox (script) for transformer attention visualization ✌

Overview

Trans_attention_vis

This is a super simple visualization toolbox (script) for transformer attention visualization

input images

1. How to prepare your attention matrix?

Just convert it to numpy array like this 👇

>>(3, 4, 6, 6) (3, 4, 6, 6) """ ">
# build an attetion matrixs as torch-output like
token_num = 6
case_num = 3
layer_num = 2
head_num = 4
attention_map_mhml = [np.stack([make_attention_map_mh(head_num, token_num)]*case_num, 0) for _ in range(layer_num)] # 4cases' 3 layers attention, with 3 head per layer( 每个case相同)
_ = [print(i.shape) for i in attention_map_mhml]

"""
>>>(3, 4, 6, 6)
(3, 4, 6, 6)
"""

2. Just try the following lines of code 👇

# import function
from transformer_attention_visualization import *

# build canvas
scale = 3
canvas = np.zeros([120*scale,60*scale]).astype(np.float)

# build an attetion matrixs as torch-output like
token_num = 6
case_num = 3
layer_num = 2
head_num = 4
attention_map_mhml = [np.stack([make_attention_map_mh(head_num, token_num)]*case_num, 0) for _ in range(layer_num)] # 4cases' 3 layers attention, with 3 head per layer( 每个case相同)

# run for getting visualization picture (on the canvas)
import datetime
tic = datetime.datetime.now()
attention_vis2 = draw_attention_map_multihead_multilayer(canvas, attention_map_mhml, line_width=0.007)
toc = datetime.datetime.now()
h, remainder = divmod((toc - tic).seconds, 3600)
m, s = divmod(remainder, 60)
time_str2 = "Cost Time %02d h:%02d m:%02d s" % (h, m, s)
print(time_str2)


# show the visualization result
import matplotlib.pyplot as plt
def show2D(img2D, mode = None):
    if mode is None:
        plt.imshow(img2D,cmap=plt.cm.gray)
    else:
        plt.imshow(img2D, cmap=plt.cm.jet)
    plt.show()

case_index = 1
layer_index = 1
head_index = 1
beta = 2  # much bigger, contrast gose much higher

show2D(attention_vis2[layer_index][case_index][0][0]**beta)
show2D(attention_vis2[layer_index][case_index][1][0]**beta)
show2D(attention_vis2[layer_index][case_index][2][0]**beta)
show2D(attention_vis2[layer_index][case_index][3][0]**beta)
input images
You might also like...
Boltzmann visualization - Visualize the Boltzmann distribution for simple quantum models of molecular motion
Boltzmann visualization - Visualize the Boltzmann distribution for simple quantum models of molecular motion

Boltzmann visualization - Visualize the Boltzmann distribution for simple quantum models of molecular motion

Data Visualizer for Super Mario Kart (SNES)

Data Visualizer for Super Mario Kart (SNES)

A simple script that displays pixel-based animation on GitHub Activity
A simple script that displays pixel-based animation on GitHub Activity

GitHub Activity Animator This project contains a simple Javascript snippet that produces an animation on your GitHub activity tracker. The project als

A simple python script using Numpy and Matplotlib library to plot a Mohr's Circle when given a two-dimensional state of stress.
A simple python script using Numpy and Matplotlib library to plot a Mohr's Circle when given a two-dimensional state of stress.

Mohr's Circle Calculator This is a really small personal project done for Department of Civil Engineering, Delhi Technological University (formerly, D

Declarative statistical visualization library for Python
Declarative statistical visualization library for Python

Altair http://altair-viz.github.io Altair is a declarative statistical visualization library for Python. With Altair, you can spend more time understa

Interactive Data Visualization in the browser, from  Python
Interactive Data Visualization in the browser, from Python

Bokeh is an interactive visualization library for modern web browsers. It provides elegant, concise construction of versatile graphics, and affords hi

Statistical data visualization using matplotlib

seaborn: statistical data visualization Seaborn is a Python visualization library based on matplotlib. It provides a high-level interface for drawing

Fast data visualization and GUI tools for scientific / engineering applications

PyQtGraph A pure-Python graphics library for PyQt5/PyQt6/PySide2/PySide6 Copyright 2020 Luke Campagnola, University of North Carolina at Chapel Hill h

Apache Superset is a Data Visualization and Data Exploration Platform
Apache Superset is a Data Visualization and Data Exploration Platform

Superset A modern, enterprise-ready business intelligence web application. Why Superset? | Supported Databases | Installation and Configuration | Rele

Owner
Mingyu Wang
Mingyu Wang
Pydrawer: The Python package for visualizing curves and linear transformations in a super simple way

pydrawer ?? The Python package for visualizing curves and linear transformations in a super simple way. ✏️ Installation Install pydrawer package with

Dylan Tintenfich 56 Dec 30, 2022
Curvipy - The Python package for visualizing curves and linear transformations in a super simple way

Curvipy - The Python package for visualizing curves and linear transformations in a super simple way

Dylan Tintenfich 55 Dec 28, 2022
Python script to generate a visualization of various sorting algorithms, image or video.

sorting_algo_visualizer Python script to generate a visualization of various sorting algorithms, image or video.

null 146 Nov 12, 2022
A Python toolbox for gaining geometric insights into high-dimensional data

"To deal with hyper-planes in a 14 dimensional space, visualize a 3D space and say 'fourteen' very loudly. Everyone does it." - Geoff Hinton Overview

Contextual Dynamics Laboratory 1.8k Dec 29, 2022
python partial dependence plot toolbox

PDPbox python partial dependence plot toolbox Motivation This repository is inspired by ICEbox. The goal is to visualize the impact of certain feature

Li Jiangchun 723 Jan 7, 2023
A Python toolbox for gaining geometric insights into high-dimensional data

"To deal with hyper-planes in a 14 dimensional space, visualize a 3D space and say 'fourteen' very loudly. Everyone does it." - Geoff Hinton Overview

Contextual Dynamics Laboratory 1.6k Feb 17, 2021
python partial dependence plot toolbox

PDPbox python partial dependence plot toolbox Motivation This repository is inspired by ICEbox. The goal is to visualize the impact of certain feature

Li Jiangchun 531 Feb 16, 2021
Seismic Waveform Inversion Toolbox-1.0

Seismic Waveform Inversion Toolbox (SWIT-1.0)

Haipeng Li 98 Dec 29, 2022
Simple, realtime visualization of neural network training performance.

pastalog Simple, realtime visualization server for training neural networks. Use with Lasagne, Keras, Tensorflow, Torch, Theano, and basically everyth

Rewon Child 416 Dec 29, 2022
Simple spectra visualization tool for astronomers

SpecViewer A simple visualization tool for astronomers. Dependencies Python >= 3.7.4 PyQt5 >= 5.15.4 pyqtgraph == 0.10.0 numpy >= 1.19.4 How to use py

null 5 Oct 7, 2021