Keras conditional passing one model output to another model

I suggest you train the cat/dog binary classification model independently, as well as the cat breed and dog breed models. Then, you can use a custom Keras model for inference. Here is a working example, you just need to load your own dataset, and adjust the model architectures to your liking.

import numpy as np
import tensorflow as tf
from tensorflow import keras

np.random.seed(1)
tf.random.set_seed(1)

num_images = 200
num_cat_breeds = 10
num_dog_breeds = 15

X_train = np.random.random([num_images, 32, 32, 3])
y_breed = np.random.randint(num_cat_breeds + num_dog_breeds, size=num_images)
y_is_cat = y_breed < num_cat_breeds
y_cat_breed = y_breed[y_is_cat]
y_dog_breed = y_breed[~y_is_cat] - num_cat_breeds

model_cat_or_dog = keras.Sequential([
    keras.layers.Conv2D(filters=32, kernel_size=3, activation="relu"),
    keras.layers.Flatten(),
    keras.layers.Dense(1, activation="sigmoid")
])
model_cat_or_dog.compile(loss="binary_crossentropy", optimizer="adam")
model_cat_or_dog.fit(X_train, y_is_cat, epochs=2)

model_cat_breed = keras.Sequential([
    keras.layers.Conv2D(filters=32, kernel_size=3, activation="relu"),
    keras.layers.Flatten(),
    keras.layers.Dense(num_cat_breeds, activation="softmax")
])
model_cat_breed.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
model_cat_breed.fit(X_train[y_is_cat], y_cat_breed, epochs=2)

model_dog_breed = keras.Sequential([
    keras.layers.Conv2D(filters=32, kernel_size=3, activation="relu"),
    keras.layers.Flatten(),
    keras.layers.Dense(num_dog_breeds, activation="softmax")
])
model_dog_breed.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
model_dog_breed.fit(X_train[~y_is_cat], y_dog_breed, epochs=2)

class BreedModel(keras.Model):
    def __init__(self, model_cat_or_dog, model_cat_breed, model_dog_breed, **kwargs):
        super().__init__(**kwargs)
        self.model_cat_or_dog = keras.models.clone_model(model_cat_or_dog)
        self.model_cat_breed = keras.models.clone_model(model_cat_breed)
        self.model_dog_breed = keras.models.clone_model(model_dog_breed)
    def __call__(self, inputs):
        y_proba_is_cat = self.model_cat_or_dog(inputs)
        y_is_cat = tf.squeeze(y_proba_is_cat > 0.5)
        cat_images = tf.boolean_mask(inputs, y_is_cat)
        dog_images = tf.boolean_mask(inputs, ~y_is_cat)
        Y_proba_cat_breed = self.model_cat_breed(cat_images)
        Y_proba_dog_breed = self.model_dog_breed(dog_images)
        return y_is_cat, y_proba_is_cat, Y_proba_cat_breed, Y_proba_dog_breed

num_test_images = 50
X_test = np.random.random([num_test_images, 32, 32, 3])
model = BreedModel(model_cat_or_dog, model_cat_breed, model_dog_breed)
y_is_cat, y_proba_is_cat, Y_proba_cat_breed, Y_proba_dog_breed = model(X_test)

Method 1.

There is a way, where you can define Dense layer with static weights and multiple outputs to 0, based on previous model outputs. However it's not a way, how you would usually do.

Method 2. And what is actually we are doing.

@staticmethod
def animal_breed(image):

    # Just an example for getting some Models. 
    def get_model(inputs):
        y = Dense(5)(image)
        y = Dense(5, name='final-1')(y)
        return Model(input=inputs, output=Dense(10)(y))

    # Define Base Model
    DogCatModel = get_model(
        inputs=image)

    result = DogCatModel.predict(image)

    # Get Base model on condition. Or load your model 
    # from any other source. 
    def get_specific(value, model1, model2):
        if value[0] > value[1]:
            return model1
        return model2

    # Just a mock of inserting previous result
    # In real works you wanted to inserted scalar results 
    # to the last layers(After CNN)
    inputs = inputs[0][0] = result
    
    SpecificModel = get_specific(
        result, get_model(inputs), get_model(inputs)
    )
    
    return SpecificModel.predict(inputs)

