CNN model conditional layer in Keras

The problem with conditionals in neural networks

The issue with a switch or conditionals (like if-then-else) as part of a neural network is that conditionals are not differentiable everywhere. Therefore the automatic differentiation methods would not work directly and solving this is super complex. Check this for more details.

A shortcut is you can end up training 3 separate models independently, and then during inference uses a control flow of conditionals to infer from them.

#Training - 
model1 = model.fit(all images, P(cat/dog))
model2 = model.fit(all images, P(cat))
model3 = model.fit(all images, P(dog))
final prediction = argmax(model2, model3)

#Inference - 
if model1.predict == Cat: 
    model2.predict
else:
    model3.predict

But I don't think you are looking for that. I think you are looking to include conditionals as part of the computation graph itself.

Sadly, there is no direct way for you to build an if-then condition as part of a computation graph as per my knowledge. The keras.switch that you see allows you to work with tensor outputs but not with layers of a graph during training. That's why you will see it being used as part of loss functions and not in computation graphs (throws input errors).

A possible Solution - Skip connections & soft-switching

You can, however, try to build something similar with skip connections and soft switching.

A skip connection is a connection from a previous layer to another layer that allows you to pass information to the subsequent layers. This is quite common in very deep networks where information from the original data is subsequently lost. Check U-net or Resnet for example, which uses skip connections between layers to pass information to future layers.

enter image description here

The next issue is the issue of switching. You want to switch between 2 possible paths in the graph. What you can do is a soft-switching method which I took as inspiration from this paper. Notice that in order to switch between 2 distribution of words (one from the decoder and another from the input), the authors multiply them with p and (1-p) to get a cumulative distribution. This is a soft-switch that allows the model to pick the next predicted word from either the decoder or from the input itself. (helps when you want your chatbot to speak the words that were input by the user as part of its response to them!)

enter image description here

With an understanding of these 2 concepts, let's try to intuitively build our architecture.

  1. First we need a single-input multi-output graph since we are training 2 models

  2. Our first model is a multi-class classification that predicts individual probabilities for Cat and Dog separately. This will be trained with the activation of softmax and a categorical_crossentropy loss.

  3. Next, let's take the logit which predicts the probability of Cat, and multiply the convolution layer 3 with it. This can be done with a Lambda layer.

  4. And similarly, let's take the probability of Dog and multiply it with the convolution layer 2. This can be seen as the following -

    • If my first model predicts a cat and not a dog, perfectly, then the computation will be 1*(Conv3) and 0*(Conv2).
    • If the first model predicts a dog and not a cat, perfectly, then the computation will be 0*(Conv3) and 1*(Conv2)
    • You can think of this as either a soft-switch OR a forget gate from LSTM. The forget gate is a sigmoid (0 to 1) output that multiplies the cell state to gate it and allow the LSTM to forget or remember previous time-steps. Similar concept here!
  5. These Conv3 and Conv2 can now be further be processed, flattened, concatenated, and passed to another Dense layer for the final prediction.

This way if the model is not sure about a dog or a cat, both conv2 and conv3 features participate in the second model's predictions. This is how you can use skip connections and soft switch inspired mechanism to add some amount of conditional control flow to your network.

Check my implementation of the computation graph below.

from tensorflow.keras import layers, Model, utils
import numpy as np

X = np.random.random((10,500,500,3))
y = np.random.random((10,2))

#Model
inp = layers.Input((500,500,3))

x = layers.Conv2D(6, 3, name='conv1')(inp)
x = layers.MaxPooling2D(3)(x)

c2 = layers.Conv2D(9, 3, name='conv2')(x)
c2 = layers.MaxPooling2D(3)(c2)

c3 = layers.Conv2D(12, 3, name='conv3')(c2)
c3 = layers.MaxPooling2D(3)(c3)

x = layers.Conv2D(15, 3, name='conv4')(c3)
x = layers.MaxPooling2D(3)(x)

x = layers.Flatten()(x)
out1 = layers.Dense(2, activation='softmax', name='first')(x)

c = layers.Lambda(lambda x: x[:,:1])(out1)
d = layers.Lambda(lambda x: x[:,1:])(out1)

c = layers.Multiply()([c3, c])
d = layers.Multiply()([c2, d])

c = layers.Conv2D(15, 3, name='conv5')(c)
c = layers.MaxPooling2D(3)(c)
c = layers.Flatten()(c)

