A PyTorch-based R-YOLOv4 implementation which combines YOLOv4 model and loss function from R3Det for arbitrary oriented object detection.

Overview

R-YOLOv4

This is a PyTorch-based R-YOLOv4 implementation which combines YOLOv4 model and loss function from R3Det for arbitrary oriented object detection. (Final project for NCKU INTRODUCTION TO ARTIFICIAL INTELLIGENCE course)

Introduction

The objective of this project is to adapt YOLOv4 model to detecting oriented objects. As a result, modifying the original loss function of the model is required. I got a successful result by increasing the number of anchor boxes with different rotating angle and combining smooth-L1-IoU loss function proposed by R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating Object into the original loss for bounding boxes.

Features


Loss Function (only for x, y, w, h, theta)

loss

angle


Scheduler

Cosine Annealing with Warmup (Reference: Cosine Annealing with Warmup for PyTorch)
scheduler


Recall

recall

As the paper suggested, I get a better results from **f(ariou) = exp(1-ariou)-1**. Therefore I used it for my loss function.

Usage

  1. Clone and Setup Environment

    $ git clone https://github.com/kunnnnethan/R-YOLOv4.git
    $ cd R-YOLOv4/
    

    Create Conda Environment

    $ conda env create -f environment.yml
    

    Create Python Virtual Environment

    $ python3.8 -m venv (your environment name)
    $ source ~/your-environment-name/bin/activate
    $ pip3 install torch torchvision torchaudio
    $ pip install -r requirements.txt
    
  2. Download pretrained weights
    weights

  3. Make sure your files arrangment looks like the following
    Note that each of your dataset folder in data should split into three files, namely train, test, and detect.

    R-YOLOv4/
    ├── train.py
    ├── test.py
    ├── detect.py
    ├── xml2txt.py
    ├── environment.xml
    ├── requirements.txt
    ├── model/
    ├── datasets/
    ├── lib/
    ├── outputs/
    ├── weights/
        ├── pretrained/ (for training)
        └── UCAS-AOD/ (for testing and detection)
    └── data/
        └── UCAS-AOD/
            ├── class.names
            ├── train/
                ├── ...png
                └── ...txt
            ├── test/
                ├── ...png
                └── ...txt
            └── detect/
                └── ...png
    
  4. Train, Test, and Detect
    Please refer to lib/options.py to check out all the arguments.

Train

I have implemented methods to load and train three different datasets. They are UCAS-AOD, DOTA, and custom dataset respectively. You can check out how I loaded those dataset into the model at /datasets. The angle of each bounding box is limited in (- pi/2, pi/2], and the height of each bounding box is always longer than it's width.

You can run experiments/display_inputs.py to visualize whether your data is loaded successfully.

UCAS-AOD dataset

Please refer to this repository to rearrange files so that it can be loaded and trained by this model.
You can download the weight that I trained from UCAS-AOD.

While training, please specify which dataset you are using.
$ python train.py --dataset UCAS_AOD

DOTA dataset

Download the official dataset from here. The original files should be able to be loaded and trained by this model.

While training, please specify which dataset you are using.
$ python train.py --dataset DOTA

Train with custom dataset

  1. Use labelImg2 to help label your data. labelImg2 is capable of labeling rotated objects.
  2. Move your data folder into the R-YOLOv4/data folder.
  3. Run xml2txt.py
    1. generate txt files: python xml2txt.py --data_folder your-path --action gen_txt
    2. delete xml files: python xml2txt.py --data_folder your-path --action del_xml

A trash custom dataset that I made and the weight trained from it are provided for your convenience.

While training, please specify which dataset you are using.
$ python train.py --dataset custom

Training Log

---- [Epoch 2/2] ----
+---------------+--------------------+---------------------+---------------------+----------------------+
| Step: 596/600 | loss               | reg_loss            | conf_loss           | cls_loss             |
+---------------+--------------------+---------------------+---------------------+----------------------+
| YoloLayer1    | 0.4302629232406616 | 0.32991039752960205 | 0.09135108441114426 | 0.009001442231237888 |
| YoloLayer2    | 0.7385762333869934 | 0.5682911276817322  | 0.15651139616966248 | 0.013773750513792038 |
| YoloLayer3    | 1.5002599954605103 | 1.1116538047790527  | 0.36262497305870056 | 0.025981156155467033 |
+---------------+--------------------+---------------------+---------------------+----------------------+
Total Loss: 2.669099, Runtime: 404.888372

