Git Product home page Git Product logo

aqt's Introduction

AQT : Accurate Quantized Training

AQT is a software library designed for easy tensor operation quantization in JAX.

Features

Let us know if you have any problem with aqt applications by filing an issue on Github.

Usage

Tensor contraction operations in JAX-based neural network libraries, i.e., any form of (high-order) matrix multiplications, including but not limited to jax.numpy.einsum and flax.linen.DenseGeneral, call lax.dot_general as its core computation. Quantizing a neural network in JAX simply requires substituting lax.dot_general with a quantized variant and keeping other parts as-is, which we call "quantization injection". JAX-based NN libraries, such as Flax and Pax, provide an API for this substitution when creating layers.

In this section, we show how AQT produces a quantized dot_general and inject it into a neural network defined in JAX. The toy example below can be found in the example colab.

First, install the AQT package named as aqtp in PyPI.

# install the AQT library
!pip install aqtp

Next, import aqt.jax.v2. Other AQT versions are obsolete.

# necessary imports
import aqt.jax.v2.flax.aqt_flax as aqt
import aqt.jax.v2.config as aqt_config
import flax.linen as nn

A sample neural network defined in Flax looks like the following (as a toy example we use a simple MLP, but it can be any model):

class MlpBlock(nn.Module):
  config: aqt_config.DotGeneral | None

  @nn.compact
  def __call__(self, inputs):
    dot_general = aqt.AqtDotGeneral(self.config)
    x = nn.Dense(dot_general=dot_general, features=inputs.shape[-1] * 4)(inputs)
    x = nn.relu(x)
    x = nn.Dense(dot_general=dot_general, features=inputs.shape[-1])(x)
    return x

AQT can quantize the model by simply replacing the dot_general in nn.Dense with a quantized dot_general created by the aqt configuration. The example specifies an AQT configuration that quantizes both forward and backward passes to int8. Now let's test it.

import jax
import jax.numpy as jnp
import numpy as np

# Generate some random matrices as inputs
def gen_matrix(rows, columns, seed=0):
  np.random.seed(seed)
  return np.random.normal(size=(rows, columns)).reshape((rows, columns))

inputs = gen_matrix(3, 4)

# test function that initializes the model and compute the forward pass
def init_and_eval(name, mlp_block, init_seed=0, eval_seed=0):
  model = mlp_block.init(jax.random.PRNGKey(init_seed), inputs)
  out = mlp_block.apply(model, inputs, rngs={'params': jax.random.key(eval_seed)})
  print(f"{name}:\n", out)

# create a config that quantizes both forward and backward passes to int8
int8_config = aqt_config.fully_quantized(fwd_bits=8, bwd_bits=8)

# run and print results
mlp_fp16 = MlpBlock(config=None)
mlp_int8 = MlpBlock(config=int8_config)
init_and_eval('mlp_fp16', mlp_fp16)
init_and_eval('mlp_int8', mlp_int8)

Results will be the following:

mlp_fp16:
 [[ 0.720744    1.5375545  -2.6456933  -1.7605033 ]
 [-0.01541612  0.09728499 -1.5742414  -0.3737522 ]
 [ 0.4071759   1.1941448  -0.6982092  -0.48336366]]
mlp_int8:
 [[ 0.7030779   1.5099456  -2.6334763  -1.7550919 ]
 [-0.00901393  0.08774488 -1.5644912  -0.3728472 ]
 [ 0.40121436  1.189411   -0.6939187  -0.48000643]]

We can see that the quantized MLP produces similar outputs as the unquantized one.

Flexible Quantization Configs

The example in usage uses the default configuration that quantizes both forward and backward passes to 8-bit, but AQT provides a much more flexible configuration system. The DotGeneral class can configure forward and backward tensor contraction operations separately.

@dataclasses.dataclass
class DotGeneral:
  """Configuration of quantization of dot_general and its gradients."""
  fwd: DotGeneralRaw
  dlhs: DotGeneralRaw
  drhs: DotGeneralRaw

In each DotGeneral.DotGeneralRaw, we can configure quantization of each input tensor of those ops separately and the hardware dtype to use (eg. jnp.bfloat16, jnp.float16, jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.int8, jnp.int4).

@dataclasses.dataclass
class DotGeneralRaw:
  """Configuration of quantization of one dot_general without gradient."""
  lhs: Tensor  # left hand side
  rhs: Tensor  # right hand side
  dg_in_dtype: Optional[DType]
  dg_accumulator_dtype: Optional[DType]
  local_aqt: Optional[LocalAqt]  # sharded quantization

