본문 바로가기
컴퓨터/수학이랑

미분방정식에 대한 수치해석학적 해(Runge-Kutta) - 구현

by adnoctum 2010. 5. 23.
참고: 소스 코드는 이 글에 있는 것을 그냥 이용하면 됩니다. 만약 아래의 코드를 보았음에도 불구하고 어떻게 이용해야 하는지 모른다면 아직 C++ 에 익숙하지 않다는 것을 의미합니다. 이 경우 저에게 코드를 요청하셔도 별로 의미가 없습니다. 왜냐 하면,  이 경우 제가 소스를 드려도 제대로 이해하기 어려울 것이기 때문입니다. 그리고 제가 보내는 소스가 이 글에 나와 있는 것과 별반 다르지 않기 때문이기도 합니다. 


   이전 글에서 보았던 미분방정식에 대한 해를 수치해석학적으로 구하는 Runge-Kutta of Order Four 알고리즘을 구현해 보자. 단, 약간 일반화시켜서 구현한다. 알고리즘은 Richard L. Burden; J. Douglas Faires, Numerical Analysis, 8th edition, pp.278‐279, Algorithm 5.2 Runge‐Kutta (Order Four) 에 기반한다. 즉 이 곳에 구현해 놓은 것은 system of differential equations 을 구하는 알고리즘이 되겠다. 즉,

y1' = f1(t, y1, ..., yN)
y2' = f2(t, y1, ..., yN)
...
yN' = fN(t, y1, ..., yN)

의 형태로 되어 있는 문제를 푸는 것이다. 예를 들면,

y1' = t - y1*y2
y2' = t^2 + y1 - y2

인 경우 y1(t)와 y2(t) 를 구하기 위한 코드이다. 이 코드는 사용 예제를 뒤에 첨부한다.



   전체적인 구조는 다음과 같다.


ode_solve::solve 함수에서 실제 작업을 진행하는데, 각 step xi에서 처리를 하기 전/후에 preprocessing 과 postprocessing 함수를 호출해서 반환값이 false 이면 해를 더이상 구하지 않고 중단하는 것이다. 그래서 pre-/post- processing 함수는 virtual 로 구현되어 있다. 이와 같은 것까지 고려하면 전체 구조는 다음과 같이 표현할 수 있다.





ode_solve::solve의 prototype 은 다음과 같다.

bool ode_solve::solve(

   double start_point, // 추정 구간의 시작점

   double end_point,   // 추정 구간의 끝점

   int point_number,   // 추정 구간을몇 개의 점으로 나눌 것인가.step size를 결정

   vector<double> initial, // 초기값

   vector<diff_function*> *function, // 미분된 식

   vector<vector<double> > *result   // 결과

)


위에서 vector<diff_function*> *function 이 바로 미분된 식을 표현하는 개체이다. ode_solve는 diff_function 클래스에서 상속된 클래스를 파라미터로 받아서 이 객체의 operator() 를 호출한다. 즉, Runge-Kutta of Order F(RKF) 알고리즘 내에서 함수 값을 받아 가야 하는 경우에 operator() 를 호출하는 것이다. 따라서 diff_fuction 은 fuctor 로 사용되고 있으며, diff_function 의 operator() 는 순수가상함수로 해 놓는다. ode_solve 의 헤더 파일은 다음과 같다.

// ode_solve.h

 

#ifndef __DEFINITION_OF_ODE_SOLVE__

#define __DEFINITION_OF_ODE_SOLVE__

 

#include <vector>

 

using namespace std;

 

class diff_function{

public:

    virtual double operator()(double time, vector<double> *param) = 0;

};

 

class ode_solve{

private:

    char    _termination_code;

public:

    bool solve(double start_point, double end_point, int point_number, vector<double> initial, vector<diff_function*> *function, vector<vector<double> > *result);

    virtual bool preprocessing(double time, std::vector<double>* initial, std::vector<double> *approximate);

    virtual bool postprocessing(double time, std::vector<double>* initial, std::vector<double> *approximate);

    char get_termination_code(){ return _termination_code; };

 

};

 

#endif