d = layers.Conv2D(12, 3, name='conv6')(d)
d = layers.MaxPooling2D(3)(d)
d = layers.Conv2D(15, 3, name='conv7')(d)
d = layers.MaxPooling2D(3)(d)
d = layers.Flatten()(d)

x = layers.concatenate([c,d])
x = layers.Dense(32)(x)
out2 = layers.Dense(2, activation='softmax',name='second')(x)

model = Model(inp, [out1, out2])
model.compile(optimizer='adam', loss='categorical_crossentropy', loss_weights=[0.5, 0.5])

model.fit(X, [y, y], epochs=5)

utils.plot_model(model, show_layer_names=False, show_shapes=True)
Epoch 1/5
1/1 [==============================] - 1s 1s/step - loss: 0.6819 - first_loss: 0.7424 - second_loss: 0.6214
Epoch 2/5
1/1 [==============================] - 0s 423ms/step - loss: 0.6381 - first_loss: 0.6361 - second_loss: 0.6400
Epoch 3/5
1/1 [==============================] - 0s 442ms/step - loss: 0.6137 - first_loss: 0.6126 - second_loss: 0.6147
Epoch 4/5
1/1 [==============================] - 0s 434ms/step - loss: 0.6214 - first_loss: 0.6159 - second_loss: 0.6268
Epoch 5/5
1/1 [==============================] - 0s 427ms/step - loss: 0.6248 - first_loss: 0.6184 - second_loss: 0.6311

enter image description here


In order to build condition-based CNN, we could pass full batch of inputs to each sub-model in the Model2 and select the desired outputs from all sub-models outputs base on conditions (which the models you defined in the question does), or we can choose a faster way by follow the step of the conditions (which is the three conditions you listed)

Example code for showing condition mechanism:

# Mimic the test dataset and labels
batch = tf.constant([[1, 2, 3], [2, 3, 1], [3, 1, 2]])
y_all = [tf.one_hot(i, number_of_class, dtype=tf.float32) for i in range(number_of_class)]
# Mimic the outputs of model_01
y_p = tf.constant([[0.9, 0.1], [0.1, 0.9], [0.3, 0.7]])
y_p = tf.one_hot(tf.math.argmax(y_p, axis=1), number_of_class, dtype=tf.float32)
# Mimic the conditions by choose the samples from batch base on if prediction is equal to label wrt each class
for y in y_all:
    condition = tf.reduce_all(tf.math.equal(y_p, y), 1)
    indices = tf.where(condition)
    choosed_inputs = tf.gather_nd(batch, indices)
    print("label:\n{}\ncondition:\n{}\nindices:\n{}\nchoosed_inputs:\n{}\n".format(y, condition, indices, choosed_inputs))

Outputs:

label:
[1. 0.]
condition:
[ True False False]
indices:
[[0]]
choosed_inputs:
[[1 2 3]]

label:
[0. 1.]
condition:
[False  True  True]
indices:
[[1]
 [2]]
choosed_inputs:
[[2 3 1]
 [3 1 2]]

Example code that build the condition-based CNN model and training it in custom training fashion:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.utils import *
import numpy as np

img_rows, img_cols, number_of_class, batch_size = 256, 256, 2, 64

#----------- main model (Model 1) ------------------------------------
inputs = Input(shape=(img_rows, img_cols, 3))
conv_01 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_01') (inputs)
conv_02 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_02') (conv_01)
skip_dog = conv_02

conv_03 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_03') (conv_02)
skip_cat = conv_03

conv_04 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_04') (conv_03)

flatten_main_model =  Flatten() (conv_04)
Output_main_model = Dense(units = number_of_class , activation = 'softmax', name = "Output_layer")(flatten_main_model)

#----------- Conditional  Cat model ------------------------------------ 
inputs_1 = Input(shape=skip_cat.shape[1:])
conv_05 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_05') (inputs_1)
flatten_cat_model =  Flatten() (conv_05)
Output_cat_model = Dense(units = number_of_class , activation = 'softmax', name = "Output_layer_cat")(flatten_cat_model)

#----------- Conditional  Dog model ------------------------------------ 
inputs_2 = Input(shape=skip_dog.shape[1:])
conv_06 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_06') (inputs_2)
flatten_dog_model =  Flatten() (conv_06)
Output_dog_model = Dense(units = number_of_class , activation = 'softmax', name = "Output_layer_dog")(flatten_dog_model)

