포스트

[ 백준 ] 4045 Cover Up ( Python )

BOJ 4045

Cover Up 게임의 최적 전략 - 확률 계산하기

안녕하세요! 오늘은 “The Price is Right”라는 인기 게임 쇼에 나오는 “Cover Up”이라는 게임에 대해 알아보려고 합니다. 이 게임의 목표는 새 자동차 가격인 5자리 숫자를 맞추는 것인데요. 최적의 전략을 사용한다면 이길 확률이 어느 정도일지 한번 계산해보겠습니다!

문제 분석

게임의 규칙은 다음과 같습니다:

  1. 각 자릿수마다 선택할 수 있는 숫자의 개수가 주어집니다. (최대 10개)
  2. 각 자릿수의 숫자들 중에는 정답일 가능성이 높은 ‘알려진 후보(known candidates)’가 있습니다.
  3. 알려진 후보 숫자 중 하나가 정답일 확률도 주어집니다.
  4. 한 자릿수라도 맞추면 계속해서 기회를 얻어 못 맞춘 자릿수의 숫자를 추측할 수 있습니다.
  5. 모든 자릿수를 맞추면 승리, 한 턴에 하나도 못 맞추면 패배합니다.

최적의 전략이란, 각 자릿수에서 알려진 후보를 먼저 선택하는 것입니다. 예를 들어 마지막 자릿수의 선택지가 [1,3,5,8,9]이고 5와 9가 후보라면, 5나 9를 먼저 선택해야겠죠.

자 그럼 이제 이길 확률을 계산해봐야겠는데요…생각을 좀 해보니 꽤 복잡해 보입니다. 어떻게 접근하면 좋을까요?

문제 접근 방법

일단 각 자릿수는 독립적이라는 점에 주목합시다. 그리고 어떤 자릿수든 맞추기만 하면 계속 진행할 수 있으므로, 전체 승리 확률을 구하려면 각 자릿수에서 숫자를 하나 이상 맞출 확률을 모두 곱해야 할 것 같습니다.

$P(win) = p_1 * p_2 * … * p_n$

$p_i$ = i번째 자릿수에서 숫자를 하나 이상 맞출 확률

그런데 사실 저 $p_i$를 구하는 것도 쉽지 않아 보입니다. 특정 자릿수에서 맞출 확률을 구하려면, 가능한 모든 숫자 선택 조합을 고려해야 합니다.

예를 들어 어떤 자릿수의 선택지가 [1,2,3,4] 이고 1이 유일한 후보라고 해보죠. (맞출 확률 100%)

  • 1을 첫번째로 선택할 경우 100% 맞출 수 있습니다.
  • 2,3,4중 하나를 선택한 뒤 1을 선택할 경우 맞출 확률은 1/3 = 약 33.3% 입니다.
  • 1이 아닌 두 숫자를 선택한 뒤 1을 선택할 확률은 약 16.7% 입니다.
  • 마지막에 1을 선택할 확률은 약 8.3% 겠죠.

따라서 이 경우 하나 이상 맞출 확률 p는 $p = 1 + (1/3) + (1/3) _ (1/2) + (1/3) _ (1/2) * (1/1) = 1.833333… $

즉 100% 입니다.

풀이 과정

자 그럼 코드를 한줄 한줄 뜯어보면서, 위에서 설명한 접근 방법이 어떻게 구현됐는지 살펴보겠습니다.

우선 prob_of_success 함수가 전체 흐름을 담당합니다. 코드를 보면 base case는 다음과 같이 처리됩니다:

1
2
3
4
5
6
7
8
9
if n == 0:
    return 1.0
elif n == 1:
    P = 0.0
    if A[0] > 0:
        P = B[0] / A[0]
    if C[0] > 0 and (D[0] / C[0] > P):
        P = D[0] / C[0]
    return P

자릿수(n)이 0이면 고려할 것이 없으니 1을 리턴해주면 됩니다.
n이 1이면 후보의 확률(B[0]/A[0])과 비후보의 확률(D[0]/C[0]) 중 더 큰 값을 리턴합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
P = 0.0
m = power2(n)
choice = [0] * n
for i in range(m):
    convert(choice, i, n)
    OK = True
    for j in range(n):
        if choice[j] == 0 and A[j] == 0:
            OK = False
        if choice[j] == 1 and C[j] == 0:
            OK = False
    if OK:
        P2 = prob_of_choice(n, choice, A, B, C, D)
    else:
        P2 = 0.0

    if P < P2:
        P = P2

