% 데이터 로드 및 전처리
% 학습 데이터
fileID_train = fopen('D:\hjKwon\KCL_202406\list\class_18\train.txt', 'r');

train_list = textscan(fileID_train, '%s %d');
fclose(fileID_train);

file_list_train  = train_list{1,1};
class_list_train = train_list{1,2};
train_size = length(file_list_train);

arr_rx_train  = cell(1, train_size);

for i = 1 : train_size
    if mod(i, 100) == 0
        i
    end 
    file_name_train = strcat('D:\hjKwon\KCL_202406\data\ALL\', file_list_train{i}, '\rx_plain.dat');
    rx_data_train = zeros(2000, 1);
    fileID2 = fopen(file_name_train);
    for j = 1 : 2000
        data_train = fgetl(fileID2);
        num_train  = str2double(data_train);
        rx_data_train(j,1) = real(num_train);
    end
    fclose(fileID2);
    arr_rx_train{1,i}  = rx_data_train;
end
debug = "import data_done";

num_rx_files_train = length(file_list_train);

crop_size = 2000;
start_rx = 1;

crop_rx_train = cell(1, num_rx_files_train);
for i = 1 : length(file_list_train)
    crop_rx_train{1,i} = arr_rx_train{1,i}(start_rx:crop_size-1);
end
debug = "crop data_done";

debug = "import train data done"
%%
% 검증 데이터
fileID_test = fopen('D:\hjKwon\KCL_202406\list\class_18\test.txt', 'r');

test_list = textscan(fileID_test, '%s %d');
fclose(fileID_test);

file_list_test  = test_list{1,1};
class_list_test = test_list{1,2};       
test_size = length(file_list_test);
arr_rx_test  = cell(1, test_size);

for i = 1 : test_size
    if mod(i, 100) == 0
        i
    end
    file_name_test = strcat('D:\hjKwon\KCL_202406\data\ALL\', file_list_test{i}, '\rx_plain.dat');
    rx_data_test = zeros(2000, 1);
    fileID2 = fopen(file_name_test);
    for j = 1 : 2000
        data_test = fgetl(fileID2);
        num_test  = str2double(data_test);
        rx_data_test(j,1) = real(num_test);
    end
    fclose(fileID2);
    arr_rx_test{1,i} = rx_data_test;
end
debug = "import data_done";

num_rx_files_test = length(file_list_test);
crop_rx_test = cell(1, num_rx_files_test);

crop_size = 2000;
start_rx = 1;

for i = 1 : length(file_list_test)
    crop_rx_test{1,i} = arr_rx_test{1,i}(start_rx:crop_size-1);
end

debug = "crop data_done";

debug = "import test data_done"
%%

% 파라미터 범위 정의
windowSizes = [256, 512, 1024]; % 시간 윈도우 크기
overlaps = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8]; % 오버랩 비율
fftSizes = [512, 1024, 2000]; % FFT 크기
windowFunctions = {@hann}; % 윈도우 함수
fs = 1e6; % 샘플링 주파수

% 기록 저장
csvFileName = 'results_stft.csv';

% 결과 저장을 위한 CSV 헤더
header = {'Window Size', 'Overlap', 'FFT Size', 'Window Function', 'Accuracy'};

% 헤더를 CSV 파일에 적용
fid = fopen(csvFileName, 'w');
fprintf(fid, '%s,%s,%s,%s,%s\n', header{:});
fclose(fid);

% 파라미터 최적화를 위한 반복
count = 0; % 조합 카운터 초기화
for windowSize = windowSizes
    for overlap = overlaps
        for fftSize = fftSizes
            for windowFunc = windowFunctions

                if windowSize > fftSize
                    continue; % 다음 반복 조건으로 넘어감
                end
                
                count = count + 1; % 조합 카운트 증가
                count

                % STFT 수행
                win = windowFunc{1}(windowSize);
                ol = ceil(windowSize * overlap);
                optSTFT = {"Window", win, "OverlapLength", ol, "FFTLength", fftSize, "FrequencyRange", "onesided"};

                funcName = func2str(windowFunc{1});
                funcName = funcName(1:end);
                param_comb = {windowSize, overlap, fftSize, funcName}

                % 모델 학습
                s_rx_train = cell(1, num_rx_files_train);   % s_rx_train : rx_file의 STFT output
                for i = 1 : num_rx_files_train
                    [s_rx_train{1,i}, f_rx_train, t_rx_train] = stft(crop_rx_train{1,i}-mean(crop_rx_train{1,i}), fs, optSTFT{:});
                    s_rx_train{1,i} = s_rx_train{1,i}(:); % flatten : 2D output을 1D로 만듦
                end
                debug = "Transform_done";

                % input data %
                inputD_train = zeros(length(s_rx_train{1,1}), num_rx_files_train);
                % ANN input data에 복소수가 들어가면 작동이 안되어 abs함수 사용 %
                for i = 1 : num_rx_files_train
                    inputD_train(:,i) = [abs(s_rx_train{1,i})];
                end
                debug = "make inputD done";

                Target_file_train = zeros(1, num_rx_files_train);
                for i = 1 : num_rx_files_train
                    Target_file_train(1,i) = class_list_train(i);
                end

                Target_filed_train = dummyvar(Target_file_train);
                Target_filed_train = Target_filed_train';
                debug = "make Target_filed done";

                debug = "make train data done"

                % hyperparameter setting %
                hiddenLayerSize = 10; % set hidden layer size
                net = patternnet(hiddenLayerSize); % set patternnet
                net.trainParam.max_fail = 100; % validation fail as 10, default valie is 6
                net.trainParam.min_grad=1e-5; % min gradient, default value is 1e-5
                net.trainParam.epochs=1000; % epoches

                [trNet, tr] = train(net, inputD_train, Target_filed_train);
                %[trNet, tr] = train(net, inputD_train, Target_filed_train, 'useGPU', 'yes');
                save trNet trNet

                debug = "make AI done"

                % 모델 검증
                threshold = 0.5;
                s_rx_test = cell(1, num_rx_files_test); % s_rx_test : rx_file의 STFT output
                for i = 1 : num_rx_files_test
                    [s_rx_test{1,i}, f_rx_test, t_rx_test] = stft(crop_rx_test{1,i}-mean(crop_rx_test{1,i}), fs, optSTFT{:});
                    s_rx_test{1,i} = s_rx_test{1,i}(:); % flatten : 2D output을 1D로 만듦
                end
                debug = "Transform_done";

                % input data %
                testD = zeros(length(s_rx_test{1,1}), num_rx_files_test);
                % ANN input data에 복소수가 들어가면 작동이 안되어 abs함수 사용 %
                for i = 1 : num_rx_files_test
                    testD(:,i) = [abs(s_rx_test{1,i})];
                end
                debug = "make inputD done";

                test_file = zeros(1, num_rx_files_test);
                for i = 1 : num_rx_files_test
                    test_file(1,i) = class_list_test(i);
                end

                test_filed = dummyvar(test_file);
                test_filed = test_filed';
                debug = "make test_filed done";

                debug = "make test data done"

                % test network
                load trNet trNet
                y = trNet(testD);

                debug = "AIs done";

                debug = "test AI done"
                    
                % Tuning %      % ANN output이 0 or 1로 딱딱 맞아떨어지지 않기 때문에 튜닝하는 작업
                [y1, y2] = size(y);
                for i = 1 : y2
                    for j = 1 : y1
                        if y(j,i) < threshold
                            y(j,i) = 0;
                        else
                            y(j,i) = 1;
                        end
                    end
                end
                
                % calculate the accuracy
                sum = 0;
                for i = 1 : y2
                    if y(1,i) == test_filed(1,i) && y(2,i) == test_filed(2,i) && y(3,i) == test_filed(3,i) && ...
                        y(4,i) == test_filed(4,i) && y(5,i) == test_filed(5,i) && y(6,i) == test_filed(6,i) && ...
                         y(7,i) == test_filed(7,i) && y(8,i) == test_filed(8,i) && y(9,i) == test_filed(9,i) && ...
                          y(10,i) == test_filed(10,i) && y(11,i) == test_filed(11,i) && y(12,i) == test_filed(12,i) && ...
                           y(13,i) == test_filed(13,i) && y(14,i) == test_filed(14,i) && y(15,i) == test_filed(15,i) && ...
                            y(16,i) == test_filed(16,i) && y(17,i) == test_filed(17,i) && y(18,i) == test_filed(18,i)
                        sum = sum + 1;
                    else
                        continue;
                    end
                end
                               
                [p, length_Y] = size(test_filed);
                
                accuracy = sum/length_Y;
                accuracy
                disp("----------------------------------------------------------------------");
                                
                % 결과를 CSV 파일에 추가
                fid = fopen(csvFileName, 'a');
                fprintf(fid, '%d,%f,%d,%s,%f\n', windowSize, overlap, fftSize, funcName, accuracy);
                fclose(fid);

                %메모리 해제
                clear s_rx_train;
                clear inputD_train;
                clear Target_file_train;
                clear Target_filed_train;
                clear net;
                clear trNet;
                clear s_rx_test;
                clear testD;
                clear test_file;
                clear test_filed;
                clear x;
                clear y;
            end
        end
    end
end
debug = "All Task Done"