Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Related tags

Deep Learning objax
Overview

Objax

Tutorials | Install | Documentation | Philosophy

This is not an officially supported Google product.

Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX -- a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.

This is the developer repository of Objax, there is very little user documentation here, for the full documentation go to objax.readthedocs.io.

You can find READMEs in the subdirectory of this project, for example:

User installation guide

You install Objax using pip as follows:

pip install --upgrade objax

Objax supports GPUs but assumes that you already have some version of CUDA installed. Here are the extra steps:

# Update accordingly to your installed CUDA version
CUDA_VERSION=11.0
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`

Useful environment configurations

Here are a few useful options:

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false

Testing your installation

You can test your installation by running the code below:

import jax
import objax

print(f'Number of GPUs {jax.device_count()}')

x = objax.random.normal(shape=(100, 4))
m = objax.nn.Linear(nin=4, nout=5)
print('Matrix product shape', m(x).shape)  # (100, 5)

x = objax.random.normal(shape=(100, 3, 32, 32))
m = objax.nn.Conv2D(nin=3, nout=4, k=3)
print('Conv2D return shape', m(x).shape)  # (100, 4, 32, 32)

Typically if you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.

Runing code examples

Clone the code repository:

git clone https://github.com/google/objax.git
cd objax/examples

Citing Objax

To cite this repository:

@software{objax2020github,
  author = {{Objax Developers}},
  title = {{Objax}},
  url = {https://github.com/google/objax},
  version = {1.2.0},
  year = {2020},
}

Developer documentation

Here is information about development setup and a guide on adding new code.

Comments
  • More control over var/module namespace.

    More control over var/module namespace.

    I got my first 'hello world' model experiment working w/ Objax. I adapted my PyTorch EfficientNet impl. Overall pretty smooth, currently wrapping Conv2d so I can get the padding I want.

    One thing that stuck out after inspecting the model, the var namespace is a mess. An aspect of modelling that I value highly is the ability to have sensible checkpoint/var maps to work with. I often end up dealing with conversions between frameworks, exports for mobile or embedded targets and having your vars (parameters) sensibly named, and often being able to control those names in the originating framework is important.

    Any thoughts on improving this? The current name/scoping mechanism forces the inclusion of the Module class names, is that necessary? Shouldn't attr names through the tree be enough for uniqueness?

    Also, there is no ability to specify names for modules in sequential containers. I use this quite often for frameworks that have it. Sometimes I don't care much (long list of block repeats, 0..n is fine), but for finer grained blocks I like to know what conv is what by looking at the var names. '0.b, o.w' etc isn't very useful.

    I'll post an example of the var keys below, and comparison point for pytorch.

    feature request 
    opened by rwightman 29
  • upsample2d function rough draft

    upsample2d function rough draft

    Hi Team, i am pretty new to contributing in opensource projects. Please have a review of the upsample2d function and let me know of anything that is required or should be changed. the function is added in objax.function.ops module.

    opened by naruto-raj 22
  • Add mean squared logarithmic loss function

    Add mean squared logarithmic loss function

    1. Added mean squared logarithmic loss function
    2. In the CONTRIBUTIONS.md file, there is no mention of code-style. So, I am using 4-spaces.
    3. I haven't formatted the code using black as there is no mention of any formatter as well.

    I will add the tests once the above points are clear

    opened by AakashKumarNain 16
  • Initial dot product attention

    Initial dot product attention

    Adds attention, per #61 So, first I'm really sorry about taking so long, but college got complicated in the pandemic and I wasted a lot of time getting organized. Also, Attention is a quite general concept, and even implementations of the same type of attention differ significantly (haiku, flax) So @david-berthelot and @aterzis-google I would like to ask a few questions just to make sure my implementation is going in the right direction

    1. I think I will implement a dot product attention, a multi-head attention and a masked attention, is that ok?
    2. What do you think of the dot product attention implementation? What do you think I need to change? Thanks for the patience and opportunity.
    opened by joaogui1 12
  • "objax.variable.VarCollection is not a valid JAX type" when creating a custom optimizer

    Hi, I wish to create a custom optimizer to replace the opt(lr=lr, grads=g) line in the example https://github.com/google/objax/blob/master/examples/classify/img/cifar10_simple.py

    Instead, I replaced it with

    for grad, p in zip(g, model_vars):
          p.value -= lr * grad   
    

    and then supplied model.vars() as an argument to train_op. However, I received an error: objax.variable.VarCollection is not a valid JAX type. Can someone help me with this issue? Here is a minimal working example which reproduces the error.

    import random
    import numpy as np
    import tensorflow as tf
    from objax.zoo.wide_resnet import WideResNet
    
    # Data
    (X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data()
    X_train = X_train.transpose(0, 3, 1, 2) / 255.0
    X_test = X_test.transpose(0, 3, 1, 2) / 255.0
    
    # Model
    model = WideResNet(nin=3, nclass=10, depth=28, width=2)
    #opt = objax.optimizer.Adam(model.vars())
    predict = objax.Jit(lambda x: objax.functional.softmax(model(x, training=False)),
                        model.vars())
    # Losses
    def loss(x, label):
        logit = model(x, training=True)
        return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
    
    gv = objax.GradValues(loss, model.vars())
    
    def train_op(x, y, model_vars, lr):
        g, v = gv(x, y)
        for grad, p in zip(g, model_vars):
          p.value -= lr * grad   
        return v
    
    
    # gv.vars() contains the model variables.
    train_op = objax.Jit(train_op, gv.vars()) #I deleted opt.vars()
    
    for epoch in range(30):
        # Train
        loss = []
        sel = np.arange(len(X_train))
        np.random.shuffle(sel)
        for it in range(0, X_train.shape[0], 64):
            loss.append(train_op(X_train[sel[it:it + 64]], Y_train[sel[it:it + 64]].flatten(), model.vars(), 4e-3 if epoch < 20 else 4e-4)) #I added model.vars() 
    
    opened by RXZ2020 11
  • Enforcing positivity (or other transformations) of TrainVars

    Enforcing positivity (or other transformations) of TrainVars

    Hi,

    Is it possible to declare constraints on trainable variables, e.g. forcing them to be positive via an exponential or softplus transformation?

    In an ideal world, we would be able to write something like: self.variance = objax.TrainVar(np.array(1.0), transform=positive)

    Thanks,

    Will

    p.s. thanks for the great work on objax so far, it's a pleasure to use.

    opened by wil-j-wil 10
  • Training state as a Module attribute

    Training state as a Module attribute

    As mentioned in a Twitter thread, I am curious about the decision to propagate training state through the call() chain. From my perspective this approach adds more boilperplate code, and more chance of making a mistake (not propagating the state to a few instances of a module with a BN or dropout layer, etc). If the state changed every call like the input data, it would make more sense to pass it with every forward, but I can't think of cases where that is common? For small models it doesn't make much difference, but as they grow with more depth and breadth of submodules, the extra args are more noticeable.

    I feel one of the major benefits of an OO abstraction for NN is being able to push some attributes like this into the class structure vs forcing it to be forwarded through every call in a functional manner. I sit in the middle ground (pragmatic) of OO vs functional. Hidden state can be problematics, but worth it if it keeps interfaces clean.

    Besides TF/Keras, most DL libs managetraining state as module attr or some sort of context

    • PyTorch - nn.Module has a self.training attribute, recursively set on train()/eval() calls on the model/modules - https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval
    • MxNet Gluon - a context manager sets scope with autograd.train_mode() with autograd.predict_mode() - https://gluon.mxnet.io/chapter03_deep-neural-networks/mlp-dropout-gluon.html
    • Swift for TF - a thread-local context holds learningPhase - https://www.tensorflow.org/swift/api_docs/Structs/Context

    It should be noted that Swift for TF started out Keras and objax like with the training state passed through call().

    Disclaimer: I like PyTorch, I do quite a bit of work with that framework. It's not perfect but I feel they really did a good job in terms of interface, usibility, evolution of the API. I've read some other comments here and acknowledge the 'we don't want to be like framework/lib X, or Y just because. If you disagree go fork yourself'. Understood, any suggestions I make are not just to be like X, but to bring elemtents of X that work really well to improve this library.

    I currently maintain some PyTorch model collections, https://github.com/rwightman/pytorch-image-models and https://github.com/rwightman/efficientdet-pytorch as examples. I'm running into a cost ($$) wall with experiments supporting my OS work and experiments re GPU. TPU costing is starting to look far more attractive. PyTorch XLA is not proving to be a great option but JAX with a productive interface looks like it could be a winning solution with even more flexibility .

    I'm willing to contribute code for changes like this, but at this point it's matter of design philosophy :)

    opened by rwightman 9
  • Implementing 2 phases DP-SGD

    Implementing 2 phases DP-SGD

    This PR implements a two-phase algorithm for per-sample gradient clipping with the goal of improving memory efficiency for the training of private deep models. The two steps are: (1) accumulate the norms of the gradient per sample and (2) use those norm values to perform a weighted backward pass that is equivalent to per-sample clipping. The user can choose whether to use this new algorithm or the currently implemented one through a boolean argument.

    The unit-tests have been adapted to check results for both algorithms.

    Let me know if this fits well!

    opened by lberrada 7
  • Give better error message when calling Parallel() without replicate()

    Give better error message when calling Parallel() without replicate()

    Currently if you forget to call replicate() on a Parallel module, it dies somewhere in JaX land in between the 5th and 6th circles of hell. This error makes it possible to understand what's going on and find your way back.

    opened by carlini 7
  • Naming of the `GradValues` function

    Naming of the `GradValues` function

    If I understand right, GradValues essentially does two things: computing gradients and computing model final values.

    So why not split it into two functions? Or if we keep the current form, could we name it GradAndValuesFn? Just thinking this is a prominent function and want to keep it the easiest for people beginning to use the framework. An easy name as fit() and predict() made scikit-learn.

    opened by jli05 6
  • Explicit padding mode

    Explicit padding mode

    It looks like objax currently limits padding to one of VALID or SAME. This prevents the ability to use explicit padding and would prevent compatibility with models from PyTorch, Gluon that only support explicit (symmetric) padding without adding extra Pad layers to the model.

    It'd be nice to at minimum add the ability to support TF style explicit padding (specify both sides of every dim), the underlying jax conv impl is able to receive a [[0, 0], [pad_beg, pad_end],[pad_beg, pad_end], [0, 0]] spec like other low level TF conv.

    Even nicer would be a simplificed, per-spatial dim symmetric values like PyTorch, Gluon [pad_h, pad_w] or just pad . My default for most 2D convnets in PyTorch is to use pad = ((stride - 1) + dilation * (kernel_size - 1)) // 2, which results in a 'same-ish' padding value. This can always be done on top of the full low/high padding sequence above.

    Some TF models explicitly work around the limitations of SAME padding. By limitations, I mean the fact that you end up with input dependent padding that can be aysmmetric and shift your feature maps relative to each other in a manner that varies as you change your input size. https://github.com/tensorflow/models/blob/146a37c6663e4a249e02d3dff0087b576e3dc3a1/research/deeplab/core/xception.py#L81-L201

    Possible interfaces:

    • padding : Union[ConvPadding, Sequence[Tuple[int, int]]] (like conv_general_dilated but with the enum for valid/same)

    • Add more modes the enum and associated values for those that need it via a dataclass

    class PaddingType(enum.Enum):
      """An Enum holding the possible padding values for convolution modules."""
        SAME = 'SAME'
        VALID = 'VALID'
        RAW = 'RAW'  # specify padding as seq of high/low tuples
        SYM = 'SYM'  # specify symmetric padding for spatial dim as tuple for H, W or single int
    
    @dataclass
    class Padding:
        type: PaddingType = PaddingType.SAME
        value: Union[Sequence[Tuple[int, int]], Tuple[int, int], int] = None
    
        @classmethod
        def same(cls):
            return Padding(PaddingType.SAME)
    
        @classmethod
        def valid(cls):
            return Padding(PaddingType.VALID)
    
        @classmethod
        def raw(cls, value: Sequence[Tuple[int, int]]):
            return Padding(PaddingType.RAW, value=value)
    
        @classmethod
        def sym(cls, value: Union[Tuple[int, int], int]):
            return Padding(PaddingType.SYM, value=value)
    
    feature request 
    opened by rwightman 6
  • `objax.variable.VarCollection.update` not compliant with key-value assignment

    `objax.variable.VarCollection.update` not compliant with key-value assignment

    Hi everyone! Thanks for the awesome work with objax and the JAX environment, and happy holidays!

    I'm trying to load some VarCollection and/or Dict[str, jnp.DeviceArray] params into the model.vars() which is a VarCollection class, and I can do so by:

    for key, value in new_params.items():
        model.vars()[key].assign(value)
    

    But I'd expect objax.variable.VarCollection.update to work the same way e.g.

    model.vars().update(new_params)
    

    And the later doesn't work while the first one does, not sure if it's because that's not the intended behavior for VarCollection.update or if I'm doing anything wrong... But just the first one works, which for the moment is fine for what I need, but wanted to mention this just in case there's something not working as expected.

    opened by alvarobartt 1
  • `objax.variable.VarCollection.update` fails when passing `Dict[str, Any]`

    `objax.variable.VarCollection.update` fails when passing `Dict[str, Any]`

    Hi everyone! Thanks for the awesome work with objax and the JAX environment, and happy holidays!

    I was playing around for objax for a bit, and realized that if you try to update the model.vars() which is a VarCollection using the VarCollection.update method overwriting the default dict.update method, if what you pass to the function is a Python dictionary and not a VarCollection it fails, as it's being cast into a Python list, and then we're trying to loop over the items of a list as if it was a Python dictionary, so it throws a ValueError: too many values to unpack (expected 2).

    https://github.com/google/objax/blob/53b391bfa72dc59009c855d01b625049a35f5f1b/objax/variable.py#L311-L318

    Is this intended? Shouldn't VarCollection.update just loop over classes that allow .items()?

    opened by alvarobartt 0
  • Update nn.rst

    Update nn.rst

    The channel number for 'in' is currently set as c which is incorrect because c is referring to the output channel number. Instead this needs to be set as t (which is the variable that iterates over the input channel numbers). in[n,c,i+h,j+w] should be changed to in[n,t,i+h,j+w]

    opened by divyas248 1
  • pmean inside objax.parallel causes multithreading deadlock for more than 2 gpus

    pmean inside objax.parallel causes multithreading deadlock for more than 2 gpus

    Hi, I've noticed a problem, where I'd like to ask for your expertise. I'm not entirely sure if it is an objax problem or rather a Jax problem under the hood, but as it is triggered by objax commands I'll post it here.

    Description

    In particular, when combining objax.Parallel and objax.functional.pmean (as done in this tutorial) I encounter problems with more than 2 GPUs (with 2 GPUs it works fine). It results in a deadlock situation, where nothing happens anymore. If I understand the tutorial correctly, the pmean is necessary to average the gradients of all cards.

    Minimal reproducible example

    import objax
    import numpy as np
    from objax.zoo.resnet_v2 import ResNet18
    from jax import numpy as jnp, device_count
    from tqdm import tqdm
    
    
    if __name__ == "__main__":
        print(f"Num devices: {device_count()}")
        model = ResNet18(3, 1)
        opt = objax.optimizer.SGD(model.vars())
    
        @objax.Function.with_vars(model.vars())
        def loss(x, label):
            return objax.functional.loss.mean_squared_error(
                model(x, training=True), label
            ).mean()
    
        gv = objax.GradValues(loss, model.vars())
    
        train_vars = model.vars() + gv.vars() + opt.vars()
    
        @objax.Function.with_vars(train_vars)
        def train_op(
            image_batch,
            label_batch,
        ):
    
            grads, loss = gv(image_batch, label_batch)
            # grads = objax.functional.parallel.pmean(grads) # this line
            # loss = objax.functional.parallel.pmean(loss) # and this line
            loss = loss[0]
            opt(1e-3, grads)
            return loss, grads
    
        train_op = objax.Parallel(train_op, reduce=jnp.mean, vc=train_vars)
    
        with (train_vars).replicate():
            for _ in tqdm(range(10), total=10):
                data = jnp.array(np.random.randn(512, 3, 224, 224))
                label = jnp.zeros((512, 1))
                loss, grads = train_op(data, label)
    
    

    Whenever you comment in the two lines with pmean the program gets stuck. However, if I understood it correctly, this is necessary to get the average of the gradients over all cards.

    Error traces

    As with most deadlock bugs you don't get an error stack trace. However, I have two clues that I've found so far. One is that if this is uncommented, the following appears:

    2022-08-22 14:55:46.462557: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
    2022-08-22 14:55:48.543291: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:36] Thread is unstuck! Warning above was a false-positive. Perhaps the timeout is too short.
    

    The other is that if I manually interrupt it with ctrl+c I got this lengthy stacktrace

    Setup

    We use 4 NVIDIA A40 GPUs with CUDA Version 11.7 (Driver Version 515.65.01), cudnn 8.2.1.32, jax version 0.3.15, objax version 1.6.0

    opened by a1302z 3
Releases(v1.6.0)
  • v1.6.0(Feb 1, 2022)

  • v1.4.0(Apr 1, 2021)

    • Added prototype of ducktyping of Objax variables as JAX arrays
    • Added prototype of automatic variable tracing
    • Added learning rate scheduler
    • Various bugfixes
    Source code(tar.gz)
    Source code(zip)
  • v1.3.1(Feb 3, 2021)

  • v1.3.0(Jan 29, 2021)

    • Feature: Improved error messages overall
    • Feature: Improved BatchNorm numerical stability
    • Feature: Objax2Tf for serving objax using TensorFlow
    • Feature: New API objax.optimizer.ExponentialMovingAverageModule for easy moving average of a model
    • Feature: Automatic broadcasting of scalars for objax.Parallel
    • Feature: New optimizer: LARS
    • Feature: New API added to functional (lax.scan)
    • Feature: Modules can be printed to nicely readable text now (repr)
    • Feature: New interpolate API (for images)
    • Bugfix: make objax.Sequential work with latest JAX
    Source code(tar.gz)
    Source code(zip)
  • v1.2.0(Nov 2, 2020)

    • Feature: Improved error messages.

    • Feature: Extended syntax: allow assigning TrainVar without TrainRef for direction experimentation.

    • Feature: Extended padding options or pad and convolution.

    • Feature: Modified ResNet_V2 to be Keras compatible.

    • Feature: Defaults can be overridden in call for Adam, Momentum.

    • BugFix: Layer norm initialization in GPT-2.

    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
pybaum provides tools to work with pytrees which is a concept burrowed from JAX.

pybaum provides tools to work with pytrees which is a concept burrowed from JAX.

Open Source Economics 9 May 11, 2022
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
Tools to create pixel-wise object masks, bounding box labels (2D and 3D) and 3D object model (PLY triangle mesh) for object sequences filmed with an RGB-D camera.

Tools to create pixel-wise object masks, bounding box labels (2D and 3D) and 3D object model (PLY triangle mesh) for object sequences filmed with an RGB-D camera. This project prepares training and testing data for various deep learning projects such as 6D object pose estimation projects singleshotpose, as well as object detection and instance segmentation projects.

null 305 Dec 16, 2022
Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Elegy Elegy is a framework-agnostic Trainer interface for the Jax ecosystem. Main Features Easy-to-use: Elegy provides a Keras-like high-level API tha

null 435 Dec 30, 2022
A lossless neural compression framework built on top of JAX.

Kompressor Branch CI Coverage main (active) main development A neural compression framework built on top of JAX. Install setup.py assumes a compatible

Rosalind Franklin Institute 2 Mar 14, 2022
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

Karush Suri 8 Nov 7, 2022
Machine Learning From Scratch. Bare bones NumPy implementations of machine learning models and algorithms with a focus on accessibility. Aims to cover everything from linear regression to deep learning.

Machine Learning From Scratch About Python implementations of some of the fundamental Machine Learning models and algorithms from scratch. The purpose

Erik Linder-Norén 21.8k Jan 9, 2023
Vowpal Wabbit is a machine learning system which pushes the frontier of machine learning with techniques such as online, hashing, allreduce, reductions, learning2search, active, and interactive learning.

This is the Vowpal Wabbit fast online learning code. Why Vowpal Wabbit? Vowpal Wabbit is a machine learning system which pushes the frontier of machin

Vowpal Wabbit 8.1k Jan 6, 2023
FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning

FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning (FedML) developed and maintained by Scaleout Systems. FEDn enables highly scalable cross-silo and cross-device use-cases over FEDn networks.

Scaleout 75 Nov 9, 2022
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Google 208 Dec 14, 2022
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)

This video in better quality. einops Flexible and powerful tensor operations for readable and reliable code. Supports numpy, pytorch, tensorflow, and

Alex Rogozhnikov 6.2k Jan 1, 2023
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

null 128 Dec 27, 2022
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Evgenii Nikishin 43 Sep 28, 2022
A Pytree Module system for Deep Learning in JAX

Treex A Pytree-based Module system for Deep Learning in JAX Intuitive: Modules are simple Python objects that respect Object-Oriented semantics and sh

Cristian Garcia 216 Dec 20, 2022
A JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short.

BraVe This is a JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short. The model provided in this package wa

DeepMind 44 Nov 20, 2022