Inside config.Tensor we can configure the numerics used for each tensor, which includes number of bits, calibration algorithm, stochastic rounding, and many other quantization parameters.

@dataclasses.dataclass
class Tensor:
  """Configuration of quantization of one tensor or one side of tensor op."""
  numerics: Numerics
  calib_shared_axes: Optional[list[int]]
  scale_stop_grad: bool
  calibration: calibration.Calibration  # calibration algorithm
  po2_scale: bool  # round calibration to power of 2
  use_fake_quant: bool
  use_fwd_quant: Optional[bool]  # use quantized fwd in the bwd pass

AQT Versions

As of today there are several independent AQT implementations in this package:

  • JAX Legacy AQT Obsolete version of AQT still used by some customers.
  • JAX AQTv1 Version of AQT that was developed with acceleration of NN inference in mind.
  • TF AQTv1 Tensorflow counterpart of JAX AQTv1.
  • JAX AQTv2 AQT implementing universal matmul quantization.

AQTv2 is the recommended library. We plan to port remaining features from AQTv1 to AQTv2 and delete AQTv1 in early Q1 2024. Below we describe details about that.

Inference Acceleration

The most important AQTv2 (to be ported from AQTv1) missing features are:

Lack of these features prevents AQTv2 from accelerating inference with small batch. The only option today is dynamic quantization where each tensor op is quantized independently and quantization scales are found just-in-time.

Backpropagation Acceleration

AQTv2 speeds up training and fine-tuning. We verified 1.2x to 1.4x reduction in step time on 1B to 16B large Transformer models to a given quality on TPUs.

Today in order to do it correctly one needs to understand that for each two-argument tensor op (matmul, einsum, conv) in the forward pass, there are two in the backward pass. One has to understand how to configure them.

We will be updating config file with current best practices.

AQT Serving

Quantization is applied to both activations and weights right before a matmul. During training, gradients flow "through" the operation and update the latent floating-point weights. During model serving, however, re-quantizing weights is unnecessary and has large overhead because:

  1. Weights are stationary during serving. Recomputing weight quantization is a waste.
  2. Recomputing quantization requires loading the latent BF16 weights, which consumes 2x more memory bandwidth than loading INT8 weights.

AQT provides a solution to the above problem, which is called "serving conversion". Serving conversion is a constant folding process that creates new variables in the checkpoint to store the folded INT8 weights. It requires running one dummy inference.

Note that:

  1. Weights can be either lhs or rhs inputs to a matmul. AQT supports storing both in checkpoints, configured by lhs_quant_mode and rhs_quant_mode. Users need to set the correct configuration themselves.
  2. In some cases where checkpoint variables require metadata such as sharding axis, users can configure variable initializers to fit their needs.

At serving mode, AQT will look for the INT8 variables in a checkpoint. If found, it skips the weight quantization and returns INT8 variables as-is as inputs to matmul, thus saving the memory bandwidth and avoids recalculation.

Applying the conversion to an unquantized floating-point model is equivalent to post-training quantization (PTQ) serving. Applying the conversion to a forward-only quantized AQT model is quantization-aware training (QAT) serving. In that case it is important to use the same AQT config during training and serving to maintain WYTIWYS.

The flax end-to-end example provides a code snippet on how to perform serving conversion and model serving in AQT.

Other Weight Transformations

Consider matmul(a, w) as an activation-weight matmul. Sometimes there is another transformation T on weights before passed into the matmul, i.e., matmul(..., T(w)) is computed. In this case, quantizing weights directly, i.e., matmul(..., T(Q(w))), can reduce the checkpoint size and save memory bandwidth, but it will not accelerate the matmul because T will likely return floats. In order to get matmul acceleration, the quantization function Q should be inserted just before matmul, i.e., matmul(..., Q(T(w))).

AQT pursues the goal of both compressing the checkpoint AND accelerating the matmul. This requires storing the entire w_q = Q(T(w)) in the checkpoint and using it in serving directly, i.e., matmul(..., w_q).

Note that AQT provides a quantized matmul_aqt as a whole such that matmul_aqt(..., T(w)) = matmul(..., Q(T(w))). Q(T(w)) is not visible outside of matmul_aqt. The main reason is that matmul_aqt has custom gradient defined for it.

How AQT Works Internally

