雑感等

音楽,数学,語学,その他に関するメモを記す.

二次元の波動方程式を陰解法+ガウス・ザイデル法で解く(Rust) ※バグあり

※バグあり メッシュ数を縦≠横にするとおかしくなる

とりあえずソースと結果を貼っておく。

陰解法と書いたが、実際は陰解法と陽解法を平均した差分方程式を使っている。(クランクニコルソン法っぽくした)

↓の記事で導出した式と、陽解法を組み合わせている。

波動方程式の有限差分法を半自動で導出(sympy) - 雑感等

ソースの条件だとクーラン数=0.1573<1で陽解法でも解ける。

const DT: f64 = 1.0 / 44100.0;

const DX: f64 = 0.05;

このあたりを変えるとクーラン数が変わる。

クーラン数を1以上にしても一応解けたが、上げるとガウス・ザイデル法の反復回数が増える(収束が遅くなる)。

下のソースだと反復回数に上限を設けているから、収束し無くても打ち切る場合がある。

ある程度まで反復回数が増えて、一定を超えると、誤差が増えていくはず。

■出力結果

■Rustプログラムのソース

use std::time::{Duration, Instant};

use plotters::prelude::*;

const T_MAX: usize = 100;
const X_MAX: usize = 50;
const Y_MAX: usize = 50;

const M_ELEM: usize = 3 * 4 + 4 * 2 * (X_MAX - 2) + 4 * 2 * (Y_MAX - 2) + 5 * (X_MAX - 2) * (Y_MAX - 2);


const T_WINDOW: usize = 4;

struct SparseMatrix {
    col: [usize; M_ELEM],
    row: [usize; M_ELEM],
    val: [f64; M_ELEM],
}