실제 구현 파일 ode_solve.cpp 파일은 다음과 같다.

    1 // ode_solve.cpp

    2 #include "stdafx.h"

    3 #include "ode_solve.h"

    4 #include <conio.h>

    5 

    6 bool ode_solve::solve(

    7 double start_point,

    8 double end_point,

    9 int point_number,

   10 vector<double> initial,

   11 vector<diff_function*> *function,

   12 vector<vector<double> > *result)

   13 {

   14     int p = 0;

   15 

   16     int m = function->size(); // number of equations

   17 

   18     // Set h = (b-a)/N;

   19     double h = (end_point - start_point) / point_number; //(b - a) / N;

   20     // t = a;

   21     double t = start_point; //a;

   22 

   23     // For j = 1, 2, ..., m set wj = aj

   24     vector<double> w(initial.begin(), initial.end());

   25 

   26     vector<double> k_1(m,0);

   27     vector<double> k_2(m,0);

   28     vector<double> k_3(m,0);

   29     vector<double> k_4(m,0);

   30 

   31 

   32     vector<double> w1(m);

   33     vector<double> w2(m);

   34     vector<double> w3(m);

   35 

   36     // For i = 1, 2, ... N do steps 5 - 11

   37     for(int i = 1; i <= point_number; i++){

   38         if(preprocessing(t, &initial, &w) == false){

   39             _termination_code = -1;

   40             break;

   41         }

   42         // Step 5 - For j = 1, 2, ... , m set

   43         // k_1,j = hf_j(t,w_1, ..., w_m).

   44         int j = 0;

   45         for(int j = 1; j<=m; j++){

   46             k_1[j-1] = h* (*(function->operator [](j-1)))(t, &w);

   47         }

   48 

   49 

   50         // Step 6 - For j = 1, 2, ... , m set

   51         // k_2,j = hf_j(t+h/2, w1+k_11/2, w2 + k_12/2, ... , wm + k_1m/2);

   52         for(p = 0; p<m; p++){

   53             w1[p] = w[p] + k_1[p] / 2;       

   54         }

   55         for(j = 1; j<=m; j++){

   56             k_2[j-1] = h* (*(function->operator [](j-1)))(t + h/2, &w1);

   57         }

   58 

   59         // Step 7 - For j = 1, 2, ... , m set

   60         // k_3,j = hf_j(t+h/2, w1+k_21/2, w2 + k_22/2, ... , wm + k_2m/2);

   61         for(p = 0; p<m; p++) w2[p] = w[p] + k_2[p]/2;

   62         for(j = 1; j<=m; j++){

   63             k_3[j-1] = h* (*(function->operator [](j-1)))(t + h/2, &w2);

   64         }

   65 

   66         // Step 8 - For j = 1, 2, ... , m set

   67         // k_4,j = hf_j(t+h/2, w1+k_31/2, w2 + k_32/2, ... , wm + k_3m/2);

   68         for(p = 0; p<m; p++) w3[p] = w[p] + k_3[p];///2;

   69         for(j = 1; j<=m; j++){

   70             k_4[j-1] = h* (*(function->operator [](j-1)))(t + h/*/2*/, &w3);

   71         }

   72 

   73         // Step 9 - For j = 1, 2, ... , m set

   74         // w[j] = w[j] + (k_1[j] + 2k_2[j] + 2k_3[j] + k_4[j])/6;

   75         for(j = 1; j<=m; j++){

   76             w[j-1] = w[j-1] + ( k_1[j-1] + 2*k_2[j-1] + 2*k_3[j-1] + k_4[j-1])/6;

   77         }

   78 

   79         // Step 10 - Set t = a + ih;

   80         t = start_point + i*h;

   81 

   82 

   83         // In cases wi should be positive, set them 0.

   84         for(p = 0; p<m; p++){

   85             if(w[p] < 0.0000005) w[p] = 0;

   86         }

   87 

   88         // CALLBACK function.

   89         // Process appropriate works at each step.

   90         if(postprocessing(t, &initial, &w) == false){

   91             _termination_code = 1;

   92             break;

   93         }

   94 

   95         // Step 11 - OUTPUT(t,w1, w2, ... , wm);

   96         result->push_back(w);

   97     }

   98 

   99     _termination_code = 0;

  100 

  101 

  102     return true;

  103 }

  104 bool ode_solve::preprocessing(double time, std::vector<double>* initial, std::vector<double> *approximate)

  105 {

  106     return true;

  107 }

  108 bool ode_solve::postprocessing(double time, std::vector<double>* initial, std::vector<double> *approximate)

  109 {

  110     return true;

  111 }


위 코드에서 줄번호를 제외한 코드는 다음에 있다.



코드를 살펴 보자. 실제 함수 yi 를 추정한 값이 w 이다. 이 값은 24 줄에서 처음 나오며, 처음에는 파라미터로 넘겨 준 초기값으로 설정되는 것을 알 수 있다. 추정을 하는 실제적인 부분은 42~80 줄 까지이다. 추정하려는 구간을 구간 개수로 나눈 이후, 각 step 에서 추정을 하기 위해 for 문을 도는 것을 볼 수 있다. 또한, 각 step 에서 추정을 하기 직전(38 번 줄)/직후(90 번 줄) 에서, 시간, 이 직전 step 에서 구한 추정치, 초기값을 파라미터로 넘겨서 혹시 추정을 중단해야 하는지를 확인한다. 초기값도 같이 넘겨 주는 이유는, 어떤 경우에는 초기값보다 2배 이상 y2 가 되었으면 중단한다, 와 같은 경우가 있을 수도 있기 때문이다.