In this section we:

  • show how to get quantization acceleration in JAX,
  • explain what AQT INT8 does under-the-hood (using the simplest INT8 configuration),
  • run the code on a simple example.

Code in this section can be found and executable in the example colab. Note that this section mainly explains how AQT works and why it can achieve a good quality. For AQT tutorial, user can refer to the usage section.

The matmul_true_int8 takes real INT8 as inputs, returns int32. The matmul computation jnp.matmul calls lax.dot_general in its source, which is a JAX wrapper for XLA DotGeneral op that implements all MXU ops (this is where we have int8 acceleration on TPUs) except convolution. This is how one can get hardware acceleration of quantized matmul in JAX.

import jax.numpy as jnp

def matmul_true_int8(lhs, rhs):
  assert lhs.dtype == jnp.int8
  assert rhs.dtype == jnp.int8
  result = jnp.matmul(lhs, rhs, preferred_element_type=jnp.int32)
  assert result.dtype == jnp.int32
  return result

Generate some random data:

batch_size = 3
channels_in = 4
channels_out = 5
a = gen_matrix(batch_size, channels_in) # Activations
w = gen_matrix(channels_in, channels_out) # Weights

Below is how AQT works internally using the simplest INT8 configuration. Even though names such as "batch" and "channels" are used, "w" and "a", which are evocative of neural networks, one may note that aqt_matmul_int8 algorithm is not DNN specific.

def aqt_matmul_int8(a, w):
  max_int8 = 127
  # This function is customizable and injectable, i.e:
  # users can inject custom quant code into an AQT config.
  def quant_int8(x):
    return jnp.clip(jnp.round(x), -max_int8, max_int8).astype(jnp.int8)

  # Calibration. Calibration function is also customizable and injectable.
  a_s = max_int8 / jnp.max(jnp.abs(a), axis=1, keepdims=True)
  w_s = max_int8 / jnp.max(jnp.abs(w), axis=0, keepdims=True)
  assert a_s.shape == (batch_size, 1) # shapes checked for illustration
  assert w_s.shape == (1, channels_out)

  # int8 matmul with int32 accumulator
  result = matmul_true_int8(quant_int8(a * a_s), quant_int8(w * w_s)) / (a_s * w_s)
  assert result.shape == (batch_size, channels_out)

  return result

Note that each example in a batch and each output channel will have their own separate scale. This reduces the effect of outliers in "w" and "a" to just one row or column, making a tighter calibration and much better quality of quantization. Comparing aqt_matmul_int8 to float matmul, their outputs are close.

print(f"jnp.matmul(a, w):\n", jnp.matmul(a, w))
print(f"aqt_matmul_int8(a, w):\n", aqt_matmul_int8(a, w))
# should expect the following outputs
jnp.matmul(a, w):
 [[ 3.6095254   5.8575077   1.9510972   4.732388    1.9792626 ]
 [ 4.335892    0.9743651   2.7298734   4.3540883   3.637487  ]
 [-0.07735002  2.7310796  -0.3519049   0.19912864 -1.2023292 ]]
aqt_matmul_int8(a, w):
 [[ 3.5998788   5.8562713   1.9385538   4.7426414   1.9792401 ]
 [ 4.321886    0.99681264  2.737299    4.3591022   3.6352503 ]
 [-0.07714217  2.7415617  -0.35343346  0.20568734 -1.1974115 ]]

Citing AQT

We will be publishing AQT whitepaper soon.

aqt's People

Contributors

andsteing avatar cdh4696 avatar chiamp avatar dusenberrymw avatar gnecula avatar hawkinsp avatar ivyzx avatar jianlijianli avatar jihwanlee-alphago avatar maxwillzq avatar rchen152 avatar rybakov avatar shivaniag avatar superbobry avatar tink-expo avatar yazdanbakhsh avatar ychzhang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

aqt's Issues

Does JAXv2 allow for arbitrary quantization?

Hi everyone! I would like to use AQT to quantize deep learning models that then I would infer on my hardware (FPGAs). Does JAXv2 support arbitrary quantization (e.g., INT4)? I am asking because I only saw examples using INT8 data type.

Refactor config/code classes to follow Flax.