Tensorboard

If you would like to use tensorboard for tracking traing process.

  • Open additional terminal in the same folder where you are running program.
  • Run command $ tensorboard --logdir='weights/your_model_name/logs' --port=6006
  • Go to http://localhost:6006/

Results

UCAS_AOD

Method Plane Car mAP
YOLOv4 (smoothL1-iou) 98.05 92.05 95.05

car

plane

DOTA

DOTA have not been tested yet. (It's quite difficult to test because of large resolution of images) DOTADOTA

trash (custom dataset)

Method Plane Car mAP
YOLOv4 (smoothL1-iou) 100.00 100.00 100.00

garbage1

garbage2

TODO

  • Mosaic Augmentation
  • Mixup Augmentation

References

yangxue0827/RotationDetection
eriklindernoren/PyTorch-YOLOv3
Tianxiaomo/pytorch-YOLOv4
ultralytics/yolov5

YOLOv4: Optimal Speed and Accuracy of Object Detection

Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao

Abstract There are a huge number of features which are said to improve Convolutional Neural Network (CNN) accuracy. Practical testing of combinations of such features on large datasets, and theoretical justification of the result, is required. Some features operate on certain models exclusively and for certain problems exclusively, or only for small-scale datasets; while some features, such as batch-normalization and residual-connections, are applicable to the majority of models, tasks, and datasets...

@article{yolov4,
  title={YOLOv4: Optimal Speed and Accuracy of Object Detection},
  author={Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao},
  journal = {arXiv},
  year={2020}
}

R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating Object

Xue Yang, Junchi Yan, Ziming Feng, Tao He

Abstract Rotation detection is a challenging task due to the difficulties of locating the multi-angle objects and separating them effectively from the background. Though considerable progress has been made, for practical settings, there still exist challenges for rotating objects with large aspect ratio, dense distribution and category extremely imbalance. In this paper, we propose an end-to-end refined single-stage rotation detector for fast and accurate object detection by using a progressive regression approach from coarse to fine granularity...

@article{r3det,
  title={R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating Object},
  author={Xue Yang, Junchi Yan, Ziming Feng, Tao He},
  journal = {arXiv},
  year={2019}
}
Comments
  • Extra classes

    Extra classes

    Do you think it would be possible to train this on my own dataset or add extra classes to this? What would I need to change to have more than plane and car classes?

    Thank you

    opened by candyotter 2
  • RuntimeError

    RuntimeError

    大佬救救我,该改哪呀,万分感谢!!!提前祝您春节快乐!!!

    Traceback (most recent call last): File "/home/zero/pjpompom/R-YOLOv4-main(test)/test.py", line 65, in sample_metrics += get_batch_statistics(outputs, targets, iou_threshold=args.iou_thres) File "/home/zero/pjpompom/R-YOLOv4-main(test)/tools/utils.py", line 191, in get_batch_statistics if pred_label not in target_labels: File "/home/zero/anaconda3/envs/pjpompom/lib/python3.8/site-packages/torch/tensor.py", line 646, in contains return (element == self).any().item() # type: ignore[union-attr] RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

    opened by PJPomPom 1
  • 请问您的yolov4的预训练权重是在什么数据集上训练的?

    请问您的yolov4的预训练权重是在什么数据集上训练的?

    顺便给个小建议,tools/logger.py 中用的是tensorflow 不过我个人不太想再去装tensorflow 您可以了解一下torch.utils.tensorboard

    # 官方示例代码
    from torch.utils.tensorboard import SummaryWriter
    import numpy as np
    
    writer = SummaryWriter()
    
    for n_iter in range(100):
        writer.add_scalar('Loss/train', np.random.random(), n_iter)
    
    opened by UnlightedOtaku 1
  • Visualize boxes

    Visualize boxes

    This is great!

    How did you visualize the bounding boxes on the images at the end of detect.py ? Should I be getting a txt file out with the labels?

    Thank you!

    opened by candyotter 1
  • 训练100个epoch后验证时ap都为0,推理时基本都是错误的

    训练100个epoch后验证时ap都为0,推理时基本都是错误的

    大佬您好,能否帮我看看。训练采用自己的数据集(200张),标注完并使用您提供的xml2txt.py转换为txt文件后,运行display_inputs.py检查训练集的图片没有问题。训练完后打开TensorBoard各项loss也基本收敛了,各类别的precision和recall都显示正常。但是运行test.py验证时AP均为0,以及detect.py推理时都是错误的杂框,比自己标注的gt要小很多。下载了您的trash数据集并训练10个epoch后ap就以及很高了,训练及验证时的参数也与您的基本一致,没有改动。

    我感觉还是数据集转换的问题,但是找不到具体错在哪里,您方便帮我看一下吗,感谢! mAP为0:mAP 部分PR如下:PR display_inputs.py转换结果: img 推理错误框: infer

    opened by sakamoto111 0
  • Reg loss calculation

    Reg loss calculation

    The regression loss considered is smooth_l1_loss (for angle) - ciou (for xywh).

    Considering that, should not the ciou loss calculation be ciou = bbox_loss_scale * (0.0 - ciou) instead of ciou = bbox_loss_scale * (1.0 - ciou) in below line? https://github.com/kunnnnethan/R-YOLOv4/blob/ab85440b135cd029c8151d6a0d120632db0b35bc/model/yololayer.py#L113

    https://github.com/kunnnnethan/R-YOLOv4/blob/ab85440b135cd029c8151d6a0d120632db0b35bc/model/yololayer.py#L119

    https://github.com/kunnnnethan/R-YOLOv4/blob/ab85440b135cd029c8151d6a0d120632db0b35bc/model/yololayer.py#L197

    With the current implementation, the effective regression loss calculation is smooth_l1_loss (for angle) + 1 - ciou (for xywh).

    opened by sonalrpatel 1
  • 训练custom数据集的问题

    训练custom数据集的问题

    大佬你好,我训练自己的数据集运行train.py文件时出现以下报错: RuntimeError: [enforce fail at inline_container.cc:145] . PytorchStreamReader failed reading zip archive: failed finding central directory pytorch的版本是1.7.0,请问是版本的原因吗

    opened by ciomphzx 9
  • Training while labeling with label-studio

    Training while labeling with label-studio

    Hi! I'm trying to implement your project as a ML backend for label-studio and I'm having some trouble. Predicting labels works without any problems and even training will work the first time. But when I try to train a second time I'll get the following error:

    RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 7.79 GiB total capacity; 2.62 GiB already allocated; 37.62 MiB free; 2.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

    This is my implentation of the ML backend:

    import os, sys
    currentdir = os.path.dirname(os.path.realpath(__file__))
    parentdir = os.path.dirname(currentdir)
    sys.path.append(parentdir)
    
    from label_studio_ml.model import LabelStudioMLBase
    from label_studio_ml.utils import get_image_size, get_single_tag_keys, is_skipped
    from label_studio.core.utils.io import json_load, get_data_dir 
    from label_studio.core.settings.base import DATA_UNDEFINED_NAME
    
    import time
    import random
    import numpy as np
    import torch
    import shutil
    import json
    from terminaltables import AsciiTable
    import glob
    
    from model.yolo import Yolo
    from lib.utils import load_class_names
    from lib.scheduler import CosineAnnealingWarmupRestarts
    from lib.post_process import post_process
    from lib.logger import *
    from lib.options import LabelStudioOptions
    from lib.plot import rescale_boxes
    import label_studio_sdk
    from datasets.label_studio_dataset import ImageDataset, LabelStudioDataset, get_transformed_image
    import cv2 as cv
    
    from urllib.parse import urlparse
    
    from PIL import Image
    
    print("LabelStudioSdk Version: ", label_studio_sdk.__version__)
    
    LABEL_STUDIO_HOST = os.getenv('LABEL_STUDIO_HOST', 'http://localhost:8080')
    LABEL_STUDIO_API_KEY = os.getenv('LABEL_STUDIO_API_KEY', '4c23feec13e2118e053b9a9940f73ed96c0e0841')
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    def weights_init_normal(m):
        if isinstance(m, torch.nn.Conv2d):
            torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, torch.nn.BatchNorm2d):
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
    
    def init():
        random.seed(42)
        np.random.seed(42)
        torch.manual_seed(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    class RotBBoxModel(object):
        def __init__(self, num_classes, args):
            self.args = args
            
            self.model = Yolo(n_classes=num_classes)
            self.model = self.model.to(device)
    
            self.logger = None
            self.model_path = None
    
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
    
        def log(self, total_loss, num_epochs, epoch, global_step, total_step, start_time):
            log = "\n---- [Epoch %d/%d] ----\n" % (epoch + 1, num_epochs)
    
            tensorboard_log = {}
            loss_table_name = ["Step: %d/%d" % (global_step, total_step),
                                "loss", "reg_loss", "conf_loss", "cls_loss"]
            loss_table = [loss_table_name]
    
            temp = ["YoloLayer1"]
            for name, metric in self.model.yolo1.metrics.items():
                if name in loss_table_name:
                    temp.append(metric)
                tensorboard_log[f"{name}_1"] = metric
            loss_table.append(temp)
    
            temp = ["YoloLayer2"]
            for name, metric in self.model.yolo2.metrics.items():
                if name in loss_table_name:
                    temp.append(metric)
                tensorboard_log[f"{name}_2"] = metric
            loss_table.append(temp)
    
            temp = ["YoloLayer3"]
            for name, metric in self.model.yolo3.metrics.items():
                if name in loss_table_name:
                    temp.append(metric)
                tensorboard_log[f"{name}_3"] = metric
            loss_table.append(temp)
    
            tensorboard_log["total_loss"] = total_loss
            self.logger.list_of_scalars_summary(tensorboard_log, global_step)
    
            log += AsciiTable(loss_table).table
            log += "\nTotal Loss: %f, Runtime: %f\n" % (total_loss, time.time() - start_time)
            print(log)
    
        def save(self, path):
            print("Model saved in: ", path)
            torch.save(self.model.state_dict(), path)
    
        def load(self, path, train=False):
            print("Loading model...")
            if not train:
                print("Loading model for prediction...")
                self.model_path = path
                if os.path.exists(self.model_path):
                    weight_path = glob.glob(os.path.join(self.model_path, "*.pth"))
                    if len(weight_path) == 0:
                        assert False, "Model weight not found"
                    elif len(weight_path) > 1:
                        assert False, "Multiple weights are found. Please keep only one weight in your model directory"
                    else:
                        weight_path = weight_path[0]
                else:
                    assert False, "Model is not exist"
                pretrained_dict = torch.load(weight_path, map_location=device)
                self.model.load_state_dict(pretrained_dict)
                self.model.eval()
            else:
                print("Loading model for training...")
                # if os.path.exists(path):
                #     weight_path = glob.glob(os.path.join(path, "*.pth"))[0]
                #     print("weight_path: ", weight_path)
                # else:
                #     print("Path does not exist")
                weight_path = "weights/pretrained/yolov4.pth"
                pretrained_dict = torch.load(weight_path, map_location=device)
                model_dict = self.model.state_dict()
    
                # 1. filter out unnecessary keys
                # pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
                pretrained_dict = {k: v for i, (k, v) in enumerate(pretrained_dict.items()) if i < 552}
                # 2. overwrite entries in the existing state dict
                model_dict.update(pretrained_dict)
                # 3. load the new state dict
                self.model.apply(weights_init_normal)
                self.model.load_state_dict(model_dict)
                self.model.eval()
    
        def predict(self, image_urls):
            images = torch.stack([get_transformed_image(url, self.args.img_size) for url in image_urls]).to(device)
            
            with torch.no_grad():
                temp = time.time()
                output, _ = self.model(images)  # batch=1 -> [1, n, n], batch=3 -> [3, n, n]
                temp1 = time.time()
                boxes = post_process(output, self.args.conf_thres, self.args.nms_thres)
                temp2 = time.time()
                print('-----------------------------------')
                num = 0
                for b in boxes:
                    if b is None:
                        break
                    num += len(b)
                print("{} objects found".format(num))
                print("Inference time : ", round(temp1 - temp, 5))
                print("Post-processing time : ", round(temp2 - temp1, 5))
                print('-----------------------------------')
                return boxes
    
        def train(self, dataloader, num_epochs=5):
            init()
            if(self.model_path == None):
                self.model_path = os.path.join("weights", self.args.model_name)
            self.logger = Logger(os.path.join(self.model_path, "logs"))
    
            num_iters_per_epoch = len(dataloader)
            scheduler_iters = round(num_epochs * len(dataloader) / self.args.subdivisions)
            total_step = num_iters_per_epoch * num_epochs
    
            scheduler = CosineAnnealingWarmupRestarts(self.optimizer,
                                                    first_cycle_steps=round(scheduler_iters),
                                                    max_lr=self.args.lr,
                                                    min_lr=1e-5,
                                                    warmup_steps=round(scheduler_iters * 0.1),
                                                    cycle_mult=1,
                                                    gamma=1)
    
            start_time = time.time()
            self.model.train()
            for epoch in range(num_epochs):
                print('Epoch {}/{}'.format(epoch, num_epochs - 1))
                print('-' * 10)
    
                for batch, (_, imgs, targets) in enumerate(dataloader):
                    global_step = num_iters_per_epoch * epoch + batch + 1
                    imgs = imgs.to(device)
                    targets = targets.to(device)
    
                    outputs, loss = self.model(imgs, targets)
    
                    loss.backward()
                    total_loss = loss.detach().item()
    
                    if global_step % self.args.subdivisions == 0:
                        self.optimizer.step()
                        self.optimizer.zero_grad()
                        scheduler.step()
    
                    self.log(total_loss, num_epochs, epoch, global_step, total_step, start_time)
    
            print()
    
            time_elapsed = time.time() - start_time
            print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    
            return self.model
    
    class RotBBoxModelApi(LabelStudioMLBase):
        
        def __init__(self, **kwargs):
            # don't forget to initialize base class...
            super(RotBBoxModelApi, self).__init__(**kwargs)
            
            parser = LabelStudioOptions()
            self.args = parser.parse()
            
            self.from_name, self.to_name, self.value, self.labels_in_config = get_single_tag_keys(
                self.parsed_label_config, 'RectangleLabels', 'Image'
            )
    
            print("from_name: ", self.from_name)
            print("to_name: ", self.to_name)
            print("value: ", self.value)
            print("labels_in_config: ", self.labels_in_config)
            print("parsed_label_config: ", self.parsed_label_config)
            print("train_output: ", self.train_output)
    
            # self.model = RotBBoxModel(len(self.labels_in_config), self.args)
            # self.model_path = os.path.join("weights", self.args.model_name)
            # print(self.model_path)
            # self.model.load(self.model_path)
    
            if self.train_output:
                self.model = RotBBoxModel(len(self.labels_in_config), self.args)
                self.model.load(self.train_output['model_path'], self.train_output)
            else:
                self.model = RotBBoxModel(len(self.labels_in_config), self.args)
                model_path = os.path.join("weights", self.args.model_name)
                print(model_path)
                self.model.load(model_path)
    
        def reset_model(self):
            self.model = RotBBoxModel(len(self.labels_in_config), self.args)
            self.model_path = os.path.join("weights", self.args.model_name)
            self.model.load(self.model_path)
       
        def predict(self, tasks, **kwargs):
            """ This is where inference happens:
                model returns the list of predictions based on input list of tasks
    
                :param tasks: Label Studio tasks in JSON format
            """
            image_urls = [task['data'][self.value] for task in tasks]
            print(image_urls)
            model_results = self.model.predict(image_urls)
            results = []
            all_scores = []
            avg_score = 0
    
            for i, (url, box) in enumerate(zip(image_urls, model_results)):
                if box is not None:
                    image_path = self.get_local_path(url)
                    # image_shape = get_image_shape(url)
                    img_width, img_height = get_image_size(image_path)
                    boxes = rescale_boxes(box, self.args.img_size, (img_height, img_width))
                    boxes = np.array(boxes)
    
                    for i in range(len(boxes)):
                        bbox = boxes[i]
                        center_x, center_y, w, h, theta = bbox[0], bbox[1], bbox[2], bbox[3], bbox[4]
                        score = round(bbox[5] * bbox[6], 2)
                        cls_id = np.squeeze(int(bbox[7]))
    
                        # Calculate top left corner of rotated bbox (box center is origin)
                        left_local = -w/2
                        top_local = -h/2
                        rotated_left_local = np.cos(theta) * left_local - np.sin(theta) * top_local
                        rotated_top_local = np.sin(theta) * left_local + np.cos(theta) * top_local
                        rotated_left = center_x + rotated_left_local
                        rotated_top = center_y + rotated_top_local
    
                        x_percent = ( (rotated_left / img_width) * 100.0).item()
                        y_percent = ( (rotated_top / img_height) * 100.0).item()
                        w_percent = ( (w / img_width) * 100.0).item()
                        h_percent = ( (h / img_height) * 100.0).item()
    
                        results.append({
                            'from_name': self.from_name,
                            'to_name': self.to_name,
                            'type': 'rectanglelabels',
                            'value': {
                                'rectanglelabels': [self.labels_in_config[cls_id]],
                                'x': x_percent,
                                'y': y_percent,
                                'width': w_percent,
                                'height': h_percent,
                                'rotation': np.rad2deg(theta).item()
                            },
                            'score': score.item()
                        })
                        all_scores.append(score)
                    avg_score = sum(all_scores) / max(len(all_scores), 1)
            if(avg_score != 0):
                avg_score = avg_score.item()
            return [{
                'result': results,
                'score': avg_score
            }]
    
        def download_tasks(self, project):
            """
            Download all labeled tasks from project using the Label Studio SDK.
            Read more about SDK here https://labelstud.io/sdk/
            :param project: project ID
            :return:
            """
            ls = label_studio_sdk.Client(LABEL_STUDIO_HOST, LABEL_STUDIO_API_KEY)
            project = ls.get_project(id=project)
            tasks = project.get_labeled_tasks()
            return tasks
    
        def fit(self, tasks, workdir=None, batch_size=32, num_epochs=10, **kwargs):
            """
            This method is called each time an annotation is created or updated
            :param kwargs: contains "data" and "event" key, that could be used to retrieve project ID and annotation event type
                            (read more in https://labelstud.io/guide/webhook_reference.html#Annotation-Created)
            :return: dictionary with trained model artefacts that could be used further in code with self.train_output
            """
            if 'data' not in kwargs:
                raise KeyError(f'Project is not identified. Go to Project Settings -> Webhooks, and ensure you have "Send Payload" enabled')
            
            data = kwargs['data']
            project = data['project']['id']
            tasks = self.download_tasks(project)
            if len(tasks) > 0:
                print(f'{len(tasks)} labeled tasks downloaded for project {project}')
                
                image_urls, image_labels = [], []
                print('Collecting annotations...')
                for task in tasks:
                    
                    if is_skipped(task):
                        continue
    
                    filepath = self.get_local_path(task['data'][self.value])
                    image_urls.append(filepath)
                    image_labels.append(task['annotations'][0]['result'])
    
    
                # augment = False if self.args.no_augmentation else True
                # mosaic = False if self.args.no_mosaic else True
                # multiscale = False if self.args.no_multiscale else True
    
                augment = False
                mosaic = False
                multiscale = False
    
                print(f'Creating dataset with {len(image_urls)} images...')
                dataset = LabelStudioDataset(image_urls, image_labels, self.labels_in_config, 
                                             self.args.img_size, self.args.sample_size,
                                             augment=augment, mosaic=mosaic, multiscale=multiscale)
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=dataset.collate_fn)
                
                print('Train model...')
                self.reset_model()
                self.model.train(dataloader, num_epochs=num_epochs)
    
                print('Save model...')
                # model_path = os.path.join(workdir, 'model.pt')
                model_path = os.path.join(self.model_path, "ryolov4.pth")
                self.model.save(model_path)
    
                return {
                    'model_path': model_path, 
                    'labels': image_labels
                }
    
            else:
                print('No labeled tasks found: make some annotations...')
                return {}
    

    This is basically just your code combined from detect.py and train.py.

    The testing is performed with the trash dataset and a model that was also trained on it. I'm not really familiar with pytorch and don't know if I implemented it correctly for this kind of application. I guess that the out-of-memory error is caused by reloading the model without clearing some old variables first? I have no idea which though.

    Could you please take a look at it if you have the time? Maybe I'm just loading the model the wrong way.

    opened by Levaru 3
  • 大佬 请问你可以教我如何写一个test.py吗?

    大佬 请问你可以教我如何写一个test.py吗?

    你好,我是来自马来西亚的学生。你是我少数看到现在还活跃的yolov4大佬,我想跟你请教如何写一个test.py。因为我在跟着一个tianxiaomo的yolov4但是他没有做test.py。所以我的教授要求我去写一个,但是我看来看去(比如说看他的train.py,比如说看你的train.py和test.py,和别人的),我真的不懂我在看什么。(哭)

    opened by MheadHero 3
