jeudi 9 janvier 2020

How to check if NamedTuple is in list?

I tried to check if an inctance of a NamedTuple "Transition" is equal to any object in the list "self.memory".

Here is the code I tried to run:

from typing import NamedTuple
import random
import torch as t

Transition = NamedTuple('Transition', state=t.Tensor, action=int, reward=int, next_state=t.Tensor, done=int, hidden=t.Tensor)


class ReplayMemory:

    def __init__(self, capacity):
        self.memory = []
        self.capacity = capacity
        self.position = 0

    def store(self, *args):
        print(self.memory == Transition(*args))
        if Transition(*args) in self.memory:
            return
    if len(self.memory) < self.capacity:
        self.memory.append(None)
    self.memory[self.position] = Transition(*args)
    ...

And here is the output:

False
False

And the error I got:

   ...
        if Transition(*args) in self.memory:
    RuntimeError: bool value of Tensor with more than one value is ambiguous

This seems weird to me because the print is telling me that the "==" operation returns a boolean.

How could this be done correctly?

Thank you

Aucun commentaire:

Enregistrer un commentaire