function [outStruct,warningFlag,obj] = crossvalidationGUI(obj,inStruct,warningFlag)
% Graphical user interface to set options for the calculation of the cross
% validation error
%
% function outStruct = crossvalidationGUI(obj,inStruct)
%
% Inputs
%
% obj:              model object
% inStruct:         Structure containing all already set crossvalidation
%                   options
%
%
% Outputs
%
% outStruct:        Structure containing all changed/new crossvalidation
%                   options

% Julian Belz, 19-Mrz-2012
% Institut fr Mess- und Regelungstechnik, Universitt Siegen, Deutschland
% Institute of Measurement and Control, University of Siegen, Germany
% Copyright (c) 2012 by Julian Belz

if nargin < 2
    inStruct.numberOfBatches        = 2;
    warningFlag                     = 1;
    if isempty(obj.history.leafModelIter{1,1})
        % Model is not already trained
        inStruct.modelComplexities      = 1:20;
    else
        % Model is already trained
        inStruct.modelComplexities      = size(obj.history.leafModelIter,2);
    end
end

outStruct       = inStruct;
if isempty(obj.input) || isempty(obj.output)
    % Without a dataSet it makes no sense to call this function. Therefor
    % the function will be aborted.
    msgbox('Because there is no dataSet defined in your model, no crossvalidationGUI is opened!','Error','error');
    return;
end

% calculation of the number of lines for the checkboxes that select for
% what model complexities the cross validation errors should be calculated
if isempty(obj.history.leafModelIter{1,1})
    numberOfComplexities = 1:20;
else
    numberOfComplexities = 1:size(obj.history.leafModelIter,2);
end
% numberOfComplexities    = inStruct.modelComplexities;
cbPerLine               = 11;
numberOfLines           = ceil(max(numberOfComplexities)/cbPerLine);

% definition of the position of the gui elements
widthEdit       = 100;
heightEdit      = 25;
textHeight1     = 52;
buttonHeight    = 30;
buttonWidth     = 200;
deltaWidth      = 45;
deltaHeight     = 30;
cbWidth         = 90;
cbHeight        = 25;
p2textHeight    = 16;
distanceLeft    = 10;
distanceTop     = 10;
heightP1        = 4*distanceTop+textHeight1+heightEdit;
widthP1         = 550;
heightP2        = 5*distanceTop+p2textHeight+2*cbHeight+...
    (numberOfLines-1)*deltaHeight;
widthP2         = 550;
width           = 570;
height          = heightP1+heightP2+4*distanceTop+buttonHeight;

% defining some values for the gui
scrsize         = get(0,'screensize');
windowPosition  = [(scrsize(3) - width)/2, ...
    (scrsize(4) - height)/2, width, height];
color1          = [0.7 0.7 0.7];
color2          = [1 1 1];
fontsize        = 12;
fontname        = 'Times New Roman';

% defining gui elements
fh              = figure('Visible','off',...
    'Position',windowPosition,...
    'MenuBar','none',...
    'ToolBar','figure',...
    'Color',color1);

%% Panel 1

panels{1,1} = uipanel('Parent',fh,...
    'Title','Selection: Number of groups',...
    'Units','pixels',...
    'Position',[distanceLeft, height-distanceTop-heightP1, ...
    widthP1 heightP1],...
    'FontSize',fontsize);

ht{1,1} = uicontrol('Parent',panels{1,1},...
    'Style','text',...
    'Position',[distanceLeft, heightEdit+distanceTop*2, ...
    widthP1-2*distanceLeft textHeight1],...
    'FontSize',fontsize,...
    'FontName',fontname,...
    'HorizontalAlignment','left',...
    'String',['Choose the number in how many groups the training data '...
    'should be divided for the calculation of the cross validation '...
    'error (if you choose the maximum number of groups that are '...
    'possible you perform a leave-one-out cross validation):']);

he = uicontrol('Parent',panels{1,1},...
    'Style','edit',...
    'String',num2str(inStruct.numberOfBatches),...
    'backgroundcolor',color2,...
    'Position',[distanceLeft distanceTop widthEdit heightEdit],...
    'Callback',@correctEditInput);

ht{2,1} = uicontrol('Parent',panels{1,1},...
    'Style','text',...
    'Position',[2*distanceLeft+widthEdit, distanceTop/2, ...
    370 25],...
    'FontSize',fontsize,...
    'FontName',fontname,...
    'HorizontalAlignment','left',...
    'String',['(maximum number of groups: ',...
    num2str(size(obj.input,1)),')']);

%% Panel 2
panels{2,1} = uipanel('Parent',fh,...
    'Title','Selection: Complexity of models for cross validation',...
    'Units','pixels',...
    'Position',[distanceLeft, distanceTop*2+buttonHeight, ...
    widthP2 heightP2],...
    'FontSize',fontsize);

ht{3,1} = uicontrol('Parent',panels{2,1},...
    'Style','text',...
    'Position',[distanceLeft, ...
    numberOfLines*deltaHeight+distanceTop*2+cbHeight, ...
    widthP2-2*distanceLeft p2textHeight],...
    'FontSize',fontsize,...
    'FontName',fontname,...
    'HorizontalAlignment','left',...
    'String',['For which model complexity (number of local models) ',...
    'should the cross validation error be calculated:']);

