Code for our paper 'Generalized Category Discovery'

Overview

Generalized Category Discovery

This repo is a placeholder for code for our paper: Generalized Category Discovery

Abstract: In this paper, we consider a highly general image recognition setting wherein, given a labelled and unlabelled set of images, the task is to categorize all images in the unlabelled set. Here, the unlabelled images may come from labelled classes or from novel ones. Existing recognition methods are not able to deal with this setting, because they make several restrictive assumptions, such as the unlabelled instances only coming from known --- or unknown --- classes and the number of unknown classes being known a-priori. We address the more unconstrained setting, naming it `Generalized Category Discovery', and challenge all these assumptions. We first establish strong baselines by taking state-of-the-art algorithms from novel category discovery and adapting them for this task. Next, we propose the use of vision transformers with contrastive representation learning for this open world setting. We then introduce a simple yet effective semi-supervised $k$-means method to cluster the unlabelled data into seen and unseen classes automatically, substantially outperforming the baselines. Finally, we also propose a new approach to estimate the number of classes in the unlabelled data. We thoroughly evaluate our approach on public datasets for generic object classification including CIFAR10, CIFAR100 and ImageNet-100, and for fine-grained visual recognition including CUB, Stanford Cars and Herbarium19, benchmarking on this new setting to foster future research.

image

Code Coming Soon!

Comments
  • The some

    The some "Evaluation Metric" for seen classes and novel classes.

    Hi, thanks for sharing the great work! I found that the same Evaluation Metric, i.e., the Hungarian algorithm, is used for seen classes and novel classes. In fact, I think that the standard "classification accuracy" should be used for seen classes, to avoid the mis-matching issue on seen classes.

    opened by jingzhengli 6
  • Range-based masks are incompatiable with SSB splits

    Range-based masks are incompatiable with SSB splits

    In multiple places in the codebase, the mask to identify one class comes from the old or new classes is computed as:

    • https://github.com/sgvaze/generalized-category-discovery/blob/262519540c3452c6b8276dc203b4e9067378c552/methods/clustering/k_means.py#L48-L49
    • https://github.com/sgvaze/generalized-category-discovery/blob/262519540c3452c6b8276dc203b4e9067378c552/methods/contrastive_training/contrastive_training.py#L313-L314
    • https://github.com/sgvaze/generalized-category-discovery/blob/262519540c3452c6b8276dc203b4e9067378c552/methods/estimate_k/estimate_k.py#L49-L50
    • https://github.com/sgvaze/generalized-category-discovery/blob/262519540c3452c6b8276dc203b4e9067378c552/methods/estimate_k/estimate_k.py#L116-L117

    However, this is only compatible with the setting that all old classes are consecutive (e.g., 0, 1, ..., n_old_classes - 1). For the SSB splits where the old class indexes are non-consecutive, it can result in incorrect old-new splits.

    Influence of this bug: the old/new ACCs of CUB & Stanford Cars are not reliable (if use_ssb_splits==True), but the all ACC is not affected.

    Possible fix: replace range(len(args.train_classes)) with args.train_classes.

    opened by xwen99 4
  • one question in contrastive_train.py

    one question in contrastive_train.py

    Dear Sagar Vaze, I am reading your released code of your interesting paper Genearalized Cateogory Discovery. However, I seem to be confused about one implementation in the file 'contrastive_training.py'. In your function 'test_kmeans.py', when doing the KMeans, why the @n_clusters equals to the ground-truth class number here? Do you assume here you have access to the ground-truth cluster number or I mistakenly understand the whole pipeline? Kindly could you help to explain this?

    Thanks for any help you can give.

    opened by eatamath 3
  • Can't find the file: dino_vitbase16_pretrain.pth

    Can't find the file: dino_vitbase16_pretrain.pth

    Hi, thanks for sharing.

    But I met a problem that is I can't the file: dino_vitbase16_pretrain.pth, from the project folder. Could you please upload this file?

    Thanks a lot.

    opened by CYDping 2
  • class ids for imagenet-100

    class ids for imagenet-100

    Hi, Thank you for providing the code for your interesting work. I am trying to reproduce the results on imagenet-100, however, the class ids are not provided in this repository. In data/imagenet.py it seems class selection is random. Can you please share the 100 class ids(Imagenet folder names ) that you used to generate your results?

    opened by sesmae 2
  • 'small_train' and 'small_validation' of herbarium19

    'small_train' and 'small_validation' of herbarium19

    Hi, I was looking into the dataset for Herbarium19. After downloading the dataset , I can only see 3 splits :'train', 'validation' and 'test'. May I know what are the 'small_train' and 'small_validation' which you used in herbarium_19.py ? Also, each training epoch for herbarium19 takes a very long time to run. Is it the reason that you used a subset of it?

    opened by sesmae 1
  • table 2,3 number clarification

    table 2,3 number clarification

    Hi, great work.

    May I know if table 2 and 3 are based on clustering method mentioned in 3.1.2 or 3.2?

    The numbers in table 3, specfically on Herbarium19 shows (35.4 | 51.0 | 27.0) which is the same as table 5 (6), which suggest semi-sup k-means was used to generate the results for table 3. While, table 2 results on CIFAR100 shows (70.8 | 77.6 | 57.0 ) which is different from table 5. Thanks in advanced.

    opened by cassie101 1
  • slow semi-supervised clustering

    slow semi-supervised clustering

    Hi, Thanks for releasing the code for your interesting work. I am trying to run semi-supervised kmeans with bash bash_scripts/k_means.sh . but it seems to be pretty slow. Is this expected ?

    opened by sesmae 0
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    opened by TrellixVulnTeam 0
  • Some Questions about Binary_Search in estimate_k.py

    Some Questions about Binary_Search in estimate_k.py

    Thanks for your brilliant work! I wonder that whether the binary search function is valid. As known to me, the binary search can be utilized to estimate the minimum of function by searching the root of f'(x)=0, but only f(x)=ACC is given while f'(x) is not given. So I have some questions on it and hope to receive your response to it.

    opened by ascetic-monk 0
  • about the value of K

    about the value of K

    Thanks for the interesting research. Are the k values ​​used in the benchmarks(tabel 2,3) estimated? or GT? In my opinion, there will be a difference in acc depending on the k value, but in the code, it seems that the performance was measured through GT. What results were used in the paper? Also, is there a result of measuring the change according to the k value?

    화면 캡처 2022-11-04 135410

    opened by Backdrop9019 0
  • Could you release the code or checkpoints for reimplementing UNO+ and Rankstat+?

    Could you release the code or checkpoints for reimplementing UNO+ and Rankstat+?

    Reimplementing these two methods in the GCD setting is pretty cumbersome, as there are too many details to align. It would be greatly helpful if you could provide the code for implementing them or the trained models' checkpoints (or the models' predictions) for evaluation. Many thanks!

    opened by xwen99 0
  • The reproduced result

    The reproduced result

    I ran the experiment of table 2 using this repo and the final accuracy for all classes on CIFAR10 and CIFAR100 are only 60.28 and 38.60, respectively, which are much lower than the reported.

    To reproduce this paper, I ran the following three files sequentially: "contrastive_train.sh", "extract_features.sh", and "k_means.sh". Following the implementation details in the original paper, I used ViT-B-16 backbone with DINO pre-trained weights (weights are downloaded from this repo: https://github.com/facebookresearch/dino) and fine-tune the final transformer block. The total number of clusters for k-means clustering is set as the total number of classes.

    The default setting of this repo seems to be set on Stanford Cars, so can you give detailed parameters for CIFAR10/CIFAR100??

    opened by snow12345 0
  • Are the experiment results in table2 and table3 performed on unlabeled data of the training set? I found that k_means.py in your code to do just that?

    Are the experiment results in table2 and table3 performed on unlabeled data of the training set? I found that k_means.py in your code to do just that?

    Are the experiment results in table2 and table3 performed on unlabeled data of the training set? I found that k_means.py in your code to do just that?

    opened by ryylcc 1
Owner
null
Seach Losses of our paper 'Loss Function Discovery for Object Detection via Convergence-Simulation Driven Search', accepted by ICLR 2021.

CSE-Autoloss Designing proper loss functions for vision tasks has been a long-standing research direction to advance the capability of existing models

Peidong Liu(刘沛东) 54 Dec 17, 2022
This repository holds code and data for our PETS'22 article 'From "Onion Not Found" to Guard Discovery'.

From "Onion Not Found" to Guard Discovery (PETS'22) This repository holds the code and data for our PETS'22 paper titled 'From "Onion Not Found" to Gu

Lennart Oldenburg 3 May 4, 2022
This is the offical website for paper ''Category-consistent deep network learning for accurate vehicle logo recognition''

The Pytorch Implementation of Category-consistent deep network learning for accurate vehicle logo recognition This is the offical website for paper ''

Wanglong Lu 28 Oct 29, 2022
PyTorch implemention of ICCV'21 paper SGPA: Structure-Guided Prior Adaptation for Category-Level 6D Object Pose Estimation

SGPA: Structure-Guided Prior Adaptation for Category-Level 6D Object Pose Estimation This is the PyTorch implemention of ICCV'21 paper SGPA: Structure

Chen Kai 24 Dec 5, 2022
Official public repository of paper "Intention Adaptive Graph Neural Network for Category-Aware Session-Based Recommendation"

Intention Adaptive Graph Neural Network (IAGNN) This is the official repository of paper Intention Adaptive Graph Neural Network for Category-Aware Se

null 9 Nov 22, 2022
Code for 'Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning', ICCV 2021

CMIC-Retrieval Code for Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning. ICCV 2021. Introduction In this wo

null 42 Nov 17, 2022
code for ICCV 2021 paper 'Generalized Source-free Domain Adaptation'

G-SFDA Code (based on pytorch 1.3) for our ICCV 2021 paper 'Generalized Source-free Domain Adaptation'. [project] [paper]. Dataset preparing Download

Shiqi Yang 84 Dec 26, 2022
Official PyTorch implementation of CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds

CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds Introduction This is the official PyTorch implementation of o

Yijia Weng 96 Dec 7, 2022
Modeling Category-Selective Cortical Regions with Topographic Variational Autoencoders

Modeling Category-Selective Cortical Regions with Topographic Variational Autoencoders

null 1 Oct 11, 2021
Single-stage Keypoint-based Category-level Object Pose Estimation from an RGB Image

CenterPose Overview This repository is the official implementation of the paper "Single-stage Keypoint-based Category-level Object Pose Estimation fro

NVIDIA Research Projects 188 Dec 27, 2022
[ICRA 2022] CaTGrasp: Learning Category-Level Task-Relevant Grasping in Clutter from Simulation

This is the official implementation of our paper: Bowen Wen, Wenzhao Lian, Kostas Bekris, and Stefan Schaal. "CaTGrasp: Learning Category-Level Task-R

Bowen Wen 199 Jan 4, 2023
PyTorch implementation of our Adam-NSCL algorithm from our CVPR2021 (oral) paper "Training Networks in Null Space for Continual Learning"

Adam-NSCL This is a PyTorch implementation of Adam-NSCL algorithm for continual learning from our CVPR2021 (oral) paper: Title: Training Networks in N

Shipeng Wang 34 Dec 21, 2022
Pytorch implementation of paper "Learning Co-segmentation by Segment Swapping for Retrieval and Discovery"

SegSwap Pytorch implementation of paper "Learning Co-segmentation by Segment Swapping for Retrieval and Discovery" [PDF] [Project page] If our project

xshen 41 Dec 10, 2022
source code for https://arxiv.org/abs/2005.11248 "Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics"

Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics This work will be published in Nature Biomedical

International Business Machines 71 Nov 15, 2022
A generalized framework for prototyping full-stack cooperative driving automation applications under CARLA+SUMO.

OpenCDA OpenCDA is a SIMULATION tool integrated with a prototype cooperative driving automation (CDA; see SAE J3216) pipeline as well as regular autom

UCLA Mobility Lab 726 Dec 29, 2022
An official implementation of "Exploiting a Joint Embedding Space for Generalized Zero-Shot Semantic Segmentation" (ICCV 2021) in PyTorch.

Exploiting a Joint Embedding Space for Generalized Zero-Shot Semantic Segmentation This is an official implementation of the paper "Exploiting a Joint

CV Lab @ Yonsei University 35 Oct 26, 2022
GeDML is an easy-to-use generalized deep metric learning library

GeDML is an easy-to-use generalized deep metric learning library

Borui Zhang 32 Dec 5, 2022
Learnable Multi-level Frequency Decomposition and Hierarchical Attention Mechanism for Generalized Face Presentation Attack Detection

LMFD-PAD Note This is the official repository of the paper: LMFD-PAD: Learnable Multi-level Frequency Decomposition and Hierarchical Attention Mechani

null 28 Dec 2, 2022
Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation The code repository for "Audio-Visual Generalized Few-Shot Learning with

Kaiaicy 3 Jun 27, 2022