Owner
null
CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Detection in Remote Sensing Images

CFC-Net This project hosts the official implementation for the paper: CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Dete

ming71 55 Dec 12, 2022
Auto-Lama combines object detection and image inpainting to automate object removals

Auto-Lama Auto-Lama combines object detection and image inpainting to automate object removals. It is build on top of DE:TR from Facebook Research and

null 44 Dec 9, 2022
An implementation for the loss function proposed in Decoupled Contrastive Loss paper.

Decoupled-Contrastive-Learning This repository is an implementation for the loss function proposed in Decoupled Contrastive Loss paper. Requirements P

Ramin Nakhli 71 Dec 4, 2022
Official code for paper "Optimization for Oriented Object Detection via Representation Invariance Loss".

Optimization for Oriented Object Detection via Representation Invariance Loss By Qi Ming, Zhiqiang Zhou, Lingjuan Miao, Xue Yang, and Yunpeng Dong. Th

ming71 56 Nov 28, 2022
WHENet - ONNX, OpenVINO, TFLite, TensorRT, EdgeTPU, CoreML, TFJS, YOLOv4/YOLOv4-tiny-3L

HeadPoseEstimation-WHENet-yolov4-onnx-openvino ONNX, OpenVINO, TFLite, TensorRT, EdgeTPU, CoreML, TFJS, YOLOv4/YOLOv4-tiny-3L 1. Usage $ git clone htt

