gold = [1,2,1,0,2,2]
preds = [
[1,2,2,0,2,2],
[2,2,1,1,1,1],
[1,2,1,0,2,1],
[1,1,0,2,1,2],
[2,2,1,0,1,2]
]
def count_correct(pred, goal=gold):
return sum(1 for a,b in zip(pred, goal) if a==b)
print(sorted(preds, key=count_correct, reverse=True))
# [[1, 2, 2, 0, 2, 2], [1, 2, 1, 0, 2, 1], [2, 2, 1, 0, 1, 2], [2, 2, 1, 1, 1, 1], [1, 1, 0, 2, 1, 2]]
# the predictions that are correct, calculate the indices for them, and then try to look for a set cover:
def correct_ids(pred, goal=gold):
return [i for i,(a,b) in enumerate(zip(pred,goal)) if a==b]
print([correct_ids(pred) for pred in preds])
# [[0, 1, 3, 4, 5], [1, 2], [0, 1, 2, 3, 4], [0, 5], [1, 2, 3, 5]]