ypos    = numberOfLines*deltaHeight+distanceTop;

cbAll{1,1} = uicontrol(panels{2,1},...
    'Style','checkbox',...
    'String','Select all',...
    'Value',0,...
    'Position',[distanceLeft ypos 80 cbHeight],...
    'Callback',@selectModelComplexities);

cbAll{2,1} = uicontrol(panels{2,1},...
    'Style','checkbox',...
    'String','Deselect all',...
    'Value',0,...
    'Position',[distanceLeft*2+71 ypos 95 cbHeight],...
    'Callback',@selectModelComplexities);

hcb                     = cell(max(numberOfComplexities),1);
xpos                    = distanceLeft;
ypos                    = (numberOfLines-1)*deltaHeight+distanceTop;
for complexity = 1:max(numberOfComplexities)
    
    flag = intersect(inStruct.modelComplexities,complexity);
    if isempty(flag)
        value = 0;
    else
        value = 1;
    end
    
    hcb{complexity,1}   = uicontrol(panels{2,1},...
        'Style','checkbox',...
        'String',num2str(complexity),...
        'Value',value,...
        'Position',[xpos ypos cbWidth cbHeight]);
    
    xpos = xpos + deltaWidth;
    if mod(complexity,cbPerLine) == 0
        ypos = ypos - deltaHeight;
        xpos = distanceLeft;
    end
end

%% Buttons

hButton{1,1} = uicontrol('Parent',fh,...
    'Style','Pushbutton',...
    'String','Perform cross validation',...
    'Enable','on',...
    'Position',[distanceLeft distanceTop buttonWidth buttonHeight],...
    'Callback',@performCrossValidation);

hButton{2,1} = uicontrol('Parent',fh,...
    'Style','Pushbutton',...
    'String','Save cross validation settings',...
    'Enable','on',...
    'Position',[distanceLeft*2+buttonWidth, distanceTop, ...
    buttonWidth buttonHeight],...
    'Callback',@saveCrossValidationSettings);

set(fh,'Visible','on');

%% Callback functions

    function correctEditInput(hObject,eventdata)
        size(hObject);
        size(eventdata);
        
        editValue           = str2double(get(he,'String'));
        if editValue > size(obj.input,1)
            set(he,'String',num2str(size(obj.input,1)));
        end
    end % end correctEditInput

    function saveCrossValidationSettings(hObject,eventdata,crossvalidationValues)
        size(hObject);
        size(eventdata);
        
        [noGroups,noLM]     = getCrossValidationOptions;
        
        if nargin == 3
            obj.history.kFoldCVlossFunction = crossvalidationValues;
        end
        
        outStruct.numberOfBatches       = noGroups;
        outStruct.modelComplexities     = noLM;
        obj.kFold                       = noGroups;
        
        close(fh);
    end % end saveCrossValidationSettings

    function performCrossValidation(hObject,eventdata)
        
        [noGroups,noLM]     = getCrossValidationOptions;
        
        % Show warning because of the time consuming cross validation error
        % calculation
        if warningFlag
            [flag,warningFlag] = GUItrain.crossValidationWarning;
        else
            flag = 1;
        end
        if flag
            obj.history.kFoldCVlossFunction(noLM) = obj.crossvalidation(noGroups,noLM);
            saveCrossValidationSettings(hObject,eventdata,obj.history.kFoldCVlossFunction);
        end
    end % end performCrossValidation

    function [noGroups,localModels] = getCrossValidationOptions()
        % get number of groups seeked by the user
        noGroups            = str2double(get(he,'String'));
        
        % if the user input seeks more groups than data points are in the
        % dataset, set the number of groups to the number of points in the
        % dataset
        if noGroups > size(obj.input,1)
            noGroups        = size(obj.input,1);
        end
        
        % save for which model complexities the cross validation error
        % should be calculated
        noLM                = cell2mat(get([hcb{:}],'Value'));
        noLM                = find(noLM);
        localModels         = noLM';
        
        % if the seeked number of groups changes, all cross validation
        % errors has to be calculated again. Therefor the old values has to
        % be deleted
        if inStruct.numberOfBatches ~= noGroups
            obj.history.kFoldCVlossFunction = [];
        end
    end % end getCrossValidationOptions

    function selectModelComplexities(hObject,eventdata)
        size(hObject);
        size(eventdata);
        
        selectAllValue      = get(cbAll{1,1},'Value');
        deselectAllValue    = get(cbAll{2,1},'Value');
        
        if hObject == cbAll{1,1}
            % function call from the 'Select all' checkbox
            if selectAllValue
                set([hcb{:}],'Value',1);
                set(cbAll{2,1},'Value',0);
            end
        else
            % function call from the 'Deselect all' checkbox
            if deselectAllValue
                set([hcb{:}],'Value',0);
                set(cbAll{1,1},'Value',0);
            end
        end
        
    end % end selectModelComplexities

waitfor(fh);

end