$2^n$의 이진수 조합을 모두 고려하면서, 각각의 경우(choice)에 대해 prob_of_choice 함수를 호출합니다. choice는 각 자릿수에서 1이면 비후보를, 0이면 후보를 선택한다는 의미입니다.

예를 들어 n=3이고 choice가 [0,1,0]이라면 첫번째/세번째 자릿수는 후보를, 두번째는 비후보를 선택하겠다는 뜻입니다.

이때 OK 변수는 해당 선택이 가능한지 여부를 판단합니다. 예를 들어 후보가 없는데 후보를 선택하려 하면 False겠죠. 가능한 선택일 경우에만 확률을 계산합니다.

최종적으로 가장 높은 확률을 리턴하는 것이 목표입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def prob_of_choice(n, choice, likely_num, likely_prob, unlikely_num, unlikely_prob):
    prob = 0.0

    m = power2(n)
    for k in range(1, m):
        hits = [0] * n
        convert(hits, k, n)
        P1 = 1.0
        for i in range(n):
            if hits[i] == 1:
                if choice[i] == 0:
                    if likely_prob[i] == 0.0:
                        P1 = 0.0
                    else:
                        P1 *= likely_prob[i] / likely_num[i]
                else:
                    if unlikely_prob[i] == 0.0:
                        P1 = 0.0
                    else:
                        P1 *= unlikely_prob[i] / unlikely_num[i]
            else:
                if choice[i] == 0:
                    if likely_prob[i] == 0.0:
                        P1 *= 1.0
                    else:
                        if likely_num[i] <= 1:
                            P1 *= unlikely_prob[i]
                        else:
                            P1 *= (likely_prob[i] * (likely_num[i] - 1) / likely_num[i]) + unlikely_prob[i]
                else:
                    if unlikely_prob[i] == 0.0:
                        P1 *= 1.0
                    else:
                        if unlikely_num[i] <= 1:
                            P1 *= likely_prob[i]
                        else:
                            P1 *= (unlikely_prob[i] * (unlikely_num[i] - 1) / unlikely_num[i]) + likely_prob[i]

        if P1 > 0:
            P2 = prob_of_success(j + 1, new_likely_num, new_likely_prob, new_unlikely_num, new_unlikely_prob)
            prob += P1 * P2

    return prob

prob_of_choice 함수는 choice가 주어졌을 때의 확률을 계산합니다. 여기서도 $2^n$의 경우의 수를 모두 고려하는데, 이때의 의미는 각 자릿수를 맞출지 틀릴지를 나타내는 hits 입니다. 예를 들어 hits가 [1,0,1]이라면 1,3번째 자릿수는 맞추고 2번째는 못맞춘다는 뜻입니다.

P1은 해당 choicehits의 조합이 일어날 확률을 나타냅니다.

  • hits[i] == 1이라면 i번째 자릿수를 맞춘 경우이므로, 해당 숫자를 선택할 확률을 곱해줍니다. (likely_prob[i] / likely_num[i] 또는 unlikely_prob[i] / unlikely_num[i])
  • hits[i] == 0이라면 못 맞춘 경우인데, 이때는 선택한 숫자 외의 숫자를 골랐을 확률을 곱해줍니다. 좀 복잡하게 느껴질 수 있는데, 예를 들어 likely_num=3, likely_prob=0.8 이고 후보를 선택했다면 ($choice[i] == 0$), 못맞출 확률은 $1 - 0.8/3 = 0.7333..$ 이 되겠죠. 이걸 좀 정리해서 코드로 옮긴 것입니다.

이렇게 구한 P1이 0보다 크다면, 즉 실제로 가능한 시나리오라면 남은 자릿수에 대해 prob_of_success를 재귀호출 하여 P2를 구하고, 둘을 곱해서 더해줍니다.

그리고 재귀호출을 하기 전에 확률을 업데이트 해줘야 하는데, 예를 들어 [3,1,0.8]에서 첫번째 시도에 후보를 선택해서 맞췄다면, 이제 남은 건 [2,1,0.8] 이 되어야 합니다. 이 부분은 직접 손으로 풀어보시는 걸 추천드립니다.

