Prog blog
Inventory

How to classify flowers with Tensorflow.js

How to classify flowers with Tensorflow.js

The post is about how I tried to makeexamplefromudacity Tensorflowabout the classification of flowers in javascript using Tensorflow.js

The whole example is available on myflower-photos github.

The hardest part of all this was the augmentation image generator, which I didn't find on thenpm, so I had to write the functions for thejimpby myself, which would work in a similar way as the original generator from the example in python.

augment images

In addition, the image generation had to be spread over all available processors using theworkerpoollibrary.

import { random, sampleSize } from "lodash";

import jimp from "jimp";
import { rotate } from "./rotate";
import { scale } from "./scale";
import { shiftHeight } from "./shift-height";
import { shiftWidth } from "./shift-width";
import workerpool from "workerpool";

const operations = [
  (image: jimp) => image.flip(true, false),
  (image: jimp) => shiftWidth(image, 0.15),
  (image: jimp) => shiftHeight(image, 0.15),
  (image: jimp) => scale(image, 0.5),
  (image: jimp) => rotate(image, 45)
];

const generateAugmentImage = async (path: string, shape: number) => {
  const image = (await jimp.read(path)).resize(
    shape,
    shape
  );
  const operationsImageResult = sampleSize(
    operations,
    random(0, operations.length)
  ).reduce((image: jimp, operation: (image: jimp) => jimp) => operation(image), image);
  
  // await operationsImageResult.writeAsync(`test_photo/${path}`);
  // console.log(`test_photo/${path}`)
  const buffer = await operationsImageResult.getBufferAsync(jimp.MIME_PNG);
  return buffer;
};

workerpool.worker({
  generateAugmentImage: generateAugmentImage,
});

Another problem I encountered was the lack of loss function ofSparseCategoryCrossentropyin tensorflow.js. I replaced it with thecategoricalCrossentropyfunction, which, as we will read on stackexchangesparse categorical crossentropy vs categorical crossentropy, is no different except the way labels are presented as oneHot.

const labels = (records: Record[], labels: string[]) =>
  function* () {
    for (let index = 0; index < records.length; index++) {
      const record = records[index];
      const indexOfLabel = labels.indexOf(record.label);
      if (indexOfLabel === -1) {
        throw new Error(
          `Something wrong. Missing label: ${
            record.label
          } in labels: ${labels.toString()}`
        );
      }
      yield tf.oneHot(indexOfLabel, labels.length);
    }
  };

One of the most important elements was also the prefetch, thanks to which the Tensorflow always had the images from the generator ready in time to learn.

  const trainX = tf.data
    .generator(features(trainRecords))
    .mapAsync(async (path: string) => {
      const image = await augmentImagePool.exec("generateAugmentImage", [
        path,
        IMAGE_SHAPE,
      ]);
      return imageBufferToInputTensor(image);
    })
    .prefetch(BATCH_SIZE * 3);

After long work on the generator, I managed to get an amazing 75% accuracy result during neural network validation.

npm run train

> flower-photos@1.0.0 train /flower-photos
> node ./dist/train.js

Overriding the gradient for 'Max'
Overriding the gradient for 'OneHot'
Overriding the gradient for 'PadV2'
Overriding the gradient for 'SpaceToBatchND'
Overriding the gradient for 'SplitV'
2020-08-11 14:51:06.023690: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2400075000 Hz
2020-08-11 14:51:06.024722: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x522b820 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-08-11 14:51:06.024759: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
Epoch 1 / 80
eta=0.0 ======================================================================> 
133405ms 4600184us/step - acc=0.257 loss=2.31 val_acc=0.384 val_loss=1.47 
Epoch 2 / 80
eta=0.0 ======================================================================> 
127012ms 4379732us/step - acc=0.418 loss=1.35 val_acc=0.506 val_loss=1.20 
...
eta=0.0 ======================================================================> 
121768ms 4198905us/step - acc=0.918 loss=0.236 val_acc=0.771 val_loss=0.859 
Epoch 80 / 80
eta=0.0 ======================================================================> 
118522ms 4086980us/step - acc=0.918 loss=0.228 val_acc=0.761 val_loss=0.872 

Before making the model publicly available, it is best to additionally quantify the weights usingtensorflowjs_convertto get half the size of the model.

tensorflowjs_converter --quantize_float16 --input_format tfjs_layers_model --output_format tfjs_layers_model model/model.json quantized_model/

Additionally, it is recommended to transform the layer model into a graph to speed up prediction on weaker devices. We can read about the results of such optimisation in the publication:TensorFlow Graph Optimizations

tensorflowjs_converter --quantize_float16 --input_format tfjs_layers_model --output_format tfjs_graph_model model/model.json quantized_graph_model/

If you want to see how the flower detection works, I invite you to testlive flower detector.

flower photos webcam app screenshot

Flower photos webcam app