Katsuya Hyodo 49 Sep 21, 2022
Seach Losses of our paper 'Loss Function Discovery for Object Detection via Convergence-Simulation Driven Search', accepted by ICLR 2021.

CSE-Autoloss Designing proper loss functions for vision tasks has been a long-standing research direction to advance the capability of existing models

Peidong Liu(刘沛东) 54 Dec 17, 2022
OBBDetection is a oriented object detection library, which is based on MMdetection.

OBBDetection news: We are now updating OBBDetection to new vision based on MMdetection v2.10, which has more advanced models and more efficient featur

jbwang1997 401 Jan 2, 2023
Implementation of "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement" by pytorch

This repository is used to suspend the results of our paper "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement"

ScorpioMiku 19 Sep 30, 2022
PyTorch implementation of Soft-DTW: a Differentiable Loss Function for Time-Series in CUDA

Soft DTW Loss Function for PyTorch in CUDA This is a Pytorch Implementation of Soft-DTW: a Differentiable Loss Function for Time-Series which is batch

Keon Lee 76 Dec 20, 2022
DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Object Detection

DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Object Detection Code for our Paper DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Obje

Steven Lang 58 Dec 19, 2022
Implement of "Training deep neural networks via direct loss minimization" in PyTorch for 0-1 loss

