dimanche 4 avril 2021

Tensorflow custom layer if-condition branches

Thank you very much for reading my question. Please see the following code:

class Custom_Layer1234(keras.layers.Layer):
    def __init__(self,  inputname , units=45, input_dim=45):
        super(Custom_Layer1234, self).__init__()
        w_init = tf.random_normal_initializer()
        b_init = tf.zeros_initializer()

        self.w_0 = tf.Variable(initial_value=w_init(shape=(input_dim, units,), dtype="float32"), 
                              name='w0{}'.format(inputname),  trainable=True,) 
        self.w_1 = tf.Variable(initial_value=w_init(shape=(input_dim, units,), dtype="float32"), 
                              name='w1{}'.format(inputname),  trainable=True,)
        self.b_0 = tf.Variable(initial_value=b_init(shape=(units,), dtype="float32"), 
                              name='b0{}'.format(inputname), trainable=True)
        self.b_1 = tf.Variable(initial_value=b_init(shape=(units,), dtype="float32"), 
                              name='b1{}'.format(inputname), trainable=True) 
    @tf.function     
    def call(self, inputs):
        diff_1 = inputs[0][1]
        
        if diff_1 <= 0 :    
            y = tf.matmul(inputs, self.w_0) + self.b_0 
        else: 
            y = tf.matmul(inputs, self.w_1) + self.b_1
            
        return tf.nn.relu(y) 

I apologize if my question doesn't make sense at all, I am quite new to TensorFlow. I want to train two ' sets ' of variables separated by this condition, I could build the model and even compile and fit it. Having read about gradient tape and knowing very shallowly about it and backpropagation, this code seems barely working fine and the model is converging.

However, when I tried to use parallel processing (I have 2 gpu) ie:

mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
   .....

it throws an exception

Variable was not created in the distribution strategy scope of (<tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x000001A825A7BAF0>)

it makes me wonder if my approach in the first place was even correct at all or didn't make any sense. and may you please point out a correct way for doing branching, something like decision tree under conditions to having multiple sets of trainable variables and that only one of them is trained base on its route specific to the input data, in the neural network.

Thank you very much

Aucun commentaire:

Enregistrer un commentaire