https://www.acmicpc.net/problem/2533

 

2533번: 사회망 서비스(SNS)

페이스북, 트위터, 카카오톡과 같은 사회망 서비스(SNS)가 널리 사용됨에 따라, 사회망을 통하여 사람들이 어떻게 새로운 아이디어를 받아들이게 되는가를 이해하는 문제가 중요해졌다. 사회망

www.acmicpc.net

 

핵심::
루트노드를 따로 주지 않기 때문에 트리를 양방향 그래프로 생성하고

자식노드는 부모노드를 탐색하지 못하게 하는 방법이 핵심이었다.

풀이::

DFS 탐색의 점화식은 다음과 같다.

DP[노드 번호][부모 노드 얼리어댑터 여부] = 필요한 하위 노드 얼리어댑터 총합( 본인 포함 )

기저 사례는 리프 노드에 도달 했을 경우인데 리프 노드의 경우는 자신의 부모의 얼리어댑터 여부만 보고 DP값을 계산 할 수 있다.

자신이 얼리어댑터일 경우의 값(sum1)와 얼리어댑터가 아닐 경우의 값(sum2) 두가지 합을 구한 뒤,

자신의 부모의 얼리어댑터 여부를 보고 min(sum1, sum2) 값을 사용할 것인지, sum1의 값만 쓸 것인지를 결정한다.

자신의 부모가 얼리어댑터가 아닐 경우에는 본인은 무조건 얼리어댑터가 되어야한다.

문제점::
탐색의 방향이 부모 -> 자식으로 정해져 있기 때문에 방향 그래프로 생성을 하니, 

간선 입력시 정점 번호를 거꾸로 주면 탐색 자체가 완료되지 않는다.

양방향 그래프로 생성할 경우 자식노드에 자신의 부모노드가 포함되기 때문에 이를 효율적으로 배제할 수 있는 방법을 찾아야 한다.

양방향 그래프로 생성하고 함수에 진입할 때마다 visit 배열에 표시하는 방식을 사용하면 이미 탐색된 정점에 대해서 DP값을 가져오지 않는다.

이를 해결하기 위해 visit 배열 표시와 확인을 자식 정점에 진입하기 전에 진행하고 만약 자식 노드에 DP값이 있을 때만 해당 값을 더하는 식으로 해결하였다.

만약 부모노드 일 경우엔 DP값이 계산되지 않았을 것이고 계산된 자식노드라면 정상적으로 DP값을 가지고 올 수 있다.

의견::
트리 디피 첨 해보는데 그냥 그래프 + DP같은 느낌인듯? 다익스트라랑 비슷한 것 같다.

코드::

더보기
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstdio>

using namespace std;
vector<vector<int>> arr;
vector<vector<int>> dp;
vector<bool> vis;

int getm(int x, bool pad) {//pad : 부모 노드의 얼리어댑터 여부
	int& ret = dp[x][pad ? 1 : 0];
	if (ret != -1) {
		return ret;
	}
	if (arr[x].size() == 1 && x != 1) {
		return ret = (pad ? 0 : 1);//부모의 얼댑 여부에 따라 리프 노드 값 결정
	}
	int sum1 = 1;//현재 노드가 얼댑인경우
	int sum2 = 0;//반대
	for (auto v : arr[x]) {				
		if (vis[v]) {
			if (dp[v][1] != -1) {
				sum1 += dp[v][1];
			}
			if (dp[v][0] != -1) {
				sum2 += dp[v][0];
			}
			continue;
		}
		vis[v] = true;
		sum1 += getm(v, true);//현재노드가 얼댑인 경우 하위노드 최소값
		sum2 += getm(v, false);//반대
	}
	if (pad) {
		return ret = min(sum1, sum2);
	}
	else {
		return ret = sum1;
	}
}
int main()
{
	int n; cin >> n;
	arr = vector<vector<int>>(n + 1);
	dp = vector<vector<int>>(n + 1, vector<int>(2, -1));
	vis = vector<bool>(n + 1, false);
	for (int i = 1; i < n; i++) {
		int a, b;
		scanf("%d %d", &a, &b);
		arr[a].push_back(b);
		arr[b].push_back(a);
	}
	vis[1] = true;
	cout << getm(1, true);

	return 0;
}