mercredi 1 juillet 2020

"If statement" in Tensorflow

I want to use 'if statement' for possible earlier return in Tensorflow 1.15. However it seems if statement doesn't work in Tensorflow.

I thought I could use 'tf.cond', but it seems it is just for calculation not 'return'.

Example: y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)

Below is my decoder code where i wanna use if statement for earlier return. 'if y_hat[0][-1] == 2:'

    def decode_for_test(self, input):
       seq_len = tf.convert_to_tensor(1, dtype=tf.int32)
       decoder_inputs = tf.ones((tf.shape([1])[0], 1), tf.int32)
       ys = (decoder_inputs, decoder_inputs, seq_len)
       xs = (input, seq_len)
       memory, src_masks = self.encode(xs, False)


       for _ in tqdm(range(self.training_parameter.get('maxlen_target'))):
           logits, y_hat, y = self.decode(ys, memory, src_masks, False)

           if y_hat[0][-1] == 2:
               return logits, y_hat

           _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
           ys = (_decoder_inputs, y, seq_len)

       return logits, y_hat

if y_hat[0][-1] is 2, it means [END], so i want the function code to stop and return the values from there.

Aucun commentaire:

Enregistrer un commentaire