#----------------------------- My discrete 3 models --------------------------------
model_01 = Model(inputs = inputs, outputs = [skip_cat, skip_dog, Output_main_model], name = 'model_main')
model_02_1 = Model(inputs = inputs_1, outputs = Output_cat_model, name = 'Conditional_cat_model')
model_02_2 = Model(inputs = inputs_2, outputs = Output_dog_model, name = 'Conditional_dog_model')

# Get one hot vectors for all the labels
y_all = [tf.one_hot(i, number_of_class, dtype=tf.float32) for i in range(number_of_class)]
sub_models_all = [model_02_1, model_02_2]
sub_models_trainable_variables = [model_01.trainable_variables[:6] + model_02_1.trainable_variables, 
                                  model_01.trainable_variables[:4] + model_02_2.trainable_variables]

cce = keras.losses.CategoricalCrossentropy()
optimizer_01 = keras.optimizers.Adam(learning_rate=1e-3, name='Adam_01')
optimizer_02 = keras.optimizers.Adam(learning_rate=2e-3, name='Adam_02')

@tf.function
def train_step(batch_imgs, labels):
    with tf.GradientTape(persistent=True) as tape:
        model_01_outputs = model_01(batch_imgs)
        y_p_01 = model_01_outputs[-1]
        loss_01 = cce(labels, y_p_01)

        # Convert outputs of model_01 from float in (0, 1) to one hot vectors, no gradients flow back from here
        y_p_01 = tf.one_hot(tf.math.argmax(y_p_01, axis=1), number_of_class, dtype=tf.float32)
        loss_02_all = []
        for i in range(number_of_class):
            condition = tf.reduce_all(tf.math.equal(y_p_01, y_all[i]), 1)
            indices = tf.where(condition)
            choosed_inputs = tf.gather_nd(model_01_outputs[i], indices)
            # Note here the inputs batch size for each sub-model is dynamic
            y_p_02 = sub_models_all[i](choosed_inputs)
            y_t = tf.gather_nd(labels, indices)
            loss_02 = cce(y_t, y_p_02)
            loss_02_all.append(loss_02)

    grads_01 = tape.gradient(loss_01, model_01.trainable_variables)
    optimizer_01.apply_gradients(zip(grads_01, model_01.trainable_variables))

    for i in range(number_of_class):
        grads_02 = tape.gradient(loss_02_all[i], sub_models_trainable_variables[i])
        optimizer_02.apply_gradients(zip(grads_02, sub_models_trainable_variables[i]))

    return loss_01, loss_02_all

def training():
    for j in range(10):
        random_imgs = np.random.rand(batch_size, img_rows, img_cols, 3)
        random_labels = np.eye(number_of_class)[np.random.choice(number_of_class, batch_size)]
        loss_01, loss_02_all = train_step(random_imgs, random_labels)
        print("Step: {}, Loss_01: {}, Loss_02_all: {}".format(j, loss_01, loss_02_all))

Outputs is something like:

Step: 0, Loss_01: 0.6966696977615356, Loss_02_1: 0.0, Loss_02_2: 0.6886894702911377
Step: 1, Loss_01: 0.6912064552307129, Loss_02_1: 0.6968430280685425, Loss_02_2: 0.6911896467208862
Step: 2, Loss_01: 0.6910352110862732, Loss_02_1: 0.698455274105072, Loss_02_2: 0.6935626864433289
Step: 3, Loss_01: 0.6955667734146118, Loss_02_1: 0.6843984127044678, Loss_02_2: 0.6953505277633667
Step: 4, Loss_01: 0.6941269636154175, Loss_02_1: 0.673763632774353, Loss_02_2: 0.6994296908378601
Step: 5, Loss_01: 0.6872361898422241, Loss_02_1: 0.6769005060195923, Loss_02_2: 0.6907837390899658
Step: 6, Loss_01: 0.6931678056716919, Loss_02_1: 0.7674703598022461, Loss_02_2: 0.6935689449310303
Step: 7, Loss_01: 0.6976977586746216, Loss_02_1: 0.7503389120101929, Loss_02_2: 0.7076178789138794
Step: 8, Loss_01: 0.6932153105735779, Loss_02_1: 0.7428234219551086, Loss_02_2: 0.6935019493103027
Step: 9, Loss_01: 0.693305253982544, Loss_02_1: 0.6476342082023621, Loss_02_2: 0.6916818618774414