For every piece of logic (like Numerics, Calibration, Tensor, DotGeneralRaw, DotGeneral, we should have a single class with "dataclass" field that configure it and methods that execute the logic.

Broken `aqtp-0.1.1` package: missing `aqt` package prefix

The most recent aqtp-0.1.1 package is missing the aqt prefix in the installed packages:

$ pip install aqtp==0.1.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting aqtp==0.1.1
  Downloading aqtp-0.1.1-py3-none-any.whl (405 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 405.5/405.5 kB 7.9 MB/s eta 0:00:00
Installing collected packages: aqtp


$ pip show -f aqtp
Successfully installed aqtp-0.1.1
Name: aqtp
Version: 0.1.1
Summary: AQT: Accurate Quantized Training
Home-page: https://github.com/google/aqt
Author: Cerebra Catalyst team
Author-email: [email protected]
License: 
Location: /usr/local/lib/python3.10/dist-packages
Requires: 
Required-by: 
Files:
  aqtp-0.1.1.dist-info/INSTALLER
  aqtp-0.1.1.dist-info/LICENSE
  aqtp-0.1.1.dist-info/METADATA
  aqtp-0.1.1.dist-info/RECORD
  aqtp-0.1.1.dist-info/REQUESTED
  aqtp-0.1.1.dist-info/WHEEL
  aqtp-0.1.1.dist-info/top_level.txt
  common/__init__.py
  common/__pycache__/__init__.cpython-310.pyc
  common/__pycache__/aqt_common.cpython-310.pyc
  common/__pycache__/aqt_config.cpython-310.pyc
  common/__pycache__/aqt_config_schedule_test.cpython-310.pyc
  common/__pycache__/aqt_config_utils.cpython-310.pyc
  common/__pycache__/emulated_floating_points.cpython-310.pyc
  common/__pycache__/emulation_utils.cpython-310.pyc
  common/aqt_common.py
  common/aqt_config.py
  common/aqt_config_schedule_test.py
  common/aqt_config_utils.py
  common/emulated_floating_points.py
  common/emulation_utils.py
  jax/__init__.py
  jax/__pycache__/__init__.cpython-310.pyc
...

Note that these both work as expected:

  • pip install aqtp==0.1.0
  • pip install aqtp@git+https://github.com/google/aqt.git

The problem should easily be fixed by

  1. yank the faulty 0.1.1 version
  2. increase the version to 0.1.2
  3. create a new wheel
  4. verify the wheel installs correctly (I checked at HEAD and it seems to work as expected; not sure what exactly went wrong when uploading the 0.1.1 package)
  5. upload the new version to PyPi

Note that the faulty 0.1.1 package breaks all downstream useres, e.g. google-research/vision_transformer#271

To avoid these problems in the future, it might be a good idea to install an automatic Python publish workflow (e.g. like this example)

Port static quantization from AQTv1 to AQTv2

Right now in AQTv2, we have only dynamic quantization.
It is great for backprop quantization, but we can have much better inference quality (and performance) with static quantization.

Quantized Batch Normalization?

Hi:

Thanks for your great work and open-sourced quantization codes!
I read your ResNet-4bit and PokeBNN papers and am interested in some GPU acceleration research based on your models.
Here I am a bit confused about the data flow of the model.

If I understand correctly, your batch normalization operator is not quantized, which means it will operate in bf16 during inference. So in a block with Conv+BN, your Conv layer will output 8/4-bit data, and then it will go through a bf16 BN operator. The output will be bf16 in the end. Then the bf16 activation data will go to the next Conv layer where it will be first quantized into 8/4-bit data and do quantized Conv operation.

If we consider two blocks with [Conv + BN]. The data will be something like: bf16->int8/int4 ->(Conv) -> int8/int4 -> bf16 -> (BN) -> bf16->next Block -> bf16.

I am not sure If I understand correctly, Do you have any comments?

Thanks a lot and thanks for your excellent project!

ckpt of 8-bit ResNet-50 teacher model

Hello, thanks for the great project.

Can you share the ckpt file of the trained 8-bit resnet50 teacher model?

I think better reproduction will be guaranteed if the teacher model's ckpt is shared.

Thank you.

Publish updated version?

The current version of this package on pypi is 0.0.9 doesn't include the fix in #56 , leading to this ImportError:

  File "/opt/conda/lib/python3.7/site-packages/aqt/jax_legacy/jax/compute_cost_utils.py", line 27, in <module>
    from jax.interpreters import masking
ImportError: cannot import name 'masking' from 'jax.interpreters' (/opt/conda/lib/python3.7/site-packages/jax/interpreters/__init__.py)

Would you consider releasing a new version?

TypeError: dataclass() got an unexpected keyword argument 'frozen'

I met this problem when I tried to use it.

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/lwm/vision_chat.py", line 18, in <module>
    from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/lwm/vision_llama.py", line 21, in <module>
    from lwm.llama import LLaMAConfig, LLAMA_STANDARD_CONFIGS, FlaxLLaMABlockCollection, RMSNorm
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/lwm/llama.py", line 34, in <module>
    import aqt.jax.v2.flax.aqt_flax as aqt
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/venv/lib/python3.10/site-packages/aqt/jax/v2/flax/aqt_flax.py", line 23, in <module>
    from aqt.jax.v2 import aqt_dot_general
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/venv/lib/python3.10/site-packages/aqt/jax/v2/aqt_dot_general.py", line 29, in <module>
    from aqt.jax.v2 import aqt_tensor
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/venv/lib/python3.10/site-packages/aqt/jax/v2/aqt_tensor.py", line 28, in <module>
    from aqt.jax.v2.numerics import no_numerics
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/venv/lib/python3.10/site-packages/aqt/jax/v2/numerics/no_numerics.py", line 23, in <module>
    class NoNumerics(numerics.AqtNumerics):
TypeError: dataclass() got an unexpected keyword argument 'frozen'

Performance of MNIST example

Hi everyone - thanks for your work on this, very exciting!

I've been playing around a bit with the Flax MNIST example (https://github.com/google/aqt/blob/main/aqt/jax/v2/examples/mnist.py). I've benchmarked the training (as well as eval) on TPU v4 and v5 and can't see a performance improvement compared to bfloat16/float32 training. Both training and eval are around 4% slower when using int8 quantized operations.

Am I doing something wrong or is this expected? I could imagine that the overhead of converting from float32 to int8 and back is non-negligible at this small scale.

AqtEinsum 'not enough values to unpack'

Hi, I was working with AqtEinsum and in this particular case I got ValueError, altough in jnp.einsum the following operation works fine.

This works fine:

x = jax.random.normal(key, [1, 2, 4])
w = jax.random.normal(key, [2, 4, 4])

z = jnp.einsum('...ij,hjk->...ik', x, w)
z

This is not:

class SimpleDense(nn.Module):
    features: int
    config = aqt_config.fully_quantized()

    @nn.compact
    def __call__(self, x):
        d = x.shape[-1]

        kernel = self.param('kernel', nn.initializers.normal(), (2, d, self.features))
        einsum = aqt.AqtEinsum(self.config)

        return einsum('...ij,hjk->...ik', x, kernel)

model = SimpleDense(features = 4)
params = model.init(key, x)
ValueError                                Traceback (most recent call last)
[<ipython-input-41-bf39ae22f96a>](https://localhost:8080/#) in <cell line: 2>()
      1 model = SimpleDense(features = 4)
----> 2 params = model.init(key, x)

    [... skipping hidden 9 frame]

1 frames
[<ipython-input-40-29c53684ec5e>](https://localhost:8080/#) in __call__(self, x)
     10         einsum = aqt.AqtEinsum(self.config)
     11 
---> 12         return einsum('...ij,hjk->...ik', x, kernel)

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.10/dist-packages/aqt/jax/v2/flax/aqt_flax.py](https://localhost:8080/#) in __call__(self, eqn, lhs_g, rhs_g)
    315     einsum = functools.partial(aqt_dot_general.einsum, eqn=eqn)
    316     a = jax.make_jaxpr(einsum)(lhs=lhs_in, rhs=rhs_in)
--> 317     [lhs_g_id, rhs_g_id] = a.eqns[0].invars
    318     [lhs_l_id, rhs_l_id] = a.jaxpr.invars
    319     not_swap = lhs_g_id == lhs_l_id and rhs_g_id == rhs_l_id

ValueError: not enough values to unpack (expected 2, got 1)

Also if the einsum subscript and the kernel dimension is the following:

...
kernel = self.param('kernel', nn.initializers.normal(), (d, self.features))
einsum = aqt.AqtEinsum(self.config)

return einsum('...ij,jk->...ik', x, kernel)
...

The code works as it is expected without any errors.

For mention I'm using aqt version 0.5.0 and the random seed is 42.

Can AQT be used to calculate qk score?

I see that the sample codes all talk about Attention block or MLP block. Can aqt int8 only be used for parts involving parameter calculation? For example, qk score calculation, score * V calculation, can these be used aqt int8?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.