% 데이터 로드 및 전처리
% 학습 데이터
fileID_train = fopen('D:\hjKwon\KCL_202406\list\test_0708\train.txt', 'r');

train_list = textscan(fileID_train, '%s %d');
fclose(fileID_train);

file_list_train  = train_list{1,1};
class_list_train = train_list{1,2};
train_size = length(file_list_train);

arr_rx_train  = cell(1, train_size);

for i = 1 : train_size
    if mod(i, 100) == 0
        i
    end 
    file_name_train = strcat('D:\hjKwon\KCL_202406\data\ALL\', file_list_train{i}, '\rx_plain.dat');
    rx_data_train = zeros(2000, 1);
    fileID2 = fopen(file_name_train);
    for j = 1 : 2000
        data_train = fgetl(fileID2);
        num_train  = str2double(data_train);
        rx_data_train(j,1) = real(num_train);
    end
    fclose(fileID2);
    arr_rx_train{1,i}  = rx_data_train;
end
debug = "import data_done";

num_rx_files_train = length(file_list_train);

crop_size = 2000;
start_rx = 1;

crop_rx_train = cell(1, num_rx_files_train);
for i = 1 : length(file_list_train)
    crop_rx_train{1,i} = arr_rx_train{1,i}(start_rx:crop_size-1);
end
debug = "crop data_done";

debug = "import train data done"
%%
% 검증 데이터
fileID_test = fopen('D:\hjKwon\KCL_202406\list\test_0708\test.txt', 'r');

test_list = textscan(fileID_test, '%s %d');
fclose(fileID_test);

file_list_test  = test_list{1,1};
class_list_test = test_list{1,2};       
test_size = length(file_list_test);
arr_rx_test  = cell(1, test_size);

for i = 1 : test_size
    if mod(i, 100) == 0
        i
    end
    file_name_test = strcat('D:\hjKwon\KCL_202406\data\ALL\', file_list_test{i}, '\rx_plain.dat');
    rx_data_test = zeros(2000, 1);
    fileID2 = fopen(file_name_test);
    for j = 1 : 2000
        data_test = fgetl(fileID2);
        num_test  = str2double(data_test);
        rx_data_test(j,1) = real(num_test);
    end
    fclose(fileID2);
    arr_rx_test{1,i} = rx_data_test;
end
debug = "import data_done";

num_rx_files_test = length(file_list_test);
crop_rx_test = cell(1, num_rx_files_test);

crop_size = 2000;
start_rx = 1;

for i = 1 : length(file_list_test)
    crop_rx_test{1,i} = arr_rx_test{1,i}(start_rx:crop_size-1);
end

debug = "crop data_done";

debug = "import test data_done"
%%

% CSV 파일의 경로 설정
csvFilePath = 'D:\hjKwon\KCL_202406\testroom\stft\results_stft_top5.csv';

% CSV 파일 읽기
opts = detectImportOptions(csvFilePath);
opts.VariableNamingRule = 'preserve';
data = readtable(csvFilePath, opts);

% 각 열 데이터를 변수로 저장
windowSizes = data.('Window Size');
overlaps = data.('Overlap');
fftSizes = data.('FFT Size');
windowFunctions = data.('Window Function');
numParams = height(data); % 파라미터 조합의 수

fs = 1e6; % 샘플링 주파수

% 기록 저장
csvFileName ='results3_stft.csv';

% 결과 저장을 위한 CSV 헤더
header = {'Window Size', 'Overlap', 'FFT Size', 'Window Function', 'Accuracy'};

% 헤더를 CSV 파일에 적용
fid = fopen(csvFileName, 'w');
fprintf(fid, '%s,%s,%s,%s,%s\n', header{:});
fclose(fid);

% 파라미터 최적화를 위한 반복
count = 0; % 조합 카운터 초기화
for i = 1:numParams
    % 현재 파라미터 가져오기
    windowSize = windowSizes(i);
    overlap = overlaps(i);
    fftSize = fftSizes(i);
    windowFunc = char(windowFunctions(i)); % 문자열로 변환

    % 윈도우 함수 설정
    switch windowFunc
        case 'hamming'
            win = hamming(windowSize);
        case 'hann'
            win = hann(windowSize);
        case 'blackman'
            win = blackman(windowSize);
        case 'bartlett'
            win = bartlett(windowSize);
    end  

    count = count + 1; % 조합 카운트 증가
    count

    % STFT 수행
    ol = ceil(windowSize * overlap);
    optSTFT = {"Window", win, "OverlapLength", ol, "FFTLength", fftSize, "FrequencyRange", "onesided"};

    funcName = windowFunc;
    param_comb = {windowSize, overlap, fftSize, funcName};
    param_comb

    % 반복 학습
    numIterations = 10;
    bestAccuracy = 0;
    for iter_count = 1:numIterations
        fprintf("Iteration %d\n", iter_count);

        % 모델 학습
        s_rx_train = cell(1, num_rx_files_train);   % s_rx_train : rx_file의 STFT output
        for i = 1 : num_rx_files_train
            [s_rx_train{1,i}, f_rx_train, t_rx_train] = stft(crop_rx_train{1,i}-mean(crop_rx_train{1,i}), fs, optSTFT{:});
            s_rx_train{1,i} = s_rx_train{1,i}(:); % flatten : 2D output을 1D로 만듦
        end
        debug = "Transform_done";
    
        % input data %
        inputD_train = zeros(length(s_rx_train{1,1}), num_rx_files_train);
        % ANN input data에 복소수가 들어가면 작동이 안되어 abs함수 사용 %
        for i = 1 : num_rx_files_train
            inputD_train(:,i) = [abs(s_rx_train{1,i})];
        end
        debug = "make inputD done";
    
        Target_file_train = zeros(1, num_rx_files_train);
        for i = 1 : num_rx_files_train
            Target_file_train(1,i) = class_list_train(i);
        end
    
        Target_filed_train = dummyvar(Target_file_train);
        Target_filed_train = Target_filed_train';
        debug = "make Target_filed done";
    
        debug = "make train data done"
    
        % hyperparameter setting %
        hiddenLayerSize = 10; % set hidden layer size
        net = patternnet(hiddenLayerSize); % set patternnet
        net.trainParam.max_fail = 100; % validation fail as 10, default valie is 6
        net.trainParam.min_grad=1e-5; % min gradient, default value is 1e-5
        net.trainParam.epochs=1000; % epoches
    
        [trNet, tr] = train(net, inputD_train, Target_filed_train);
        %[trNet, tr] = train(net, inputD_train, Target_filed_train, 'useGPU', 'yes');
        save trNet trNet
    
        debug = "make AI done"
    
        % 모델 검증
        threshold = 0.5;
        s_rx_test = cell(1, num_rx_files_test);   % s_rx_test : rx_file의 STFT output
        for i = 1 : num_rx_files_test
            [s_rx_test{1,i}, f_rx_test, t_rx_test] = stft(crop_rx_test{1,i}-mean(crop_rx_test{1,i}), fs, optSTFT{:});
            s_rx_test{1,i} = s_rx_test{1,i}(:); % flatten : 2D output을 1D로 만듦
        end
        debug = "Transform_done";
    
        % input data %
        testD = zeros(length(s_rx_test{1,1}), num_rx_files_test);
        % ANN input data에 복소수가 들어가면 작동이 안되어 abs함수 사용 %
        for i = 1 : num_rx_files_test
            testD(:,i) = [abs(s_rx_test{1,i})];
        end
        debug = "make inputD done";
    
        test_file = zeros(1, num_rx_files_test);
        for i = 1 : num_rx_files_test
            test_file(1,i) = class_list_test(i);
        end
    
        test_filed = dummyvar(test_file);
        test_filed = test_filed';
        debug = "make test_filed done";
    
        debug = "make test data done"
    
        % test network
        load trNet trNet
        y = trNet(testD);
    
        debug = "AIs done";
    
        debug = "test AI done"
            
        % Tuning %      
        [y1, y2] = size(y);
        for i = 1 : y2
            for j = 1 : y1
                if y(j,i) < threshold
                    y(j,i) = 0;
                else
                    y(j,i) = 1;
                end
            end
        end
        
        % calculate the accuracy
        sum = 0;
        for i = 1 : y2
            if y(1,i) == test_filed(1,i) && y(2,i) == test_filed(2,i) && y(3,i) == test_filed(3,i) && ...
                    y(4,i) == test_filed(4,i) && y(5,i) == test_filed(5,i) && y(6,i) == test_filed(6,i)
                sum = sum + 1;
            else
                continue;
            end
        end
                       
        [p, length_Y] = size(test_filed);
        
        accuracy = sum/length_Y;
        accuracy

        % 최고 정확도 갱신
        if accuracy > bestAccuracy
            bestAccuracy = accuracy;
        end
        bestAccuracy

        %메모리 해제
        clear s_rx_train;
        clear inputD_train;
        clear Target_file_train;
        clear Target_filed_train;
        clear net;
        clear trNet;
        clear s_rx_test;
        clear testD;
        clear test_file;
        clear test_filed;
        clear x;
        clear y;        
        
        disp("----------------------------------------------------------------------");             
    end
                                    
    % 결과를 CSV 파일에 추가
    fid = fopen(csvFileName, 'a');
    fprintf(fid, '%d,%f,%d,%s,%f\n', windowSize, overlap, fftSize, funcName, bestAccuracy);
    fclose(fid);
end

debug = "All Task Done"