Why it's so? You might expect something else, but in fact it's common solution which easy to scale. You will not usually use layers itself to combine different models. + it's also much more easier to configure/freeze settings.


Here is another solution which may train faster, run faster and use less RAM, give better performance, and be easier to use than the alternatives listed here.

Just use a single model with multiple outputs: a binary output (cat/dog), a cat breed output (multiclass), and a dog breed output (multiclass). During training, you can use a custom loss function to ignore the loss that corresponds to the wrong species (for example, ignore the cat breed output for dog images).

The benefits are:

  • Faster training: just one model to train.
  • Better performance: you can often get better performance when doing multi-task learning like this. That's because dog images and cat images have a lot in common, so it's helpful to train a single base neural network for both and then build specialized networks on top of that.
  • Less RAM and less compute: instead of having to go through two CNNs (one for the cat/dog detector and one for the breed), we just go through one (the base network). This largely compensates the unnecessary computations that we do by going through the dog breed classifier even when the cat/dog detector says it's a cat.
  • Easier to use: just one call to the model, and you get everything you need all at once. Moreover, if the cat/dog detector is not quite sure (for example it outputs a 50% probability), then you can at least have reasonable candidates for both cats and dogs.

Here's a working example. You just need to replace the data with your own. Note that there are three labels:

  • cat/dog: for example [0, 1, 1, 0] for dog, cat, cat, dog
  • cat breed: for example [-1, 2, 0, -1] for not-a-cat, 2nd cat class, 0th cat class, not-a-cat
  • dog breed: for example [3, -1, -1, 1] for 3rd dog class, not-a-dog, not-a-dog, 1st dog class
import numpy as np
import tensorflow as tf
from tensorflow import keras

np.random.seed(1)
tf.random.set_seed(1)

num_images = 200
num_cat_breeds = 10
num_dog_breeds = 15

X_train = np.random.random([num_images, 32, 32, 3])
y_breed = np.random.randint(num_cat_breeds + num_dog_breeds, size=num_images)
y_is_cat = y_breed < num_cat_breeds
y_cat_breed = np.where(y_is_cat, y_breed, -1)
y_dog_breed = np.where(y_is_cat, -1, y_breed - num_cat_breeds)

base_model = keras.Sequential([
    keras.layers.Conv2D(filters=32, kernel_size=3, activation="relu"),
    keras.layers.Flatten(),
])

model_is_cat = keras.Sequential([
    keras.layers.Dense(1, activation="sigmoid")
])

model_cat_breed = keras.Sequential([
    keras.layers.Dense(num_cat_breeds, activation="softmax")
])

model_dog_breed = keras.Sequential([
    keras.layers.Dense(num_dog_breeds, activation="softmax")
])

image_input = keras.layers.Input(shape=[32, 32, 3])
z = base_model(image_input)
is_cat = model_is_cat(z)
cat_breed = model_cat_breed(z)
dog_breed = model_dog_breed(z)
model = keras.Model(inputs=[image_input],
                    outputs=[is_cat, cat_breed, dog_breed])

def optional_crossentropy(y_true, y_pred):
    is_not_ignored = y_true != -1
    y_true_no_ignore = tf.where(is_not_ignored, y_true, 0)
    mask = tf.cast(is_not_ignored, tf.float32)
    return keras.losses.sparse_categorical_crossentropy(y_true_no_ignore, y_pred) * mask

model.compile(loss=["binary_crossentropy",
                    optional_crossentropy,
                    optional_crossentropy],
              optimizer="adam")
model.fit(X_train, [y_is_cat, y_cat_breed, y_dog_breed], epochs=2)

y_is_cat_pred, y_cat_breed_pred, y_dog_breed_pred = model.predict(X_train[:2])
print(y_is_cat_pred)
print(y_cat_breed_pred)
print(y_dog_breed_pred)