46 번 줄을 보자. C++ 에 익숙하지 않으면 매우 혼동될 수 있는 코드인데, 설명하자면 다음과 같다.

46             k_1[j-1] = h* (*(function->operator [](j-1)))(t, &w);

원래의 알고리즘은 k1[j-1] = h * fj(t, w) 이다. 즉, k1 의 j 번째 요소에 j 번째 f 로 t 와 w 를 넣고 계산한 값을 넣으라는 것이다. 우리는 함수 형태가 뭔지에 상관이 없이 이 알고리즘을 사용할 수 있기 때문에, 함수 객체 diff_function 의 포인터를 요소로 갖고 있는 vector 인 function 을 인자로 받았다. 따라서 function 의 j 번째 요소가 j 번째 함수식(yj' = fj(t, y1, ..., yN) 에서 우변)인 것이다. 이 함수에 t 와 함수값을 넘기면 되는 것이다. 또한 diff_function 에서는 pure virtual function 으로 선언된 operator() 를, diff_function 을 상속받은 클래스에서 구현을 한 뒤 넘겨 주면, 46 번 줄에서처럼 operator() 를 호출하면 그 함수 값을 받아갈 수 있는 것이다. 또한, diff_function 에서 상속받은 클래스들은 두 번째로 받는 std::vector<double>* w의 j 번째 요소는 yj를 의미한다.

84~86 번 줄은, 내가 이 코드를 생물학 관련 코드에 사용하느라 넣은 부분인데, 추정하려는 값이 0 이하로 갈 수 없다면 0 으로 해준다. 만약 이와 같은 가정이 필요 없다면 이 부분을 주석처리 해서 사용하면 된다.


이 코드는 예제를 좀 써 본다. 3개만 보자.

첫 번째 예제는 다음과 같은 상황에 대한 식이다. 이 식은 물론 간단하니까 analytic 하게 풀 수 있긴 하다.


위 도식은 다음과 같이 모델링 할 수 있다.

X' = -aX + b.Y
Y' = aX - deg.Y

이 때, a 는 X 가 Y 로 변하는 속도이고, b 는 Y 가 X -> Y 로의 전환을 억제하는 정도, 이고, deg 는 Y 의 분해 속도이다. X 가 감소하는 양은 X 의 양에 a 를 곱한 값일테고, 이 감소분은 Y 가 많으면 많을수록 줄어들 것이다. Y 의 변화량은, X 에서 전환되는 양만큼 증가할테고, 줄어드는 양은 Y 의 양에 비례할 것이다. 따라서 위와 같은 모델링이 가능하다.

위 상황을 코드로 옮겨 보자. 우선 X 에 대한 식은 다음과 같이 할 수 있다.

    1 class diff_x : public diff_function{

    2 public:

    3     double operator()(double time, std::vector<double>* p){

    4         return -_alpha*(*p)[0] + _beta*(*p)[1];

    5     }

    6 };


4 번 줄에서 보는 바와 같이 코드는 dX / dt 를 수식 그대로 표현하고 있다. (개인적인 코딩 convention 상 밑줄 하나로 시작하는 변수는 전역/멤버 변수임). p 가 사용된 변수를 나타내는 변수이고, 0 은 X, 1 은 Y 를 의미한다(이것은 프로그램 작성자의 암묵적 가정에 기반한다). Y 에 대한 식은 다음과 같이 할 수 있다.

    1 class diff_y : public diff_function{

    2 public:

    3     double operator()(double time, std::vector<double>* p){

    4         return _alpha*(*p)[0] - _deg*(*p)[1];

    5     }

    6 };


이제 실제로 미방을 푸는 코드를 보면,

    1 diff_x dx;

    2 diff_y dy;

    3 

    4 vector<diff_function*> function;

    5 function.push_back(&dx);

    6 function.push_back(&dy);

    7 

    8 

    9 _alpha = 0.8;

   10 _beta = 0.2;

   11 _deg = 0.1;

   12 vector<double> initial;

   13 initial.push_back(10);

   14 initial.push_back(0);

   15 

   16 vector<vector<double> > result;

   17 

   18 ode_solve solver;

   19 solver.solve(0,100,10000,initial,&function,&result);



19 번 줄에서 보는 바와 같이 [0, 100] 까지 X, Y 를 구하며, 이 구간을 10000 개로 잘라서 구하고 있다. 만약 단위가 초(second) 라면 0.01 초마다 X, Y 값을 구하는 것이다. 이렇게 하면

result[t][0] 은 X 값을,
result[t][1] 은 Y 값을

담고 있게 된다.

이번 예제는, 위의 예제에 time delay 를 넣은 것이다. 이렇게 하면 oscillation 생기는 것을 알 수 있다. 즉, 단 두 개의 component 로 이루어진 system 이 oscillation 을 내기 위해서는 time delay 가 필요한 것이다. 세포 내에서는 time delay 가 주로 transcription/translation 단계에서 일어 난다.

모델은, 위의 경우에서, Y 가 생기자마자 X 의 전환을 막을 수 있는 것이 아니라 약간의 시간 뒤에서부터 Y 가 작용한다고 하자. 이 경우의 X' 는 다음과 같이 코드로 표현할 수 있다.


    1 

    2 class delayed_dx : public diff_function{

    3 private:

    4     queue<double> pre_value;

    5 public:

    6     double operator()(double time, vector<double>* p){

    7         pre_value.push((*p)[1]);

    8 

    9         double val = 0;

   10 

   11         if(pre_value.size() > _tau){

   12             val = pre_value.front();

   13             pre_value.pop();

   14         }

   15 

   16         return -_alpha*(*p)[0] + _beta*val;

   17     }

   18 }



큐에 Y 값을 집어 넣다가(7번 줄), time delay 로 설정한 변수 _tau 보다 큐 길이가 커지면(11번 줄) 그 때부터 앞의 요소부터 빼내서(13번 줄) X 를 막는 것을 나타내는 값으로 사용(16 번 줄)하면 되는 것이다.

이번 예제는 내 석사 졸업 논문에 사용한 것이고, 이 논문에 약간의 설명이 나와 있다. 사용된 모델은 논문을 참조하는 편이 좋고(너무 생물학적이라...), 수식은


위와 같다. 이 모델을 코드로 표현하면 다음과 같다.


    1 

    2 // x0 : ubiquitin

    3 class simple_total_dx0 : public diff_function{

    4 public:

    5     double operator()(double time, vector<double> *p){

    6         return c0;

    7     }

    8 };

    9 

   10 // x1 : inactive TNFR complex

   11 class simple_total_dx1 : public diff_function{

   12 public:

   13     double operator()(double time, vector<double> *p){

   14         return a21*(*p)[2] - a12*(*p)[1]*pow((*p)[0],u)*S;

   15 

   16     }

   17 };

   18 

   19 // x2 : active TNFR complex

   20 class simple_total_dx2 : public diff_function{

   21 public:

   22     double operator()(double time, vector<double> *p){

   23         return a12*(*p)[1]*pow((*p)[0],u)*S + a82*(*p)[8] - a21*(*p)[2] - a28*(*p)[2]*(*p)[3];

   24     }

   25 };

   26 

   27 // x3 : active A20

   28 class simple_total_dx3 : public diff_function{

   29 public:

   30     double operator()(double time, vector<double> *p){

   31         return a43*(*p)[4]*(*p)[2];

   32     }

   33 };

   34 

   35 // x4 : inactive A20

   36 class simple_total_dx4 : public diff_function{

   37 public:

   38     double operator()(double time, vector<double> *p){

   39         return -a43*(*p)[4]*(*p)[2];

   40     }

   41 };

   42 

   43 // x5 : inactive IKK complex

   44 class simple_total_dx5 : public diff_function{

   45 public:

   46     double operator()(double time, vector<double> *p){

   47         return a65*(*p)[6] - a56*(*p)[5]*(*p)[2];

   48     }

   49 };

   50 

   51 // x6 : active IKK complex

   52 class simple_total_dx6 : public diff_function{

   53 public:

   54     double operator()(double time, vector<double> *p){

   55         return a56*(*p)[5]*(*p)[2] - a65*(*p)[6];

   56     }

   57 };

   58 

   59 // x7 : proteasome

   60 class simple_total_dx7 : public diff_function{

   61 public:

   62     double operator()(double time, vector<double> *p){

   63         return 0;

   64     }

   65 };

   66 

   67 class simple_total_dx8 : public diff_function{

   68 public:

   69     double operator()(double time, vector<double> *p){

   70         return a28*(*p)[2]*(*p)[3] - a82*(*p)[8] - a8*(*p)[8]*pow((*p)[0],g1)*pow((*p)[7],h1);

   71     }

   72 };



위에서 본 바와 같이, 알고리즘을 직접 구현하면, 그 내부에 우리가 하고 싶은 개입(intervention) 을 마음껏 적용시킬 수 있다. 나도 간단한 것은 그냥 matlab 으로 하기는 하는데, 그러면 0 이 되는 것을 막는 것이라던가, 하는 것을 하기 까다롭다. 바로 이렇게 중간에 여러 개입을 각각의 상황에 맞게 할 수 있기 때문에 나는 대부분의 코드를 직접 구현하는 것을 좋아한다. ㅋ