Collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning.

Overview

Collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning


InstallationDocsAboutPredictionFinetuningTasksGeneral TaskContributeCommunityWebsiteLicense

Stable API Documentation Status PyPI - Python Version PyPI Status PyPI Status Slack Discourse status license CI testing codecov


News

Read our launch blogpost


Installation

Pip / conda

pip install lightning-flash -U

Pip from source

# with git
pip install git+https://github.com/PytorchLightning/[email protected]
# OR from an archive
pip install https://github.com/PyTorchLightning/lightning-flash/archive/master.zip

From source using setuptools

# clone flash repository locally
git clone https://github.com/PyTorchLightning/lightning-flash.git
cd lightning-flash
# install in editable mode
pip install -e .

What is Flash

Flash is a framework of tasks for fast prototyping, baselining, finetuning and solving business and scientific problems with deep learning. It is focused on:

  • Predictions
  • Finetuning
  • Task-based training

It is built for data scientists, machine learning practitioners, and applied researchers.

Scalability

Flash is built on top of PyTorch Lightning (by the Lightning team), which is a thin organizational layer on top of PyTorch. If you know PyTorch, you know PyTorch Lightning and Flash already!

As a result, Flash can scale up across any hardware (GPUs, TPUS) with zero changes to your code. It also has the best practices in AI research embedded into each task so you don't have to be a deep learning PhD to leverage its power :)

Predictions

# import our libraries
from flash.text import TextClassifier

# 1. Load finetuned task
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado."
    "Very, very afraid"
    "This guy has done a great job with this movie!",
])
print(predictions)

Finetuning

First, finetune:

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    valid_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes, backbone="resnet18")

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")

Then use the finetuned model

# load the finetuned model
classifier = ImageClassifier.load_from_checkpoint('image_classification_model.pt')

# predict!
predictions = classifier.predict('data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg')
print(predictions)

Tasks

Flash is built as a collection of community-built tasks. A task is highly opinionated and laser-focused on solving a single problem well, using state-of-the-art methods.

Example 1: Image classification

Flash has an ImageClassification task to tackle any image classification problem.

View example To illustrate, Let's say we wanted to develop a model that could classify between ants and bees.

Here we classify ants vs bees.

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    valid_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")

# 6. Test the model
trainer.test()

# 7. Predict!
predictions = model.predict([
    "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
    "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
    "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

To run the example:

python flash_examples/finetuning/image_classifier.py

Example 2: Text Classification

Flash has a TextClassification task to tackle any text classification problem.

View example To illustrate, say you wanted to classify movie reviews as positive or negative.
import flash
from flash import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 2. Load the data
datamodule = TextClassificationData.from_files(
    train_file="data/imdb/train.csv",
    valid_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input="review",
    target="sentiment",
    batch_size=512
)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")

# 6. Test model
trainer.test()

# 7. Classify a few sentences! How was the movie?
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado."
    "Very, very afraid"
    "This guy has done a great job with this movie!",
])
print(predictions)

To run the example:

python flash_examples/finetuning/classify_text.py

Example 3: Tabular Classification

Flash has a TabularClassification task to tackle any tabular classification problem.

View example

To illustrate, say we want to build a model to predict if a passenger survived on the Titanic.

from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
import flash
from flash import download_data
from flash.tabular import TabularClassifier, TabularData

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

# 2. Load the data
datamodule = TabularData.from_csv(
    "./data/titanic/titanic.csv",
    test_csv="./data/titanic/test.csv",
    categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
    numerical_input=["Fare"],
    target="Survived",
    val_size=0.25,
)

# 3. Build the model
model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])

# 4. Create the trainer. Run 10 times on data
trainer = flash.Trainer(max_epochs=10)

# 5. Train the model
trainer.fit(model, datamodule=datamodule)

# 6. Test model
trainer.test()

# 7. Predict!
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)

To run the example:

python flash_examples/finetuning/tabular_data.py

A general task

Flash comes prebuilt with a task to handle a huge portion of deep learning problems.

import flash
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
import pytorch_lightning as pl

