Could anyone please help covert this code to tensorflow. I am attempting to find the data point positions in the set for which the CNN outputs a value bigger than 0.95 as to aid in pseudo-labelling.
positions = []
for t in range(int(dataset.shape[0] // batch_size)):
data = dataset.next_batch
model_output = sess.run([output], feed_dict={model_input_pl: data})
for i in range(model_output[0].shape[0]):
if model_output[0][i][some_nodal_position] > 0.95:
positions.append(batch_start_position + i)
Parellising this code would allow for many more models to be tested, but having the code as above takes long.
Aucun commentaire:
Enregistrer un commentaire