This is the implementation of "Training deep neural networks via direct loss minimization" published at ICML 2016 in PyTorch. The implementation targe

Cuong Nguyen 1 Jan 18, 2022
Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation

SSWS-loss_function_based_on_MS-TCN Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation Supervised Sliding Window

null 3 Aug 3, 2022
A unofficial pytorch implementation of PAN(PSENet2): Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network

Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network Requirements pytorch 1.1+ torchvision 0.3+ pyclipper opencv3 gcc

zhoujun 400 Dec 26, 2022
Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

El Bruno 3 Mar 30, 2022
Implementing yolov4 target detection and tracking based on nao robot

Implementing yolov4 target detection and tracking based on nao robot

null 6 Apr 19, 2022
Official implementation for CVPR 2021 paper: Adaptive Class Suppression Loss for Long-Tail Object Detection

Adaptive Class Suppression Loss for Long-Tail Object Detection This repo is the official implementation for CVPR 2021 paper: Adaptive Class Suppressio

CASIA-IVA-Lab 67 Dec 4, 2022
Implementation of TransGanFormer, an all-attention GAN that combines the finding from the recent GanFormer and TransGan paper

TransGanFormer (wip) Implementation of TransGanFormer, an all-attention GAN that combines the finding from the recent GansFormer and TransGan paper. I

Phil Wang 146 Dec 6, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

null 32 Sep 21, 2022
Object tracking implemented with YOLOv4, DeepSort, and TensorFlow.

Object tracking implemented with YOLOv4, DeepSort, and TensorFlow. YOLOv4 is a state of the art algorithm that uses deep convolutional neural networks to perform object detections. We can take the output of YOLOv4 feed these object detections into Deep SORT (Simple Online and Realtime Tracking with a Deep Association Metric) in order to create a highly accurate object tracker.

The AI Guy 1.1k Dec 29, 2022