추가 설명

와 정말 복잡하고 어려운 문제였습니다. 하지만 차근차근 분석하고 접근 방법을 정하니 풀 수 있었어요.

여기서 중요한 포인트는 아래와 같습니다:

  1. 각 자릿수는 독립적이므로 승리 확률은 각 자릿수에서 하나 이상 맞출 확률의 곱이다.

  2. 어떤 자릿수에서 하나 이상 맞출 확률을 구하려면, 가능한 모든 숫자 선택의 조합을 고려해야 한다.

  3. 특정 선택이 주어졌을 때의 확률은, 다시 모든 맞춤/틀림의 조합을 고려해서 계산할 수 있다.

  4. 현재 선택과 맞은 숫자에 따라 남은 숫자의 확률을 업데이트하고, 재귀호출로 남은 자릿수의 승리 확률을 계산할 수 있다.

  5. 2^n의 경우의 수를 비트마스크로 구현하면 깔끔하다. (convert 함수)

혹시라도 이해가 잘 안되는 부분이 있다면 댓글로 물어봐 주세요. 저도 여러분들과 토론하면서 배워가고 싶습니다 :)

전체 코드 보기
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import sys

def print_vectors(n, A, B, C, D):
    print(f"n={n}")
    print("likely: ", end="")
    for i in range(n):
        print(f"{A[i]} {B[i]}, ", end="")
    print()
    print("unlikely: ", end="")
    for i in range(n):
        print(f"{C[i]} {D[i]}, ", end="")
    print()
    print()