# model
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# data
dataset = datasets.MNIST('./data_folder', download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

# task
classifier = flash.Task(model, loss_fn=nn.functional.cross_entropy, optimizer=optim.Adam)

# train
flash.Trainer().fit(classifier, DataLoader(train), DataLoader(val))

Infinitely customizable

Tasks can be built in just a few minutes because Flash is built on top of PyTorch Lightning LightningModules, which are infinitely extensible and let you train across GPUs, TPUs etc without doing any code changes.

import torch
import torch.nn.functional as F
from flash.core.classification import ClassificationTask

class LinearClassifier(ClassificationTask):
    def __init__(
        self,
        num_inputs,
        num_classes,
        loss_fn: Callable = F.cross_entropy,
        optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
        metrics: Union[Callable, Mapping, Sequence, None] = [Accuracy()],
        learning_rate: float = 1e-3,
    ):
        super().__init__(
            model=None,
            loss_fn=loss_fn,
            optimizer=optimizer,
            metrics=metrics,
            learning_rate=learning_rate,
        )
        self.save_hyperparameters()

        self.linear = torch.nn.Linear(num_inputs, num_classes)

    def forward(self, x):
        return self.linear(x)

classifier = LinearClassifier()
...

When you reach the limits of the flexibility provided by tasks, then seamlessly transition to PyTorch Lightning which gives you the most flexibility because it is simply organized PyTorch.

Contribute!

The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we're looking for incredible contributors like you to submit new tasks!

Join our Slack to get help becoming a contributor!

Community

For help or questions, join our huge community on Slack!

Citations

We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffee, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors.

License

Please observe the Apache 2.0 license that is listed in this repository. In addition the Lightning framework is Patent Pending.

Issues
  • Adding support for loading datasets and visualizing model predictions via FiftyOne

    Adding support for loading datasets and visualizing model predictions via FiftyOne

    What does this PR do?

    Integrates Lightning Flash with FiftyOne, the open source dataset and model analysis library!

    Loading FiftyOne data into Flash

    This PR adds FiftyOneDataSources for image/video classification, object detection, semantic segmentation, and image embedding tasks that load FiftyOne Datasets into Flash.

    Loading Flash predictions into FiftyOne

    This PR adds Serializer implementations that can convert classification/detection/segmentation model outputs into the appropriate FiftyOne label types so that they can be added to FiftyOne datasets and visualized.

    Note

    This PR requires a source install of FiftyOne on this branch https://github.com/voxel51/fiftyone/pull/1059 in order to function.

    git clone https://github.com/voxel51/fiftyone
    cd fiftyone
    git checkout --track origin/flash-video
    bash install.bash
    

    The above branch also contains a parallel integration that enables FiftyOne users to add predictions from any Flash model to their datasets 😄

    Points of discussion

    1. It'd be great if these examples could be integrated into the Flash documentation/README in the appropriate places 😄

    2. The new FiftyoneDataSource classes introduced in this PR require a label_field argument to specify which field of the FiftyOne dataset should be used as the label field. To enable this, we added **data_source_kwargs to Flash's processor interface. Perhaps there's a better way to support this?

    3. When serializing object detections, Flash models seem to return bounding boxes in absolute coordinates, but FiftyOne expects bounding boxes in relative coordinates. Is it possible for FiftyOneDetectionLabels to access the dimensions of the current image when serialize() is called? Perhaps using set_state() as is done for class labels? The current implementation requires fiftyone.utils.flash.normalize_detections() to be manually called to convert to relative coordinates for import into FiftyOne, but it would be much cleaner if this could be done natively within FiftyOneDetectionLabels...

    Basic patterns

    The following subsections show the basic patterns enabled by this integration. See the next section for concrete examples of each task type.

    Loading data from FiftyOne into Flash

    FiftyOne users can load their datasets into Flash Data Sources via the pattern below:

    from flash.image import ImageClassificationData
    
    import fiftyone as fo
    
    train_dataset = fo.Dataset.from_dir(
        "/path/to/train",
        fo.types.ImageClassificationDirectoryTree,
        label_field="ground_truth",
    )
    
    val_dataset = fo.Dataset.from_dir(
        "/path/to/val",
        fo.types.ImageClassificationDirectoryTree,
        label_field="ground_truth",
    )
    
    datamodule = ImageClassificationData.from_fiftyone(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        label_field="ground_truth",
    )
    

    Visualizing Flash predictions in FiftyOne

    Flash users can swap out the serializer on their model with the corresponding FiftyOne serializer for the task type, and then visualize their predictions in the FiftyOne App via the pattern below:

    from flash import Trainer
    from flash.core.classification import FiftyOneLabels
    from flash.core.integrations.fiftyone import visualize
    from flash.video import VideoClassificationData, VideoClassifier
    
    classifier = VideoClassifier.load_from_checkpoint(...)
    
    # Option 1: Generate predictions using a Trainer and datamodule
    datamodule = VideoClassificationData.from_folders(
        predict_folder="/path/to/folder",
        ...
    )
    trainer = Trainer()
    classifier.serializer = FiftyOneLabels(return_filepath=True)
    predictions = trainer.predict(classifier, datamodule=datamodule)
    
    session = visualize(predictions) # Launch FiftyOne
    
    # Option 2: Generate predictions from model using filepaths
    filepaths = ["list", "of", "filepaths"]
    predictions = classifier.predict(filepaths)
    classifier.serializer = FiftyOneLabels()
    
    session = visualize(predictions, filepaths=filepaths) # Launch FiftyOne
    

    Applying Flash models to FiftyOne datasets

    In addition to this PR, https://github.com/voxel51/fiftyone/pull/1059 adds a parallel integration in the FiftyOne library that enables FiftyOne users to add predictions from any Flash model to their datasets via the pattern below:

    from flash.image import ObjectDetector
    
    import fiftyone as fo
    import fiftyone.zoo as foz
    
    dataset = foz.load_zoo_dataset("quickstart", max_samples=10)
    
    model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/object_detection_model.pt")
    
    dataset.apply_model(model, label_field="predictions")
    
    session = fo.launch_app(dataset)
    

    Task examples

    The subsections below demonstrate both (a) FiftyOne dataset -> Flash, and (b) Flash predictions -> FiftyOne for each task type.

    Video classification

    from torch.utils.data.sampler import RandomSampler
    
    import flash
    from flash.core.classification import FiftyOneLabels
    from flash.core.data.utils import download_data
    from flash.video import VideoClassificationData, VideoClassifier
    
    import fiftyone as fo
    
    # 1. Download data
    download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip")
    
    # 2. Load data into FiftyOne
    # Here we use different datasets for each split, but you can also
    # use views into the same dataset
    train_dataset = fo.Dataset.from_dir(
        "data/kinetics/train",
        fo.types.VideoClassificationDirectoryTree,
        label_field="ground_truth",
        max_samples=5,
    )
    
    val_dataset = fo.Dataset.from_dir(
        "data/kinetics/val",
        fo.types.VideoClassificationDirectoryTree,
        label_field="ground_truth",
        max_samples=5,
    )
    
    predict_dataset = fo.Dataset.from_dir(
        "data/kinetics/predict",
        fo.types.VideoDirectory,
        max_samples=5,
    )
    
    # 3. Finetune a model
    classifier = VideoClassifier.load_from_checkpoint(
      "https://flash-weights.s3.amazonaws.com/video_classification.pt",
      pretrained=False,
    )
    
    datamodule = VideoClassificationData.from_fiftyone(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        predict_dataset=predict_dataset,
        label_field="ground_truth",
        batch_size=8,
        clip_sampler="uniform",
        clip_duration=1,
        video_sampler=RandomSampler,
        decode_audio=False,
        num_workers=8,
    )
    
    trainer = flash.Trainer(max_epochs=1, fast_dev_run=1)
    trainer.finetune(classifier, datamodule=datamodule)
    trainer.save_checkpoint("video_classification.pt")
    
    # 4. Predict from checkpoint
    classifier = VideoClassifier.load_from_checkpoint(
      "https://flash-weights.s3.amazonaws.com/video_classification.pt",
      pretrained=False,
    )
    
    classifier.serializer = FiftyOneLabels()
    
    filepaths = predict_dataset.values("filepath")
    predictions = classifier.predict(filepaths)
    
    predict_dataset.set_values("predictions", predictions)
    
    # 5. Visualize in FiftyOne App
    session = fo.launch_app(predict_dataset)
    

    Image classification

    from itertools import chain
    
    import fiftyone as fo
    import fiftyone.zoo as foz
    
    from flash import Trainer
    from flash.core.classification import FiftyOneLabels
    from flash.core.finetuning import FreezeUnfreeze
    from flash.image import ImageClassificationData, ImageClassifier
    
    # 1. Load your FiftyOne dataset
    # Here we use views into one dataset, but you can also create a
    # different dataset for each split
    dataset = foz.load_zoo_dataset("cifar10", split="test", max_samples=40)
    train_dataset = dataset.shuffle(seed=51)[:20]
    test_dataset = dataset.shuffle(seed=51)[20:25]
    val_dataset = dataset.shuffle(seed=51)[25:30]
    predict_dataset = dataset.shuffle(seed=51)[30:40]
    
    # 2. Load the Datamodule
    datamodule = ImageClassificationData.from_fiftyone(
        train_dataset = train_dataset,
        test_dataset = test_dataset,
        val_dataset = val_dataset,
        predict_dataset = predict_dataset,
        label_field = "ground_truth",
        batch_size=4,
        num_workers=4,
    )
    
    # 3. Build the model
    model = ImageClassifier(
        backbone="resnet18",
        num_classes=datamodule.num_classes,
        serializer=FiftyOneLabels(),
    )
    
    # 4. Create the trainer
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=1,
        limit_val_batches=1,
    )
    
    # 5. Finetune the model
    trainer.finetune(
        model,
        datamodule=datamodule,
        strategy=FreezeUnfreeze(unfreeze_epoch=1),
    )
    
    # 6. Save it!
    trainer.save_checkpoint("image_classification_model.pt")
    
    # 7. Generate predictions
    model = ImageClassifier.load_from_checkpoint(
      "https://flash-weights.s3.amazonaws.com/image_classification_model.pt"
    )
    model.serializer = FiftyOneLabels()
    
    predictions = trainer.predict(model, datamodule=datamodule)
    
    predictions = list(chain.from_iterable(predictions)) # flatten batches
    
    # 8. Add predictions to dataset and analyze
    predict_dataset.set_values("flash_predictions", predictions)
    session = fo.launch_app(view=predict_dataset)
    

    Object detection

    from itertools import chain
    
    import fiftyone as fo
    import fiftyone.zoo as foz
    
    from flash import Trainer
    from flash.image import ObjectDetectionData, ObjectDetector
    from flash.image.detection.serialization import FiftyOneDetectionLabels
    
    # 1. Load your FiftyOne dataset
    # Here we use views into one dataset, but you can also create a
    # different dataset for each split
    dataset = foz.load_zoo_dataset("quickstart", max_samples=40)
    train_dataset = dataset.shuffle(seed=51)[:20]
    test_dataset = dataset.shuffle(seed=51)[20:25]
    val_dataset = dataset.shuffle(seed=51)[25:30]
    predict_dataset = dataset.shuffle(seed=51)[30:40]
    
    # 2. Load the Datamodule
    datamodule = ObjectDetectionData.from_fiftyone(
        train_dataset = train_dataset,
        test_dataset = test_dataset,
        val_dataset = val_dataset,
        predict_dataset = predict_dataset,
        label_field = "ground_truth",
        batch_size=4,
        num_workers=4,
    )
    
    # 3. Build the model
    model = ObjectDetector(
        model="retinanet",
        num_classes=datamodule.num_classes,
        serializer=FiftyOneDetectionLabels(),
    )
    
    # 4. Create the trainer
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=1,
        limit_val_batches=1,
    )
    
    # 5. Finetune the model
    trainer.finetune(model, datamodule=datamodule)
    
    # 6. Save it!
    trainer.save_checkpoint("object_detection_model.pt")
    
    # 7. Generate predictions
    model = ObjectDetector.load_from_checkpoint(
      "https://flash-weights.s3.amazonaws.com/object_detection_model.pt"
    )
    model.serializer = FiftyOneDetectionLabels()
    
    predictions = trainer.predict(model, datamodule=datamodule)
    
    predictions = list(chain.from_iterable(predictions)) # flatten batches
    
    # 8. Add predictions to dataset and analyze
    predict_dataset.set_values("flash_predictions", predictions)
    session = fo.launch_app(view=predict_dataset)
    

    Semantic segmentation

    from itertools import chain
    
    import fiftyone as fo
    import fiftyone.zoo as foz
    
    from flash import Trainer
    from flash.core.data.utils import download_data
    from flash.image import SemanticSegmentation, SemanticSegmentationData
    from flash.image.segmentation.serialization import FiftyOneSegmentationLabels
    
    # 1. Load your FiftyOne dataset
    # This is a Dataset with Semantic Segmentation Labels generated via CARLA
    self-driving simulator.
    # The data was generated as part of the Lyft Udacity Challenge.
    # More info here:
    https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
    download_data(
      "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
      "data/"
    )
    
    # Here we use views into one dataset, but you can also create a
    # different dataset for each split
    dataset = fo.Dataset.from_dir(
        dataset_dir = "data",
        data_path = "CameraRGB",
        labels_path = "CameraSeg",
        max_samples = 40,
        force_grayscale = True,
        dataset_type=fo.types.ImageSegmentationDirectory,
    )
    train_dataset = dataset.shuffle(seed=51)[:20]
    test_dataset = dataset.shuffle(seed=51)[20:25]
    val_dataset = dataset.shuffle(seed=51)[25:30]
    predict_dataset = dataset.shuffle(seed=51)[30:40]
    
    # 2. Load the Datamodule
    datamodule = SemanticSegmentationData.from_fiftyone(
        train_dataset = train_dataset,
        test_dataset = test_dataset,
        val_dataset = val_dataset,
        predict_dataset = predict_dataset,
        label_field = "ground_truth",
        batch_size=4,
        num_workers=4,
        num_classes=21,
    )
    
    # 3. Build the model
    model = SemanticSegmentation(
        backbone="resnet50",
        num_classes=datamodule.num_classes,
        serializer=FiftyOneSegmentationLabels(),
    )
    
    # 4. Create the trainer
    trainer = Trainer(
        max_epochs=1,
        fast_dev_run=1,
    )
    
    # 5. Finetune the model
    trainer.finetune(model, datamodule=datamodule, strategy="freeze")
    
    # 6. Save it!
    trainer.save_checkpoint("semantic_segmentation_model.pt")
    
    # 7. Generate predictions
    model = ObjectDetector.load_from_checkpoint(
      "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
    )
    model.serializer = FiftyOneSegmentationLabels()
    
    predictions = trainer.predict(model, datamodule=datamodule)
    
    predictions = list(chain.from_iterable(predictions)) # flatten batches
    
    # 8. Add predictions to dataset and analyze
    predict_dataset.set_values("flash_predictions", predictions)
    session = fo.launch_app(view=predict_dataset)
    

    Image embeddings

    import numpy as np
    import torch
    
    from flash.core.data.utils import download_data
    from flash.image import ImageEmbedder
    
    import fiftyone as fo
    import fiftyone.brain as fob
    
    # 1 Download data
    download_data(
        "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip"
    )
    
    # 2 Load data into FiftyOne
    dataset = fo.Dataset.from_dir(
        "data/hymenoptera_data/test/",
        fo.types.ImageClassificationDirectoryTree,
    )
    
    # 3 Load model
    embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128)
    
    # 4 Generate embeddings
    filepaths = dataset.values("filepath")
    embeddings = np.stack(embedder.predict(filepaths))
    
    # 5 Visualize in FiftyOne App
    results = fob.compute_visualization(dataset, embeddings=embeddings)
    
    session = fo.launch_app(dataset)
    
    plot = results.visualize(labels="ground_truth.label")
    plot.show()
    

    Before submitting

    • [X] (This PR was discussed face-to-face) Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
    • [X] Did you read the contributor guideline, Pull Request section?
    • [X] Did you make sure your PR does only one thing, instead of bundling different changes together?
    • [ ] Did you make sure to update the documentation with your changes?
    • [X] Did you write any new necessary tests? [not needed for typos/docs]
    • [X] Did you verify new and existing tests pass locally with your changes?
    • [X] If you made a notable change (that affects users), did you update the CHANGELOG?

    PR review

    • [X] Is this pull request ready for review? (if not, please submit in draft mode)

    Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

    Did you have fun?

    Make sure you had fun coding 🙃

    opened by ehofesmann 23
  • TypeError: list indices must be integers or slices, not DefaultDataKeys when training Object Detection Model

    TypeError: list indices must be integers or slices, not DefaultDataKeys when training Object Detection Model

    🐛 Bug

    I've spent days making the data augmentation work for Object Detection but errors keep poping up. I don't know if I'm reinventing the wheels or you are missing a lot in term data preparation/augmentation documentation for object detection. I'm about to give up...

    Following #409 (always not resolved) I've created a custom data augmentation transformation using albumentations. However it fails with a weird message when starting training (when we fix the error I can make a PR for integrating albumentations with pytorch lightning flash):

    File "train.py", line 93, in train
        trainer.finetune(model, datamodule=datamodule)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/flash/core/trainer.py", line 148, in finetune
        return super().fit(model, train_dataloader, val_dataloaders, datamodule)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
        self._run(model)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
        self.dispatch()
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
        self.accelerator.start_training(self)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
        self.training_type_plugin.start_training(trainer)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
        self._results = trainer.run_stage()
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
        return self.run_train()
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 842, in run_train
        self.run_sanity_check(self.lightning_module)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1107, in run_sanity_check
        self.run_evaluation()
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 962, in run_evaluation
        output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 174, in evaluation_step
        output = self.trainer.accelerator.validation_step(args)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 226, in validation_step
        return self.training_type_plugin.validation_step(*args)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 322, in validation_step
        return self.model(*args, **kwargs)
      File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/usr/lib/python3/dist-packages/torch/nn/parallel/distributed.py", line 705, in forward
        output = self.module(*inputs[0], **kwargs[0])
      File "/usr/lib/python3/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
        output = self.module.validation_step(*inputs, **kwargs)
      File "/home/ubuntu/.local/lib/python3.8/site-packages/flash/image/detection/model.py", line 179, in validation_step
        images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]
    TypeError: list indices must be integers or slices, not DefaultDataKeys
    

    Before that it was failing with RuntimeError: each element in list of batch should be of equal size but this torch vision tip of custom collate : lambda x:x "fixes" it https://github.com/pytorch/vision/issues/2624

    What is going on?

    To Reproduce

    
    import albumentations as A
    from albumentations.pytorch.transforms import ToTensorV2
    from PIL import Image
    import cv2
    
    import flash
    from flash.core.data.utils import download_data
    from flash.image import ObjectDetectionData, ObjectDetector
    from pytorch_lightning import seed_everything
    import numpy
    
    import logging
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger(__name__)
    
    seed_everything(42)
    
       image_size = 1024
    
        train_transform = A.Compose(
            [
                A.Resize(height=image_size, width=image_size, p=1),
                A.OneOf([
                    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=20,
                                         val_shift_limit=20, p=0.5),
                    A.RandomBrightnessContrast(brightness_limit=0.2,
                                               contrast_limit=0.2, p=0.5),
                ], p=0.9),
                A.ToGray(p=0.01),
                A.VerticalFlip(p=0.5),
                A.HorizontalFlip(p=0.5),
                A.ShiftScaleRotate(p=0.5),
                A.Cutout(num_holes=10, max_h_size=32, max_w_size=32, fill_value=0, p=0.5),
                ToTensorV2(p=1)
            ],
            p=1.0,
            bbox_params=A.BboxParams(
                format='pascal_voc',
                min_area=0,
                min_visibility=0,
                label_fields=['labels']
            )
        )
    
        valid_transform = A.Compose(
            [
                A.Resize(height=image_size, width=image_size, p=1),
                ToTensorV2(p=1)
            ],
            p=1.0,
            bbox_params=A.BboxParams(
                format='pascal_voc',
                min_area=0,
                min_visibility=0,
                label_fields=['labels']
            )
        )
    
        test_transform = A.Compose(
            [
                A.Resize(height=image_size, width=image_size, p=1),
                ToTensorV2(p=1)
            ],
            p=1.0,
            bbox_params=A.BboxParams(
                format='pascal_voc',
                min_area=0,
                min_visibility=0,
                label_fields=['labels']
            )
        )
    
        datamodule = ObjectDetectionData.from_coco(
            train_folder="data_coco/train",
            train_ann_file="data_coco/train/_annotations.coco.json",
            train_transform={
                'pre_tensor_transform': lambda sample: transform_using_albu(sample, train_transform),
                'collate' : lambda x: x
            },
            val_transform={
             'pre_tensor_transform': lambda sample: transform_using_albu(sample, valid_transform),
              'collate': lambda x: x
            },
            test_transform={
             'pre_tensor_transform': lambda sample: transform_using_albu(sample, test_transform),
             'collate': lambda x: x
           },
            val_split=0.2,
            batch_size=8,
            num_workers=4,
        )
    
        model = ObjectDetector(model="retinanet", backbone="resnet101", num_classes=datamodule.num_classes, fpn=True)
    
        # 4. Create the trainer
        trainer = flash.Trainer(max_epochs=1, gpus=2, accelerator='ddp', limit_train_batches=1, limit_val_batches=1, checkpoint_callback=True)
    
        # 5. Finetune the model
        trainer.finetune(model, datamodule=datamodule)
    
    
    def transform_using_albu(sample, train_transform):
            labels = sample['target']['labels']
            image = to_cv(sample['input'])
            transformed = train_transform(image=image, bboxes=sample['target']['boxes'], labels=sample['target']['labels'])
            trans_bboxes = [list(boxes) for boxes in transformed["bboxes"]]
            area = [calculate_area(boxes) for boxes in trans_bboxes]
            return {
                'input': transformed["image"],
                'target': {
                  'boxes': trans_bboxes,
                  'labels': labels,
                  'image_id': sample['target']['image_id'],
                  'area': area,
                  'iscrowd': [0 for _ in trans_bboxes]}
                }
    
    

    Environment

    • PyTorch Version: 1.8
    • OS (e.g., Linux): MacOS
    • How you installed PyTorch: pip
    • Python version: 3.7
    • CUDA/cuDNN version: 11
    • GPU models and configuration: 2 A 6000

    Additional context

    It is necessary to provide a clear and working example of augmenting and resizing images for object detection using torchvision transformers or albumentations.

    bug / fix help wanted 
    opened by hzitoun 18
  • Additonal Backbones for Object Detection Task

    Additonal Backbones for Object Detection Task

    Object Detection task currently only seems to support FasterRCNN it would be good to have additional backbones such as RetinaNet or YoloV5 .

    enhancement help wanted 
    opened by aribornstein 17
  • feat: Add Detection Task

    feat: Add Detection Task

    What does this PR do?

    Add support for detection

    Before submitting

    • [x] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
    • [x] Did you read the contributor guideline, Pull Request section?
    • [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
    • [ ] Did you make sure to update the documentation with your changes?
    • [x] Did you write any new necessary tests? [not needed for typos/docs]
    • [x] Did you verify new and existing tests pass locally with your changes?
    • [ ] If you made a notable change (that affects users), did you update the CHANGELOG?

    PR review

    • [ ] Is this pull request ready for review? (if not, please submit in draft mode)

    Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

    Did you have fun?

    Make sure you had fun coding 🙃

    task 
    opened by kaushikb11 17
  • simplify examples

    simplify examples

    What does this PR do?

    since we do not run tests on examples anyway there is o reason to have them wrap in main especially since there is nothing else than main anyway...

    Before submitting

    • [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
    • [x] Did you read the contributor guideline, Pull Request section?
    • [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
    • [x] Did you make sure to update the documentation with your changes?
    • [x] Did you write any new necessary tests? [not needed for typos/docs]
    • [x] Did you verify new and existing tests pass locally with your changes?
    • [ ] If you made a notable change (that affects users), did you update the CHANGELOG?

    PR review

    • [x] Is this pull request ready for review? (if not, please submit in draft mode)

    Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

    Did you have fun?

    Make sure you had fun coding 🙃

    documentation 
    opened by Borda 15
  • try minimal requirements

    try minimal requirements

    What does this PR do?

    unfreeze requirements, cc: @SeanNaren

    Before submitting

    • [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
    • [x] Did you read the contributor guideline, Pull Request section?
    • [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
    • [x] Did you make sure to update the documentation with your changes?
    • [x] Did you write any new necessary tests? [not needed for typos/docs]
    • [x] Did you verify new and existing tests pass locally with your changes?
    • [ ] If you made a notable change (that affects users), did you update the CHANGELOG?

    PR review

    • [x] Is this pull request ready for review? (if not, please submit in draft mode)

    Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

    Did you have fun?

    Make sure you had fun coding 🙃

    enhancement 
    opened by Borda 15
  • TypeError: list indices must be integers or slices, not DefaultDataKey

    TypeError: list indices must be integers or slices, not DefaultDataKey

    🐛 Bug

    Hi guys, flash seems to expect data loaders to return dictionaries and not normal tuples like in 99% of the cases

    To Reproduce

    Steps to reproduce the behavior. Just run the sample code

    
    (venv) [email protected]:~/gust-torchvision$ /home/zuppif/gust-torchvision/venv/bin/python3 /home/zuppif/gust-torchvision/playground.py
    eehh
    GPU available: True, used: True
    TPU available: False, using: 0 TPU cores
    IPU available: False, using: 0 IPUs
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]
    Traceback (most recent call last):
      File "/home/zuppif/gust-torchvision/playground.py", line 60, in <module>
        trainer.finetune(classifier, datamodule=dm, strategy="freeze")
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/flash/core/trainer.py", line 165, in finetune
        return super().fit(model, train_dataloader, val_dataloaders, datamodule)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
        self._run(model)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 912, in _run
        self._pre_dispatch()
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 941, in _pre_dispatch
        self._log_hyperparams()
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 970, in _log_hyperparams
        self.logger.save()
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
        return fn(*args, **kwargs)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/loggers/tensorboard.py", line 249, in save
        save_hparams_to_yaml(hparams_file, self.hparams)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 413, in save_hparams_to_yaml
        with fs.open(config_yaml, "w", newline="") as fp:
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/fsspec/spec.py", line 972, in open
        self.open(path, mode, block_size, **kwargs), **text_kwargs
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/fsspec/spec.py", line 976, in open
        f = self._open(
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/fsspec/implementations/local.py", line 145, in _open
        return LocalFileOpener(path, mode, fs=self, **kwargs)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/fsspec/implementations/local.py", line 236, in __init__
        self._open()
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/fsspec/implementations/local.py", line 241, in _open
        self.f = open(self.path, mode=self.mode)
    FileNotFoundError: [Errno 2] No such file or directory: '/home/zuppif/gust-torchvision/lightning_logs/version_2/hparams.yaml'
    (venv) [email protected]:~/gust-torchvision$ /home/zuppif/gust-torchvision/venv/bin/python3 /home/zuppif/gust-torchvision/playground.py
    eehh
    GPU available: True, used: True
    TPU available: False, using: 0 TPU cores
    IPU available: False, using: 0 IPUs
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]
    
      | Name          | Type       | Params
    ---------------------------------------------
    0 | train_metrics | ModuleDict | 0     
    1 | val_metrics   | ModuleDict | 0     
    2 | backbone      | Sequential | 11.2 M
    3 | head          | Sequential | 5.1 K 
    ---------------------------------------------
    14.7 K    Trainable params
    11.2 M    Non-trainable params
    11.2 M    Total params
    44.727    Total estimated model params size (MB)
    Validation sanity check: 0it [00:00, ?it/s]/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/flash/core/model.py:397: LightningDeprecationWarning: The `LightningModule.datamodule` property is deprecated in v1.3 and will be removed in v1.5. Access the datamodule through using `self.trainer.datamodule` instead.
      if self.datamodule is not None and getattr(self.datamodule, "data_pipeline", None) is not None:
    Validation sanity check:   0%|                                                                                                                                      | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
      File "/home/zuppif/gust-torchvision/playground.py", line 60, in <module>
        trainer.finetune(classifier, datamodule=dm, strategy="freeze")
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/flash/core/trainer.py", line 165, in finetune
        return super().fit(model, train_dataloader, val_dataloaders, datamodule)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
        self._run(model)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 918, in _run
        self._dispatch()
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _dispatch
        self.accelerator.start_training(self)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
        self.training_type_plugin.start_training(trainer)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
        self._results = trainer.run_stage()
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage
        return self._run_train()
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1031, in _run_train
        self._run_sanity_check(self.lightning_module)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/flash/core/trainer.py", line 93, in _run_sanity_check
        super()._run_sanity_check(ref_model)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1115, in _run_sanity_check
        self._evaluation_loop.run()
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
        self.advance(*args, **kwargs)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
        dl_outputs = self.epoch_loop.run(
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
        self.advance(*args, **kwargs)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 110, in advance
        output = self.evaluation_step(batch, batch_idx, dataloader_idx)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 154, in evaluation_step
        output = self.trainer.accelerator.validation_step(step_kwargs)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 211, in validation_step
        return self.training_type_plugin.validation_step(*step_kwargs.values())
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 178, in validation_step
        return self.model.validation_step(*args, **kwargs)
      File "/home/zuppif/gust-torchvision/venv/lib/python3.8/site-packages/flash/image/classification/model.py", line 121, in validation_step
        batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
    TypeError: list indices must be integers or slices, not DefaultDataKeys
    (venv) [email protected]:~/gust-torchvision$ 
    

    Code sample

    class FakeData(LightningDataModule):
    
        def __init__(self, n: int = 512, n_classes: int = 10):
            super().__init__()
            self.n = n
            self.n_classes = n_classes
    
        def dataloader(self):
            imgs, labels = torch.randn((self.n, 3, 224, 224)), torch.randint(0, self.n_classes, size=(self.n, 1))
            ds = TensorDataset(imgs, labels)
            return DataLoader(ds, batch_size=32, num_workers=8)
    
        def train_dataloader(self):
            return self.dataloader()
    
        def val_dataloader(self):
            return self.dataloader()
    
        def test_dataloader(self):
            return self.dataloader()
    
    dm = FakeData()
    
    dl = dm.train_dataloader()
    num_classes = dm.n_classes
    
    metrics = M.MetricCollection({
        'accuracy': M.Accuracy(num_classes=num_classes),
        'recall': M.Recall(num_classes=num_classes),
        'f1': M.F1(num_classes=num_classes),
        'precision': M.Precision(num_classes=num_classes)
    })
    
    classifier = ImageClassifier(backbone='resnet18', num_classes=dm.n_classes, metrics=metrics)
    
    trainer = Trainer(gpus=1, max_epochs=1)
    trainer.finetune(classifier, datamodule=dm, strategy="freeze")
    
    trainer.save_checkpoint('./checkpoint.pt')
    

    Expected behavior

    It should work

    Environment

    • PyTorch Version (e.g., 1.0):
    • OS (e.g., Linux):
    • How you installed PyTorch (conda, pip, source):
    • Build command you used (if compiling from source):
    • Python version:
    • CUDA/cuDNN version:
    • GPU models and configuration:
    • Any other relevant information:

    Additional context

    bug / fix help wanted 
    opened by FrancescoSaverioZuppichini 14
  • List of models that can be used as a backbone

    List of models that can be used as a backbone

    🚀 Feature

    There is no central list of models that can be used as a backbone in the docs. It becomes difficult to search for which model can be used for what task.
    Here in the docs the info for the backbone isn't clear either. And it is the case in a lot of other places as well.
    Also in the README, in the first code example for TranslationTask the weights are loaded from a checkpoint using an AWS url. How does one find that url and model name?

    Motivation

    This is really annoying when you are trying to build a quick model for any task and you have to browse through flash as well as huggingface or any other libraries docs just to check if the model will work. This might be similar to #593

    Pitch

    Add a central list or table separated by task that contains the name of all the backbones or models that can be used for that particular task.

    documentation enhancement help wanted 
    opened by bamblebam 12
  • Inconsistency in F1 metric between manual eval and Trainer.test() run

    Inconsistency in F1 metric between manual eval and Trainer.test() run

    🐛 Bug

    When training a multilabel image classifier as described in the docs, (original link:https://lightning-flash.readthedocs.io/en/latest/reference/multi_label_classification.html),

    import os.path as osp
    from typing import List, Tuple
    
    import pandas as pd
    from torchmetrics import F1
    
    import flash
    from flash.core.classification import Labels
    from flash.core.data.utils import download_data
    from flash.image import ImageClassificationData, ImageClassifier
    from flash.image.classification.data import ImageClassificationPreprocess
    
    # 1. Download the data
    # This is a subset of the movie poster genre prediction data set from the paper
    # “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo.
    # Please consider citing their paper if you use it. More here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/
    download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/")
    
    # 2. Load the data
    genres = ["Action", "Romance", "Crime", "Thriller", "Adventure"]
    
    
    def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]:
        metadata = pd.read_csv(osp.join(root, data, "metadata.csv"))
        return ([osp.join(root, data, row['Id'] + ".jpg") for _, row in metadata.iterrows()],
                [[int(row[genre]) for genre in genres] for _, row in metadata.iterrows()])
    
    
    train_files, train_targets = load_data('train')
    test_files, test_targets = load_data('test')
    
    datamodule = ImageClassificationData.from_files(
        train_files=train_files,
        train_targets=train_targets,
        test_files=test_files,
        test_targets=test_targets,
        val_split=0.1,  # Use 10 % of the train dataset to generate validation one.
        image_size=(128, 128),
    )
    
    # 3. Build the model
    model = ImageClassifier(
        backbone="resnet18",
        num_classes=len(genres),
        multi_label=True,
        metrics=F1(num_classes=len(genres)),
    )
    
    # 4. Create the trainer. Train on 2 gpus for 10 epochs.
    trainer = flash.Trainer(max_epochs=10)
    
    # 5. Train the model
    trainer.finetune(model, datamodule=datamodule, strategy="freeze")
    
    # 6. Predict what's on a few images!
    # Serialize predictions as labels, low threshold to see more predictions.
    model.serializer = Labels(genres, multi_label=True, threshold=0.25)
    
    predictions = model.predict([
        "data/movie_posters/predict/tt0085318.jpg",
        "data/movie_posters/predict/tt0089461.jpg",
        "data/movie_posters/predict/tt0097179.jpg",
    ])
    
    print(predictions)
    
    # 7. Save it!
    trainer.save_checkpoint("image_classification_multi_label_model.pt")
    

    I get different F1 metrics for the test set depending on how I run the evaluation:

    # Run test with trainer:
    
    trainer.test(model, datamodule=datamodule)
    
    # stdout:
    # {'test_binary_cross_entropy_with_logits': 0.5449734330177307,
    # 'test_f1': 0.46086955070495605}
    
    # Run test manually:
    
    metric = F1(num_classes=len(genres))
    
    for batch in datamodule.test_dataloader():
        image_tensor = batch[DefaultDataKeys.INPUT]
        target = batch[DefaultDataKeys.TARGET]
        with torch.no_grad():
            y_hat = model(image_tensor)
        prediction = model.to_metrics_format(y_hat)
        metric(prediction, target)
    
    print(metric.compute())
    
    # stdout:
    # tensor(0.3891)
    

    To Reproduce

    Steps to reproduce the behavior:

    1. Copy paste the example training code from the link above
    2. Add the test evaluation code above
    3. Save and run the script
    4. See error

    Expected behavior

    The two F1 metrics should be identical

    Environment

    • PyTorch Version: 1.8.0
    • PyTorch-Lightning: 1.3.5
    • Lightning-Flash: 0.3.2
    • Torchmetrics: 0.3.2
    • OS (e.g., Linux): macOS
    • How you installed PyTorch (conda, pip, source): pip
    • Python version: 3.8.8
    • CUDA/cuDNN version: N/A
    • GPU models and configuration: None
    • Any other relevant information: None

    Additional context

    None

    bug / fix help wanted won't fix 
    opened by lillekemiker 11
  • Update lightning version to v1.2

    Update lightning version to v1.2

    What does this PR do?

    Fixes #132

    Before submitting

    • [x] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
    • [x] Did you read the contributor guideline, Pull Request section?
    • [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
    • [x] Did you make sure to update the documentation with your changes?
    • [ ] Did you write any new necessary tests? [not needed for typos/docs]
    • [x] Did you verify new and existing tests pass locally with your changes?
    • [ ] If you made a notable change (that affects users), did you update the CHANGELOG?

    PR review

    • [x] Is this pull request ready for review? (if not, please submit in draft mode)

    Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

    Did you have fun?

    Make sure you had fun coding 🙃

    bug / fix Priority 
    opened by kaushikb11 10
  • One requirements file per task

    One requirements file per task

    Users shouldn't have to install requirements they aren't using. Should make the following changes:

    • one requirements file per task
    • allow the requires decorator to reference a requirements file (i.e. remove current manual approach)
    enhancement 
    opened by ethanwharris 0
  • Flash CLI not instantiating classes in subcommand config

    Flash CLI not instantiating classes in subcommand config

    🐛 Bug

    I tried to use class instantiation in a subcommand in the Flash CLI, only to find out it was failing.

    After some digging, this bug is related to the subcommand configuration not being instantiated with self.parser.instantiate_classes before being actually use to generate the corresponding object.

    To Reproduce

    Try to run any flash task using a class instantiation syntax in the subcommand.

    Code sample:

    flash image_classification --model.num_classes=10 from_datasets --train_dataset "{'class_path':'torchvision.datasets.CIFAR10','init_args':{'root':'/path/to/CIFAR10'}}"
    

    Output with traceback:

    GPU available: True, used: False
    TPU available: False, using: 0 TPU cores
    IPU available: False, using: 0 IPUs
    <redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1579: UserWarning: GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.
      rank_zero_warn(
    <redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:90: LightningDeprecationWarning: Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0. Please use `train_dataloader()` directly.
      rank_zero_deprecation(
    <redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:118: UserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
      rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
    <redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:125: LightningDeprecationWarning: Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0. Please use `val_dataloader()` directly.
      rank_zero_deprecation(
    
      | Name          | Type           | Params
    -------------------------------------------------
    0 | train_metrics | ModuleDict     | 0     
    1 | val_metrics   | ModuleDict     | 0     
    2 | test_metrics  | ModuleDict     | 0     
    3 | adapter       | DefaultAdapter | 11.2 M
    -------------------------------------------------
    14.7 K    Trainable params
    11.2 M    Non-trainable params
    11.2 M    Total params
    44.727    Total estimated model params size (MB)
    <redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:110: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
      rank_zero_warn(
    <redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:406: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
      rank_zero_warn(
    Epoch 0:   0%|                                                                                                                                                            | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
      File "<redacted>/bin/flash", line 8, in <module>
        sys.exit(main())
      File "<redacted>/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
        return self.main(*args, **kwargs)
      File "<redacted>/lib/python3.9/site-packages/click/core.py", line 1053, in main
        rv = self.invoke(ctx)
      File "<redacted>/lib/python3.9/site-packages/click/core.py", line 1659, in invoke
        return _process_result(sub_ctx.command.invoke(sub_ctx))
      File "<redacted>/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
        return ctx.invoke(self.callback, **ctx.params)
      File "<redacted>/lib/python3.9/site-packages/click/core.py", line 754, in invoke
        return __callback(*args, **kwargs)
      File "<redacted>/lib/python3.9/site-packages/flash/__main__.py", line 38, in wrapper
        command()
      File "<redacted>/lib/python3.9/site-packages/flash/image/classification/cli.py", line 58, in image_classification
        cli = FlashCLI(
      File "<redacted>/lib/python3.9/site-packages/flash/core/utilities/flash_cli.py", line 162, in __init__
        super().__init__(
      File "<redacted>/lib/python3.9/site-packages/flash/core/utilities/lightning_cli.py", line 309, in __init__
        self.fit()
      File "<redacted>/lib/python3.9/site-packages/flash/core/utilities/flash_cli.py", line 280, in fit
        self.trainer.finetune(**self.fit_kwargs)
      File "<redacted>/lib/python3.9/site-packages/flash/core/trainer.py", line 198, in finetune
        return super().fit(model, train_dataloader, val_dataloaders, datamodule)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
        self._call_and_handle_interrupt(
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
        return trainer_fn(*args, **kwargs)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
        self._run(model, ckpt_path=ckpt_path)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
        self._dispatch()
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
        self.training_type_plugin.start_training(self)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
        self._results = trainer.run_stage()
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
        return self._run_train()
      File "<redacted>/lib/python3.9/site-packages/flash/core/trainer.py", line 128, in _run_train
        self.fit_loop.run()
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
        self.advance(*args, **kwargs)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
        self.epoch_loop.run(data_fetcher)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 140, in run
        self.on_run_start(*args, **kwargs)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 141, in on_run_start
        self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py", line 121, in _update_dataloader_iter
        dataloader_iter = enumerate(data_fetcher, batch_idx)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 199, in __iter__
        self.prefetching(self.prefetch_batches)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 258, in prefetching
        self._fetch_next_batch()
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/utilities/fetching.py", line 300, in _fetch_next_batch
        batch = next(self.dataloader_iter)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/supporters.py", line 550, in __next__
        return self.request_next_batch(self.loader_iters)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/trainer/supporters.py", line 562, in request_next_batch
        return apply_to_collection(loader_iters, Iterator, next)
      File "<redacted>/lib/python3.9/site-packages/pytorch_lightning/utilities/apply_func.py", line 92, in apply_to_collection
        return function(data, *args, **kwargs)
      File "<redacted>/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
        data = self._next_data()
      File "<redacted>/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 557, in _next_data
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
      File "<redacted>/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
        data = [self.dataset[idx] for idx in possibly_batched_index]
      File "<redacted>/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
        data = [self.dataset[idx] for idx in possibly_batched_index]
      File "<redacted>/lib/python3.9/site-packages/flash/core/data/auto_dataset.py", line 98, in __getitem__
        return self._call_load_sample(self.data[index])
    KeyError: 0
    Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s]
    

    Code sample

    flash image_classification --model.num_classes=10 from_datasets --train_dataset "{'class_path':'torchvision.datasets.CIFAR10','init_args':{'root':'/path/to/CIFAR10'}}"
    

    Expected behavior

    Execution working normally.

    Environment

    • PyTorch Version (e.g., 1.0): 1.8.1
    • OS (e.g., Linux): Manjaro Linux x86_64
    • How you installed PyTorch (conda, pip, source): pip in a conda environment
    • Build command you used (if compiling from source):
    • Python version: 3.9.7
    • CUDA/cuDNN version: 10.1.243 from conda-forge
    • GPU models and configuration: RTX 2080Ti, but using CPU here
    • Any other relevant information: PyTorch Lightning 1.5.2, Flash 0.6.0.dev0 (master branch up to date 3-4 days ago), JSONArgParse 4.0.0

    Additional context

    A quick fix that I put together is to substitute the lines

    def instantiate_classes(self) -> None:
        """Instantiates the classes using settings from self.config."""
        sub_config = self.config.get("subcommand")
        self.datamodule = self._subcommand_builders[sub_config](**self.config.get(sub_config))
    

    with

    def instantiate_classes(self) -> None:
        """Instantiates the classes using settings from self.config."""
        sub_config = self.config.get("subcommand")
        config_init = self.parser.instantiate_classes(self.config)
        self.datamodule = self._subcommand_builders[sub_config](**config_init.get(sub_config))
    

    in the class FlashCLI in the file flash/core/utilities/flash_cli.py.

    I think this solution might be a bit overkill as it would instantiate the configurations again for using the trainer and model arguments later on, but it fixed the problem for me.

    A more efficient solution might be to instantiate only the related subcommand arguments, use them for the datamodule, and then instantiate the remainder of the configuration later on for the model. However, I don't have enough expertise with JSONArgParse or Flash CLI to write this more efficient solution myself.

    Thank you very much for the awesome tool :D

    Edit: on a side note, I didn't find any mentions on how to use a YAML config file with different subcommands in the documentation, as it is always trying to instantiate the main subcommand, e.g. from_hymenoptera for Image Classification. After some digging, I saw that it is enough to add 'subcommand': 'from_datasets' in the config file, but it would be helpful have a mention of this in the documentation.

    bug / fix help wanted 
    opened by Alexei95 0
  • Reformat text classification data

    Reformat text classification data

    What does this PR do?

    Reformats the text classification data.

    opened by pietrolesci 1
  • [WIP] Refactor image inputs and update to new input object

    [WIP] Refactor image inputs and update to new input object

    What does this PR do?

    Fixes # (issue)

    Before submitting

    • [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
    • [ ] Did you read the contributor guideline, Pull Request section?
    • [ ] Did you make sure your PR does only one thing, instead of bundling different changes together?
    • [ ] Did you make sure to update the documentation with your changes?
    • [ ] Did you write any new necessary tests? [not needed for typos/docs]
    • [ ] Did you verify new and existing tests pass locally with your changes?
    • [ ] If you made a notable change (that affects users), did you update the CHANGELOG?

    PR review

    • [ ] Is this pull request ready for review? (if not, please submit in draft mode)

    Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

    Did you have fun?

    Make sure you had fun coding 🙃

    opened by ethanwharris 0
  • St embeddings

    St embeddings

    First step to #760

    Generated Sentence Embeddings using SentenceTransformers.

    from flash.text.embeddings import SentenceEmbedder
    model = SentenceEmbedder()
    print(model.generate_embeddings(["This is a sentence","This is another sentence"]))
    
    

    @ethanwharris Updated Task 1 as per our conversation.

    opened by abhijithneilabraham 1
  • Data loading error.

    Data loading error.

    I am trying to finetuning wav2vec model for code switching. My json data file contains some characters from a different language (Hindi). While loading I am getting this error.

    Using custom data configuration default-e6037235814b16b2
    100%|██████████| 1/1 [00:00<?, ?it/s]
    100%|██████████| 1/1 [00:00<?, ?it/s]
    Failed to read file 'D:\pycharmprojects\wav2vec_finetuneing\train_data.json' with error <class 'pyarrow.lib.ArrowInvalid'>: JSON parse error: Invalid escape character in string. in row 2
    Traceback (most recent call last):
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\datasets\packaged_modules\json\json.py", line 140, in _generate_tables
        dataset = json.load(f)
      File "C:\Users\syeda\AppData\Local\Programs\Python\Python37\Lib\json\__init__.py", line 296, in load
        parse_constant=parse_constant, object_pairs_hook=object_pairs_hook, **kw)
      File "C:\Users\syeda\AppData\Local\Programs\Python\Python37\Lib\json\__init__.py", line 348, in loads
        return _default_decoder.decode(s)
      File "C:\Users\syeda\AppData\Local\Programs\Python\Python37\Lib\json\decoder.py", line 340, in decode
        raise JSONDecodeError("Extra data", s, end)
    json.decoder.JSONDecodeError: Extra data: line 2 column 1 (char 4185)
    
    During handling of the above exception, another exception occurred:
    
    Traceback (most recent call last):
      File "D:/pycharmprojects/wav2vec_finetuneing/train.py", line 17, in <module>
        train_file=r"D:\pycharmprojects\wav2vec_finetuneing\train_data.json",
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\flash\core\data\data_module.py", line 1035, in from_json
        **input_transform_kwargs,
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\flash\core\data\data_module.py", line 589, in from_input
        predict_data,
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\flash\core\data\io\input.py", line 307, in to_datasets
        train_dataset = self.generate_dataset(train_data, RunningStage.TRAINING)
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\flash\core\data\io\input.py", line 342, in generate_dataset
        data = load_data(data, mock_dataset)
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\flash\audio\speech_recognition\data.py", line 95, in load_data
        dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)})
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\datasets\load.py", line 1637, in load_dataset
        use_auth_token=use_auth_token,
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\datasets\builder.py", line 608, in download_and_prepare
        dl_manager=dl_manager, verify_infos=verify_infos, **download_and_prepare_kwargs
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\datasets\builder.py", line 697, in _download_and_prepare
        self._prepare_split(split_generator, **prepare_split_kwargs)
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\datasets\builder.py", line 1157, in _prepare_split
        generator, unit=" tables", leave=False, disable=True  # bool(logging.get_verbosity() == logging.NOTSET)
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\tqdm\std.py", line 1168, in __iter__
        for obj in iterable:
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\datasets\packaged_modules\json\json.py", line 142, in _generate_tables
        raise e
      File "D:\pycharmprojects\wav2vec_finetuneing\venv\lib\site-packages\datasets\packaged_modules\json\json.py", line 119, in _generate_tables
        io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size)
      File "pyarrow\_json.pyx", line 246, in pyarrow._json.read_json
      File "pyarrow\error.pxi", line 143, in pyarrow.lib.pyarrow_internal_check_status
      File "pyarrow\error.pxi", line 99, in pyarrow.lib.check_status
    pyarrow.lib.ArrowInvalid: JSON parse error: Invalid escape character in string. in row 2
    

    Any suggestions?

    Thanks

    bug / fix 
    opened by BakingBrains 2
  • Initial commit of node classification

    Initial commit of node classification

    What does this PR do?

    This PR creates a graph node classification task. It could also be used for node regression.

    Fixes #985

    Before submitting

    • [x] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
    • [x] Did you read the contributor guideline, Pull Request section?
    • [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
    • [x] Did you make sure to update the documentation with your changes?
    • [ ] Did you write any new necessary tests? [not needed for typos/docs]
    • [ ] Did you verify new and existing tests pass locally with your changes?
    • [x] If you made a notable change (that affects users), did you update the CHANGELOG?

    PR review

    • [ ] Is this pull request ready for review? (if not, please submit in draft mode)

    Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

    Did you have fun?

    Make sure you had fun coding 🙃

    enhancement task PyTorch Geometric 
    opened by PabloAMC 2
  • Graph node classification and regression task

    Graph node classification and regression task

    🚀 Feature

    Addition of graph node classification and regression task.

    Motivation

    An important graph task is node classification/regression.

    Pitch

    A flash implementation of this task.

    Alternatives

    Can be done using Pytorch Geometric library.

    Additional context

    Not sure yet how to handle masks.

    enhancement help wanted 
    opened by PabloAMC 0
  • `TextClassifier` with `load_from_checkpoint` still downloads data

    `TextClassifier` with `load_from_checkpoint` still downloads data

    🐛 Bug

    When I train a model I want to use it offline, so I save it, but when I load it from the saved model it still pulls the online model https://github.com/PyTorchLightning/lightning-flash/blob/a0c97a39f2083b5344a08d248ccab7e5bfa91df4/flash/text/classification/model.py#L90

    To Reproduce

    https://www.kaggle.com/jirkaborovec/toxic-comments-with-lightning-flash-inference?scriptVersionId=80368862

    Additional context

    cc: @rohitgr7 @karthikrangasai

    enhancement good first issue help wanted 
    opened by Borda 3
  • text classify prediction - memory overflow

    text classify prediction - memory overflow

    🐛 Bug

    seems like the prediction does not honor batch size and tries to load all at once... Have this simple prediction with backbone="xlm-roberta-base" and running predictions on https://www.kaggle.com/c/jigsaw-toxic-severity-rating/data?select=comments_to_score.csv

    predictions = model.predict(df_comments["text"])
    

    crashes on Kaggle kernel, but if I take it with batch size it works fine (just very slow)

    predictions = []
    for i in range(int(len(df_comments) / datamodule.batch_size)):
        predictions += model.predict(df_comments["text"][i * datamodule.batch_size:(i + 1) * datamodule.batch_size])
    

    To Reproduce

    https://www.kaggle.com/jirkaborovec/toxic-comments-with-lightning-flash?scriptVersionId=80223034

    Additional context

    cc: @karthikrangasai @rohitgr7

    bug / fix help wanted 
    opened by Borda 2
Releases(0.5.2)
  • 0.5.2(Nov 5, 2021)

    [0.5.2] - 2021-11-05

    Added

    • Added a TabularForecaster task based on PyTorch Forecasting (#647)
    • Added a TabularRegressor task (#892)

    Fixed

    • Fixed a bug where test metrics were not logged correctly with active learning (#879)
    • Fixed a bug where validation metrics could be aggregated together with test metrics in some cases (#900)
    • Fixed a bug where the latest versions of torchmetrics and Lightning Flash could not be installed together (#902)
    • Fixed compatibility with PyTorch-Lightning 1.5 (#933)

    Contributors

    @aniketmaurya @awaelchli @Borda @Dref360 @ethanwharris @pietrolesci @sumanmichael @twsl

    If we forgot someone due to not matching commit email with GitHub account, let us know :]

    Source code(tar.gz)
    Source code(zip)
  • 0.5.1(Oct 26, 2021)

    [0.5.1] - 2021-10-26

    Added

    • Added LabelStudio integration (#554)
    • Added support learn2learn training_strategy for ImageClassifier (#737)
    • Added vissl training_strategies for ImageEmbedder (#682)
    • Added support for from_data_frame to TextClassificationData (#785)
    • Added FastFace integration (#606)
    • Added support for from_lists to TextClassificationData (#805)

    Changed

    • Changed the default num_workers on linux to 0 (matching the default for other OS) (#759)
    • Optimizer and LR Scheduler registry are used to get the respective inputs to the Task using a string (or a callable). (#777)

    Fixed

    • Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored (#792)
    • Fixed a bug where loading text data with additional non-numeric columns (not input or target) would give an error (#888)

    New Contributors

    • @bamblebam made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/735
    • @dlangerm made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/765
    • @pietrolesci made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/767
    • @gianscarpe made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/776
    • @kingyiusuen made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/785
    • @Isaac-Flath made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/799
    • @KonstantinKorotaev made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/554
    • @borhanMorphy made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/606
    • @EStorm21 made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/837
    • @parmidaatg made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/822
    • @Darktex made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/824
    • @Dref360 made their first contribution in https://github.com/PyTorchLightning/lightning-flash/pull/861

    PR List

    • Bump version to 0.5.1dev by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/749
    • Docs for backbones by @bamblebam in https://github.com/PyTorchLightning/lightning-flash/pull/735
    • Refactor unnecessary else / elif when if block has a return statement by @deepsource-autofix in https://github.com/PyTorchLightning/lightning-flash/pull/751
    • Clean up docs by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/754
    • Set logo to have a white background for dark mode by @SeanNaren in https://github.com/PyTorchLightning/lightning-flash/pull/757
    • New README by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/756
    • Speed up and fix graph tests by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/759
    • Feature/output data keys by @dlangerm in https://github.com/PyTorchLightning/lightning-flash/pull/765
    • Fix logo spacing by @SeanNaren in https://github.com/PyTorchLightning/lightning-flash/pull/766
    • Move text backbones in separate module by @pietrolesci in https://github.com/PyTorchLightning/lightning-flash/pull/767
    • Speed up question answering tests by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/775
    • [PoC] Add MetaLearning support through learn2learn by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/737
    • [Readme] Add training strategies by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/780
    • VISSL initial integration by @ananyahjha93 in https://github.com/PyTorchLightning/lightning-flash/pull/682
    • Add thumbnails to card items by @SeanNaren in https://github.com/PyTorchLightning/lightning-flash/pull/787
    • Document object detector augmentations by @gianscarpe in https://github.com/PyTorchLightning/lightning-flash/pull/776
    • Fix RTD build by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/789
    • VISSL collate function/transforms restructure by @ananyahjha93 in https://github.com/PyTorchLightning/lightning-flash/pull/786
    • TextClassificationData from_dataframe by @kingyiusuen in https://github.com/PyTorchLightning/lightning-flash/pull/785
    • [Doc] Add learn2learn integrations documentation by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/788
    • [Doc] VISSL docs by @ananyahjha93 in https://github.com/PyTorchLightning/lightning-flash/pull/794
    • Add sampler argument to tabular data by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/792
    • [Feat] Add ActiveLearning Loop Customization v2 by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/779
    • Update Readme by @Isaac-Flath in https://github.com/PyTorchLightning/lightning-flash/pull/799
    • Update flash_zero.rst by @williamFalcon in https://github.com/PyTorchLightning/lightning-flash/pull/796
    • Add question answering thumbnail by @SeanNaren in https://github.com/PyTorchLightning/lightning-flash/pull/810
    • enable persistent workers for train and val dataloaders by @dlangerm in https://github.com/PyTorchLightning/lightning-flash/pull/812
    • Adding integration with Label Studio by @KonstantinKorotaev in https://github.com/PyTorchLightning/lightning-flash/pull/554
    • Add from_lists to TextClassificationData by @kingyiusuen in https://github.com/PyTorchLightning/lightning-flash/pull/805
    • [Doc] by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/813
    • Face Detection Task (task-a-thon) by @borhanMorphy in https://github.com/PyTorchLightning/lightning-flash/pull/606
    • add Kaggle links by @Borda in https://github.com/PyTorchLightning/lightning-flash/pull/826
    • [pre-commit.ci] pre-commit suggestions by @pre-commit-ci in https://github.com/PyTorchLightning/lightning-flash/pull/831
    • Add val_loss and test_loss calculation and logging for QnA task by @karthikrangasai in https://github.com/PyTorchLightning/lightning-flash/pull/832
    • Fix typo in learn2learn example by @EStorm21 in https://github.com/PyTorchLightning/lightning-flash/pull/837
    • HotFix for doc build on master by @SeanNaren in https://github.com/PyTorchLightning/lightning-flash/pull/849
    • Bump version to 0.5.1rc0 by @SeanNaren in https://github.com/PyTorchLightning/lightning-flash/pull/850
    • Add FlashDataset by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/851
    • Add FlashDataset update by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/853
    • Missing docstring on methods by @SkafteNicki in https://github.com/PyTorchLightning/lightning-flash/pull/854
    • Add PreprocessTransform by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/852
    • added query_size, and initial_num_labels. removed num_labels_randomly… by @parmidaatg in https://github.com/PyTorchLightning/lightning-flash/pull/822
    • [bugfix] Change to torchmetrics instead of PL by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/858
    • [bugfix] Resolve bug with Lightning 1.5.0rc0 by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/859
    • Freeze structlog version by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/860
    • Add support for PreprocessTransform to FlashDatasets by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/856
    • Fix VideoClassificationData.from_files() not working by @Darktex in https://github.com/PyTorchLightning/lightning-flash/pull/824
    • Fix predict DataLoader in Active learning by @Dref360 in https://github.com/PyTorchLightning/lightning-flash/pull/861
    • Fix inference for instance segmentation by @SeanNaren in https://github.com/PyTorchLightning/lightning-flash/pull/857
    • 2/n Add Custom Data Loading Tutorial + API improvement. by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/855
    • Rename PreprocessTransform to InputTransform by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/868
    • Add Serving to RunningStage by @tchaton in https://github.com/PyTorchLightning/lightning-flash/pull/872
    • Refactor text data loading by @pietrolesci in https://github.com/PyTorchLightning/lightning-flash/pull/870
    • PoC: Revamp optimizer and scheduler experience using registries by @karthikrangasai in https://github.com/PyTorchLightning/lightning-flash/pull/777
    • VISSL datapipeline fix by @ananyahjha93 in https://github.com/PyTorchLightning/lightning-flash/pull/880
    • Fix RTD Build by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/887
    • Fix text classification data loading by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/888
    • Update docutils package version in requirements by @awaelchli in https://github.com/PyTorchLightning/lightning-flash/pull/891
    • Bump version to 0.5.1 by @ethanwharris in https://github.com/PyTorchLightning/lightning-flash/pull/890
    Source code(tar.gz)
    Source code(zip)
  • 0.5.1rc0(Oct 11, 2021)

  • 0.5.0(Sep 7, 2021)

    [0.5.0] - 2021-09-07

    Added

    • Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method (#552)
    • Added support for from_csv and from_data_frame to ImageClassificationData (#556)
    • Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task (#560)
    • Added support for Semantic Segmentation backbones and heads from segmentation-models.pytorch (#562)
    • Added support for nesting of Task objects (#575)
    • Added PointCloudSegmentation Task (#566)
    • Added PointCloudObjectDetection Task (#600)
    • Added a GraphClassifier task (#73)
    • Added the option to pass pretrained as a string to SemanticSegmentation to change pretrained weights to load from segmentation-models.pytorch (#587)
    • Added support for field parameter for loadng JSON based datasets in text tasks. (#585)
    • Added AudioClassificationData and an example for classifying audio spectrograms (#594)
    • Added a SpeechRecognition task for speech to text using Wav2Vec (#586)
    • Added Flash Zero, a zero code command line ML platform built with flash (#611)
    • Added support for .npy and .npz files to ImageClassificationData and AudioClassificationData (#651)
    • Added support for from_csv to the AudioClassificationData (#651)
    • Added option to pass a resolver to the from_csv and from_pandas methods of ImageClassificationData, which is used to resolve filenames given IDs (#651)
    • Added integration with IceVision for the ObjectDetector (#608)
    • Added keypoint detection task (#608)
    • Added instance segmentation task (#608)
    • Added Torch ORT support to Transformer based tasks (#667)
    • Added support for flash zero with the InstanceSegmentation and KeypointDetector tasks (#672)
    • Added support for in_chans argument to the flash ResNet to control the expected number of input channels (#673)
    • Added a QuestionAnswering task for extractive question answering (#607)
    • Added automatic unwrapping of IceVision prediction objects (#727)
    • Added support for the ObjectDetector with FiftyOne (#727)
    • Added support for MP3 files to the SpeechRecognition task with librosa (#726)
    • Added support for from_numpy and from_tensors to AudioClassificationData (#745)

    Changed

    • Changed how pretrained flag works for loading weights for ImageClassifier task (#560)
    • Removed bolts pretrained weights for SSL from ImageClassifier task (#560)
    • Changed the behaviour of the sampler argument of the DataModule to take a Sampler type rather than instantiated object (#651)
    • Changed arguments to ObjectDetector, use head instead of model and append _fpn to the backbone name instead of the fpn argument (#608)

    Fixed

    • Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version (#493)
    • Fixed a bug where train and validation metrics weren't being correctly computed (#559)
    • Fixed a bug where an uncaught ValueError could be raised when checking if a module is available (#615)
    • Fixed a bug where some tasks were not compatible with PyTorch 1.7 due to use of torch.jit.isinstance (#611)
    • Fixed a bug where custom samplers would not be properly forwarded to the data loader (#651)
    • Fixed a bug where it was not possible to pass no metrics to the ImageClassifier or TestClassifier (#660)
    • Fixed a bug where drop_last would be set to True during prediction and testing (#671)
    • Fixed a bug where flash was not compatible with pytorch-lightning >= 1.4.3 (#690)

    Contributors

    @ananyahjha93 @aniketmaurya @aribornstein @Borda @ethanwharris @flozi00 @hhsecond @hihunjin @karthikrangasai @Kinyugo @PeppeSaccardi @pmeier @SeanNaren @sumanmichael @tchaton @tszumowski

    If we forgot someone due to not matching commit email with GitHub account, let us know :]

    Source code(tar.gz)
    Source code(zip)
  • 0.5.0rc0(Sep 1, 2021)

  • 0.4.0(Jun 22, 2021)

    [0.4.0] - 2021-06-22

    Added

    • Added integration with FiftyOne (#360)
    • Added flash.serve (#399)
    • Added support for torch.jit to tasks where possible and documented task JIT compatibility (#389)
    • Added option to provide a Sampler to the DataModule to use when creating a DataLoader (#390)
    • Added support for multi-label text classification and toxic comments example (#401)
    • Added a sanity checking feature to flash.serve (#423)

    Changed

    • Split backbone argument to SemanticSegmentation into backbone and head arguments (#412)

    Fixed

    • Fixed a bug where the DefaultDataKeys.METADATA couldn't be a dict (#393)
    • Fixed a bug where the SemanticSegmentation task would not work as expected with finetuning callbacks (#412)
    • Fixed a bug where predict batches could not be visualized with ImageClassificationData (#438)

    Contributors

    @ehofesmann @ethanwharris @fstroth @lillekemiker @tchaton

    Additional credits to @rlizzo @hhsecond @lantiga @luiscape for building the Flash Serve Engine.

    If we forgot someone due to not matching commit email with GitHub account, let us know :]

    Source code(tar.gz)
    Source code(zip)
  • 0.3.2(Jun 8, 2021)

  • 0.3.1(Jun 8, 2021)

    [0.3.1] - 2021-06-08

    Added

    • Added deeplabv3, lraspp, and unet backbones for the SemanticSegmentation task #370

    Changed

    • Changed the installation command for extra features #346
    • Change resize interpolation default mode to nearest #352

    Deprecated

    • Deprecated SemanticSegmentation backbone names torchvision/fcn_resnet50 and torchvision/fcn_resnet101, use fc_resnet50 and fcn_resnet101 instead #370

    Fixed

    • Fixed flash.Trainer.add_argparse_args not adding any arguments #343
    • Fixed a bug where the translation task wasn't decoding tokens properly #332
    • Fixed a bug where huggingface tokenizers were sometimes being pickled #332
    • Fixed issue with KorniaParallelTransforms to assure to share the random state between transforms #351
    • Fixed a bug where using val_split with overfit_batches would give an infinite recursion #375
    • Fixed a bug where some timm models were mistakenly given a global_pool argument #377
    • Fixed flash.Trainer.from_argparse_args not passing arguments correctly #380

    Contributors

    @akihironitta @aribornstein @carmocca @deepseek-eoghan @edgarriba @ethanwharris

    If we forgot someone due to not matching commit email with GitHub account, let us know :]

    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(May 20, 2021)

    [0.3.0] - 2021-05-20

    Added

    • Added DataPipeline API (#188 #141 #207)
    • Added timm integration (#196)
    • Added BaseViz Callback (#201)
    • Added backbone API (#204)
    • Added support for Iterable auto dataset (#227)
    • Added multi label support (#230)
    • Added support for schedulers (#232)
    • Added visualisation callback for image classification (#228)
    • Added Video Classification task (#216)
    • Added Dino backbone for image classification (#259)
    • Added Data Sources API (#256 #264 #272)
    • Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState (#229)
    • Added Semantic Segmentation task (#239 #287 #290)
    • Added Object detection prediction example (#283)
    • Added Style Transfer task and accompanying finetuning and prediction examples (#262)
    • Added a Template task and tutorials showing how to contribute a task to flash (#306)

    Changed

    • Rename valid_ to val_ (#197)
    • Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState (#229)

    Fixed

    • Fix DataPipeline resolution in Task (#212)
    • Fixed a bug where the backbone used in summarization was not correctly passed to the postprocess (#296)

    Contributors

    @aniketmaurya @carmocca @edgarriba @ethanwharris @pmeier @tchaton

    If we forgot someone due to not matching commit email with GitHub account, let us know :]

    Source code(tar.gz)
    Source code(zip)
  • 0.2.3(Apr 17, 2021)

  • 0.2.2(Apr 5, 2021)

    [0.2.2] - 2021-04-05

    Changed

    • Switch to use torchmetrics (#169)
    • Update lightning version to v1.2 (#133)

    Fixed

    • Fixed classification softmax (#169)
    • Don't download data if exists (#157)
    Source code(tar.gz)
    Source code(zip)
  • 0.2.1(Mar 6, 2021)

    [0.2.1] - 2021-3-06

    Added

    • Added RetinaNet & backbones to ObjectDetector Task (#121)
    • Added .csv image loading utils (#116, #117,#118)

    Changed

    • Set inputs as optional (#109)

    Fixed

    • Set minimal requirements (#62)
    • Fixed VGG backbone num_features (#154)
    Source code(tar.gz)
    Source code(zip)
  • 0.2.0(Feb 12, 2021)

    [0.2.0] - 2021-02-12

    Added

    • Added ObjectDetector Task (#56)
    • Added TabNet for tabular classification (#101)
    • Added support for more backbones(mobilnet, vgg, densenet, resnext) (#45)
    • Added backbones for image embedding model (#63)
    • Added SWAV and SimCLR models to imageclassifier + backbone reorg (#68)

    Changed

    • Applied transform in FilePathDataset (#97)
    • Moved classification integration from vision root to folder (#86)

    Fixed

    • Unfreeze default number of workers in datamodule (#57)
    • Fixed wrong label in FilePathDataset (#94)

    Removed

    • Removed densenet161 duplicate in DENSENET_MODELS (#76)
    • Removed redundant num_features arg from Classification model (#88)
    Source code(tar.gz)
    Source code(zip)
  • 0.1.0(Feb 2, 2021)

    Flash Lightning First Release

    Overview:

    Lightning Flash is a collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning This release will introduce 6 Flash Tasks:

    • ImageClassifier
    • TabularClassifier
    • TextClassifier
    • ImageEmbedder
    • SummarizationTask
    • TranslationTask

    Each task can easily be used for both inference and finetuning.

    [0.1.0] - 02/02/2021

    Added

    • Added flash_notebook examples (#9)
    • Added strategy to trainer.finetune with NoFreeze, Freeze, FreezeUnfreeze, UnfreezeMilestones Callbacks (#39)
    • Added SummarizationData, SummarizationTask and TranslationData, TranslationTask (#37)
    • Added ImageEmbedder(#36)

    Contributors @Borda, @carmocca, @justusschock, @SeanNaren, @SkafteNicki, @tchaton, @williamFalcon

    If we forgot someone due to not matching commit email with GitHub account, let us know :]

    Source code(tar.gz)
    Source code(zip)
Owner
Pytorch Lightning
Pytorch Lightning
Code for Graph-to-Tree Learning for Solving Math Word Problems (ACL 2020)

Graph-to-Tree Learning for Solving Math Word Problems PyTorch implementation of Graph based Math Word Problem solver described in our ACL 2020 paper G

Jipeng Zhang 50 Nov 18, 2021
IDRLnet, a Python toolbox for modeling and solving problems through Physics-Informed Neural Network (PINN) systematically.

IDRLnet IDRLnet is a machine learning library on top of PyTorch. Use IDRLnet if you need a machine learning library that solves both forward and inver

IDRL 46 Nov 22, 2021
SNIPS: Solving Noisy Inverse Problems Stochastically

SNIPS: Solving Noisy Inverse Problems Stochastically This repo contains the official implementation for the paper SNIPS: Solving Noisy Inverse Problem

Bahjat Kawar 20 Nov 26, 2021
Solving reinforcement learning tasks which require language and vision

Multimodal Reinforcement Learning JAX implementations of the following multimodal reinforcement learning approaches. Dual-coding Episodic Memory from

Henry Prior 30 Nov 1, 2021
Finetuning Pipeline

KLUE Baseline Korean(한국어) KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark. See our paper fo

null 58 Nov 16, 2021
Create UIs for prototyping your machine learning model in 3 minutes

Note: We just launched Hosted, where anyone can upload their interface for permanent hosting. Check it out! Welcome to Gradio Quickly create customiza

Gradio 4.1k Nov 24, 2021
Deep learning library for solving differential equations and more

DeepXDE Voting on whether we should have a Slack channel for discussion. DeepXDE is a library for scientific machine learning. Use DeepXDE if you need

Lu Lu 744 Dec 2, 2021
NeuPy is a Tensorflow based python library for prototyping and building neural networks

NeuPy v0.8.2 NeuPy is a python library for prototyping and building neural networks. NeuPy uses Tensorflow as a computational backend for deep learnin

Yurii Shevchuk 693 Nov 19, 2021
Myia prototyping

Myia Myia is a new differentiable programming language. It aims to support large scale high performance computations (e.g. linear algebra) and their g

Mila 409 Nov 20, 2021
A generalized framework for prototyping full-stack cooperative driving automation applications under CARLA+SUMO.

OpenCDA OpenCDA is a SIMULATION tool integrated with a prototype cooperative driving automation (CDA; see SAE J3216) pipeline as well as regular autom

UCLA Mobility Lab 450 Dec 2, 2021
A fast, scalable, high performance Gradient Boosting on Decision Trees library, used for ranking, classification, regression and other machine learning tasks for Python, R, Java, C++. Supports computation on CPU and GPU.

Website | Documentation | Tutorials | Installation | Release Notes CatBoost is a machine learning method based on gradient boosting over decision tree

CatBoost 6.2k Nov 24, 2021
A fast, scalable, high performance Gradient Boosting on Decision Trees library, used for ranking, classification, regression and other machine learning tasks for Python, R, Java, C++. Supports computation on CPU and GPU.

Website | Documentation | Tutorials | Installation | Release Notes CatBoost is a machine learning method based on gradient boosting over decision tree

CatBoost 5.7k Feb 12, 2021
TorchDistiller - a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

This project is a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

yifan liu 82 Nov 24, 2021
A fast, distributed, high performance gradient boosting (GBT, GBDT, GBRT, GBM or MART) framework based on decision tree algorithms, used for ranking, classification and many other machine learning tasks.

Light Gradient Boosting Machine LightGBM is a gradient boosting framework that uses tree based learning algorithms. It is designed to be distributed a

Microsoft 13.2k Nov 27, 2021
docTR by Mindee (Document Text Recognition) - a seamless, high-performing & accessible library for OCR-related tasks powered by Deep Learning.

docTR by Mindee (Document Text Recognition) - a seamless, high-performing & accessible library for OCR-related tasks powered by Deep Learning.

Mindee 486 Nov 24, 2021
Awesome Deep Graph Clustering is a collection of SOTA, novel deep graph clustering methods

ADGC: Awesome Deep Graph Clustering ADGC is a collection of state-of-the-art (SOTA), novel deep graph clustering methods (papers, codes and datasets).

yueliu1999 5 Nov 25, 2021
Data and Code for ACL 2021 Paper "Inter-GPS: Interpretable Geometry Problem Solving with Formal Language and Symbolic Reasoning"

Introduction Code and data for ACL 2021 Paper "Inter-GPS: Interpretable Geometry Problem Solving with Formal Language and Symbolic Reasoning". We cons

Pan Lu 56 Nov 4, 2021