fn sub_main() {
    println!("M_ELEM = {:?}", M_ELEM);
    // plot ////////////////////////////////////////////////////////////////////////////////////////
    let root = BitMapBackend::gif("testrustimage.gif", (200, 200), 50).unwrap()
        .into_drawing_area();

    let cells = root.split_evenly((X_MAX, Y_MAX));

    let mut cell_nx;
    let mut cell_ny;
    let mut val;
    let mut col;


    // calc ////////////////////////////////////////////////////////////////////////////////////////
    const DT: f64 = 1.0 / 44100.0;
    const DX: f64 = 0.05;
    const C: f64 = 347.0;
    const COURANT: f64 = C * DT / DX;

    println!("DT:{},DX:{},C:{},COURANT:{}", DT, DX, C, COURANT);

    //u[time][y][x]
    let mut u: [[[f64; X_MAX]; Y_MAX]; T_WINDOW] = [[[0.0f64; X_MAX]; Y_MAX]; T_WINDOW];

    let mut M: [[f64; X_MAX * Y_MAX]; X_MAX * Y_MAX] = [[0.0f64; X_MAX * Y_MAX]; X_MAX * Y_MAX];
    let mut x: [f64; X_MAX * Y_MAX] = [0.0f64; X_MAX * Y_MAX];
    let mut b: [f64; X_MAX * Y_MAX] = [0.0f64; X_MAX * Y_MAX];

    let r = 0.5;
    let nu2 = COURANT * COURANT;


    //define M
    {
        let diag = 4.0 * r * nu2 + 2.0;
        let oth_elem = -r * nu2;
        let mut curr_ind = 0;
        for ny in 0..Y_MAX {
            for nx in 0..X_MAX {
                curr_ind = ny * X_MAX + nx;
                
                M[curr_ind][curr_ind] = diag;
                if nx > 0 {
                    M[curr_ind][curr_ind - 1] = oth_elem;
                }
                if nx < X_MAX - 1 {
                    M[curr_ind][curr_ind + 1] = oth_elem;
                }
                if ny > 0 {
                    M[curr_ind][curr_ind - X_MAX] = oth_elem;
                }
                if ny < Y_MAX - 1 {
                    M[curr_ind][curr_ind + X_MAX] = oth_elem;
                }
            }
        }
    }
    let M = M;

    println!("u:{}", u[0][0][X_MAX - 1]);

    let mut u_max = 2.0;
    let mut u_max_t: usize = 0;
    let mut u_max_x: usize = 0;
    let mut u_max_y: usize = 0;
    let mut u_min = -1.0;
    let mut u_min_t: usize = 0;
    let mut u_min_x: usize = 0;
    let mut u_min_y: usize = 0;

    u[1][10][20] = 1.0;


    let mut t_curr;
    let mut t_prev;
    let mut t_next;

    let loop_start = Instant::now();

    //main loop
    //gauss seidel: solve Mx=b
    {
        let omega = 1.0;
        let mut x_temp = 0.0;
        let mut m_temp = 0.0;
        let mut gs_diff = 0.0;

        let mut curr_ind = 0;
        for nt in 1..T_MAX - 1 {
            t_curr = nt % T_WINDOW;
            t_prev = (nt - 1) % T_WINDOW;
            t_next = (nt + 1) % T_WINDOW;

            //define b
            for ny in 1..=Y_MAX - 2 {
                for nx in 1..=X_MAX - 2 {
                    curr_ind = ny * X_MAX + nx;
                    b[curr_ind] =
                        r * nu2 * (u[t_prev][ny][nx - 1] + u[t_prev][ny - 1][nx] + u[t_prev][ny][nx + 1] + u[t_prev][ny + 1][nx])
                            + 2.0 * (1.0 - r) * nu2 * (u[t_curr][ny][nx - 1] + u[t_curr][ny - 1][nx] + u[t_curr][ny][nx + 1] + u[t_curr][ny + 1][nx])
                            - (4.0 * r * nu2 + 2.0) * u[t_prev][ny][nx] + (8.0 * (r - 1.0) * nu2 + 4.0) * u[t_curr][ny][nx];
                }
            }
            //gauss seidel
            for lp in 0_usize..200_usize {
                gs_diff = 0.0;
                for ny in 0..X_MAX * Y_MAX {
                    for nx in 0..ny {
                        m_temp -= M[ny][nx] * x[nx];
                    }
                    for nx in (ny + 1)..X_MAX * Y_MAX {
                        m_temp -= M[ny][nx] * x[nx];
                    }
                    x_temp = (b[ny] + m_temp) / M[ny][ny];
                    gs_diff += (x_temp - x[ny]).abs();
                    // println!("{:?}:{:?}:{:?}", nt, lp, x_temp - x[ny]);
                    x[ny] = omega * x_temp + (1.0 - omega) * x[ny];

                    m_temp = 0.0;
                }
                if gs_diff < f64::MIN_POSITIVE {
                    println!("break {}:{}", nt, lp);
                    break;
                }
            }
            println!("{:?}:{:?}", nt, gs_diff);
            //x->u
            for ny in 0..Y_MAX {
                for nx in 0..X_MAX {
                    curr_ind = ny * X_MAX + nx;
                    u[t_next][ny][nx] = x[curr_ind];
                }
            }
            // plot ////////////////////////////////////////////////////////////////////////////
            for (area, num) in (&cells).into_iter().zip(0..) {
                cell_nx = num % X_MAX;
                cell_ny = num / X_MAX;
                val = u[t_curr][cell_nx][cell_ny];
                col = (val - u_min) / (u_max - u_min);
                area.fill(&RGBColor((col * 255.0) as u8, (col * 255.0) as u8, (col * 255.0) as u8)).unwrap();
            }
            root.present().expect("unable to write");
        }
    }
    let duration = loop_start.elapsed();
    println!("elapsed:{:?}", duration);


    println!("max u[{}][{}][{}]={}", u_max_t, u_max_y, u_max_x, u_max);
    println!("min u[{}][{}][{}]={}", u_min_t, u_min_y, u_min_x, u_min);
}


fn main() {
    {
        const STACK_SIZE: usize = 1024 * 1024 * 1024 * 1024;
        std::thread::Builder::new()
            .stack_size(STACK_SIZE)
            .spawn(sub_main)
            .unwrap()
            .join()
            .unwrap();
    }
}