def print_prob(P):
    if P > 0.999999999999:
        print(1)
    else:
        print("0.", end="")
        n = int(P * 1000)
        P = P * 1000 - int(P * 1000)
        if P >= 0.499999999999:
            n += 1
        if n % 100 == 0:
            print(n // 100)
        elif n % 10 == 0:
            if n // 100 == 0:
                print('0', end="")
            print(n // 10)
        else:
            if n // 100 == 0:
                print('0', end="")
            if n // 10 == 0:
                print('0', end="")
            print(n)


def power2(n):
    return 2 ** n


def convert(A, x, n):
    for k in range(n):
        A[k] = x % 2
        x //= 2


def prob_of_success(n, A, B, C, D):
    def check_probs(n, choice, likely_num, likely_prob, unlikely_num, unlikely_prob):
        for i in range(n):
            if likely_num[i] == 0 and likely_prob[i] != 0:
                print(f" error!!! Likely: {likely_num[i]} {likely_prob[i]}")
        for i in range(n):
            if unlikely_num[i] == 0 and unlikely_prob[i] != 0:
                print(f" error!!! UnLikely: {unlikely_num[i]} {unlikely_prob[i]}")
        error = False
        for i in range(n):
            if choice[i] != 0 and choice[i] != 1:
                error = True
        if error:
            for i in range(n):
                print(choice[i], end=" ")
            print()

    def prob_of_choice(n, choice, likely_num, likely_prob, unlikely_num, unlikely_prob):
        prob = 0.0

        m = power2(n)
        for k in range(1, m):
            hits = [0] * n
            convert(hits, k, n)
            P1 = 1.0
            for i in range(n):
                if hits[i] == 1:
                    if choice[i] == 0:
                        if likely_prob[i] == 0.0:
                            P1 = 0.0
                        else:
                            P1 *= likely_prob[i] / likely_num[i]
                    else:
                        if unlikely_prob[i] == 0.0:
                            P1 = 0.0
                        else:
                            P1 *= unlikely_prob[i] / unlikely_num[i]
                else:
                    if choice[i] == 0:
                        if likely_prob[i] == 0.0:
                            P1 *= 1.0
                        else:
                            if likely_num[i] <= 1:
                                P1 *= unlikely_prob[i]
                            else:
                                P1 *= (likely_prob[i] * (likely_num[i] - 1) / likely_num[i]) + unlikely_prob[i]
                    else:
                        if unlikely_prob[i] == 0.0:
                            P1 *= 1.0
                        else:
                            if unlikely_num[i] <= 1:
                                P1 *= likely_prob[i]
                            else:
                                P1 *= (unlikely_prob[i] * (unlikely_num[i] - 1) / unlikely_num[i]) + likely_prob[i]

            if P1 > 0:
                new_likely_num = [0] * n
                new_unlikely_num = [0] * n
                new_likely_prob = [0.0] * n
                new_unlikely_prob = [0.0] * n
                j = -1
                for i in range(n):
                    if hits[i] == 0:
                        j += 1
                        if choice[i] == 0:
                            Pwrong = 1 - likely_prob[i] / likely_num[i]
                            if likely_num[i] <= 1:
                                new_likely_num[j] = 0
                                new_likely_prob[j] = 0.0
                                new_unlikely_num[j] = unlikely_num[i]
                                if unlikely_num[i] > 0:
                                    new_unlikely_prob[j] = 1.0
                                else:
                                    new_unlikely_prob[j] = 0.0
                            else:
                                new_likely_num[j] = likely_num[i] - 1
                                new_likely_prob[j] = (likely_prob[i] / likely_num[i]) / Pwrong * (likely_num[i] - 1)
                                new_unlikely_num[j] = unlikely_num[i]
                                new_unlikely_prob[j] = unlikely_prob[i] / Pwrong
                        else:
                            Pwrong = 1 - unlikely_prob[i] / unlikely_num[i]
                            if unlikely_num[i] <= 1:
                                new_unlikely_num[j] = 0
                                new_unlikely_prob[j] = 0.0
                                new_likely_num[j] = likely_num[i]
                                if likely_num[i] > 0:
                                    new_likely_prob[j] = 1.0
                                else:
                                    new_likely_prob[j] = 0.0
                            else:
                                new_unlikely_num[j] = unlikely_num[i] - 1
                                new_unlikely_prob[j] = (unlikely_prob[i] / unlikely_num[i]) / Pwrong * (unlikely_num[i] - 1)
                                new_likely_num[j] = likely_num[i]
                                new_likely_prob[j] = likely_prob[i] / Pwrong

                P2 = prob_of_success(j + 1, new_likely_num, new_likely_prob, new_unlikely_num, new_unlikely_prob)
                prob += P1 * P2

        return prob

    if n == 0:
        return 1.0
    elif n == 1:
        P = 0.0
        if A[0] > 0:
            P = B[0] / A[0]
        if C[0] > 0 and (D[0] / C[0] > P):
            P = D[0] / C[0]
        return P
    else:
        P = 0.0
        m = power2(n)
        choice = [0] * n
        for i in range(m):
            convert(choice, i, n)
            OK = True
            for j in range(n):
                if choice[j] == 0 and A[j] == 0:
                    OK = False
                if choice[j] == 1 and C[j] == 0:
                    OK = False
            if OK:
                P2 = prob_of_choice(n, choice, A, B, C, D)
            else:
                P2 = 0.0

            if P < P2:
                P = P2

        return P


while True:
    n = int(sys.stdin.readline().strip())
    if n == 0:
        break

    likely_num = [0] * n
    unlikely_num = [0] * n
    likely_prob = [0.0] * n
    unlikely_prob = [0.0] * n

    inputs = list(map(float, sys.stdin.readline().strip().split()))

    for i in range(n):
        m = int(inputs[3 * i])
        l = int(inputs[3 * i + 1])
        p = inputs[3 * i + 2]
        likely_num[i] = l
        unlikely_num[i] = m - l
        likely_prob[i] = p
        unlikely_prob[i] = 1.0 - p

    prob = prob_of_success(n, likely_num, likely_prob, unlikely_num, unlikely_prob)
    print_prob(prob)

개인적으로는 prob_of_choice 함수 내부의 P1 계산 부분이 가장 핵심이라고 생각합니다.

  • 선택한 숫자를 맞출 확률 (likely_prob[i] / likely_num[i], unlikely_prob[i] / unlikely_num[i])
  • 선택한 숫자를 못 맞출 확률 (1 - likely_prob[i] / likely_num[i], 1 - unlikely_prob[i] / unlikely_num[i])
  • 이 둘을 잘 조합해서 각 경우의 확률을 계산합니다.

혹시 더 궁금한 점이나 이해가 안 되는 부분이 있다면 언제든 말씀해 주세요. 함께 해결해 나가면서 성장하는 게 목표니까요 :) 긴 설명 읽어주셔서 감사합니다!

이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.