코딩하는 오리

[백준] 11049_행렬 곱셈 순서 (java) 본문

알고리즘/백준

[백준] 11049_행렬 곱셈 순서 (java)

jooeun 2023. 1. 13. 17:45

https://www.acmicpc.net/problem/11049

 

11049번: 행렬 곱셈 순서

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같

www.acmicpc.net

 

 

문제를 읽어보고 바로 DP문제라는 사실을 알아내지 못했다.

처음에는 어떤 두 행렬을 먼저 곱할지 순서를 순열을 통해서 확인하는 순열(완전 탐색)을 생각했다. 

그러나 최대 499!은 말도 안되는 비교 횟수이기에 다른 방법으로 해결해야했다.

 

DP를 적용하기 위해서는 점화식을 잘 세워야 한다. 이전 값을 중복해서 잘 사용할 수 있어야 한다.

 

문제에서 N은 최대 500이지만 N을 3으로 줄여 총 3개의 행렬이 있다고 생각해보자.

 

편의상 (i~j)를 i번째 행렬부터 j번째 행렬까지 행렬 연산에 대한 최솟값이라 표현하겠다.

 

이때의 결과값은(1~3) 연산의 최솟값은 ((1, 2), (3)) 또는 ((1), (2, 3)) 두가지 중 최솟값이된다.

(1~3)의 최솟값을 구했다면, (2~4), (3~5) 또한 같은 원리로 구할 수있다.

N이 4일 때는 (1~4)는 ((1), (2~4)) 또는 ((1~2), (3~4)) 또는 ((1~3) (4))  중 최솟값이 된다. 

 

이렇게 1~N개의 행렬이 있을 때 이렇게 ((1~K), (K+1,N))로 이분하여 (1~N)의 최솟값을 구할 수 있다.

 

dp를 위한 N*N 이차원배열을 생성하고, dp[i][j]는 (i~j)에 대한 값을 저장한다.

(i~i)는 자기자신에 대해서는 연산이 없으므로 dp[i][i]는 0이 되겠고, i와 j(i보다 큰값)에 대해 값을 연산한다.

 

 

이를 위해서 3차 for문을 구성한다.

 

가장 바깥 for문은 dp배열의 n개의 대각선을 채우기 위한 것이다. d값을 증가하며 연산하는 행렬의 갯수를 늘려간다.

 

두번째 반복문은 dp배열의 대각선 한개를 채우기 위한 것이며, i는 대각선의 행, i+d는 대각선의 열이다.

 

세번째 반복문에서는 k기준으로 연산을 나눠서 dp[i][j]의 최소 연산수를 찾는 것이다. 

아까 N=4 일 때 (1~4)는 ((1), (2~4)) 또는 ((1~2), (3~4)) 또는 ((1~3) ,(4))  중 최솟값 이었던 것 처럼 최소 ((i~k), (k+1,j))를 찾는다.

 

이렇게 dp[i][j]보다 작은 dp[i][k] + dp[k+1][j] + arr[i][0] * arr[k][1] * arr[j][1] 값으로 갱신해주면 된다.

 

sum()

 

1~N개의 행렬 연산의 최솟값을 구하는 것인데 작은 것 부터 시작해서 (1~2), (2~3), ..., (1~3), (2~4), ... (1~4), (2~5), ... ... 이렇게 늘려가며 1~N 사이 연산의 최솟값을 구할 수 있다. 이렇게 두 개의 정점을 사용하므로 이차원 배열에 저장하여 누적해나가면 되겠다.

 

너무 어렵다. ㅎ

 

 

코드

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
	
	static int[][] arr;

	public static void main(String[] args) throws Exception {
		
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st;
		
		int N = Integer.parseInt(br.readLine());
		
        	//행렬 정보 저장
		arr = new int[N][2];
		for(int i=0;i<N;i++) {
			st = new StringTokenizer(br.readLine());
			arr[i][0] = Integer.parseInt(st.nextToken());
			arr[i][1] = Integer.parseInt(st.nextToken());
		}
	
		int[][] dp = new int[N][N];
		
		for(int d=1;d<N;d++) {
			for(int i=0;i+d<N;i++) {
				int min = Integer.MAX_VALUE;
				for(int k=i;k<i+d;k++) {
					int val = dp[i][k]+dp[k+1][i+d] + sum(i,k,i+d);
					min = Math.min(min, val);
				}
				dp[i][i+d] = min;
			}
		}
		
		System.out.println(dp[0][N-1]);
	}

	private static int sum(int i, int k, int j) {
		return arr[i][0] * arr[k][1] * arr[j][1];
	}

}