function obj = findInputCombinations(obj,strategy,minInputs,maxInputs,fixedInputs)
% Function should determine 'good' input combinations regarding to Akaikes
% Information Criterion (AIC). The user can define the search method, the
% minimum number of inputs and the maximum number of inputs. Additionally
% fixed inputs can be passed, that are not allowed to discard.
%
% obj = findInputCombinations(obj,strategy,minInputs,maxInputs,fixedInputs)
%
%
% INPUTS
%
% obj           sensitivityAnalysis object
%
% strategy      String: Defines the search strategy. Implemented search
%               strategies are: bruteForce, forwardSelection,
%               backwardElimination
%
% minInputs     Integer: minimum number of inputs
%
% maxInputs     Integer: maximum number of inputs
%
%
% fixedInputs 	[1 x nf]: Vector containing the column numbers of the
%               training data, that are prohibited for being discarded.
%
%
% OUTPUTS
%
% obj       sensitivityAnalysis object, which contains the results of the
%           demanded search

% Julian Belz, 31st-Oct-2012
% Institute of Mechanics & Automatic Control, University of Siegen, Germany
% Copyright (c) 2012 by Prof. Dr.-Ing. Oliver Nelles

% Check if obj contains data, otherwise no input combinations can be
% determined
if isempty(obj.input) || isempty(obj.output)
    fprintf('\n The object contains no data, so no input selection can be performed. \n');
    return;
end

% Make sure all input variables are set properly
if exist('strategy','var') && ~isempty(strategy)
    obj.strategy = strategy;
end
if ~exist('minInputs','var') || isempty(minInputs)
    minInputs = 0;
end
if nargin > 4
    if exist('fixedInputs','var')
        obj.fixedInputs = fixedInputs;
    end
end
if ~exist('maxInputs','var') || isempty(maxInputs) || maxInputs > size(obj.input,2)
    maxInputs = size(obj.input,2);
end

% Make sure minInputs is lower than maxInputs
if maxInputs < minInputs
    tmp = minInputs;
    minInputs = maxInputs;
    maxInputs = tmp;
end

% Check if the original in- and output is already passed to the
% corresponding properties. Necessary for further steps!
if isempty(obj.originalInput)
    obj.originalInput = obj.unscaledInput;
end
if isempty(obj.originalOutput)
    obj.originalOutput = obj.unscaledOutput;
end

% If there are results from former calculations, delete them.
obj.sensitivityResults = struct('model',{},'lossFunctionValue',{},'ranking',{});

% Plausibility check
if minInputs < size(obj.fixedInputs,2)
    minInputs = size(obj.fixedInputs,2);
end

if minInputs == 0
    minInputs4combinationMatrixDetermination = 1;
else
    minInputs4combinationMatrixDetermination = minInputs;
end

switch obj.strategy
    case 'bruteForce'
        
        % Determine the combination matrix with all possible input
        % combinations with respect to fixedInputs
        if isempty(obj.fixedInputs)
            obj.combinationMatrix = obj.buildCombinationMatrix(minInputs4combinationMatrixDetermination,size(obj.originalInput,2));
        else
            % Determine all possible combinations, as if the fixed
            % inputs can also be discarded
            obj.combinationMatrix = obj.buildCombinationMatrix(minInputs4combinationMatrixDetermination,size(obj.originalInput,2));
            
            % Set all fixed inputs to true in any input combination
            obj.combinationMatrix(:,obj.fixedInputs) = true(size(obj.combinationMatrix,1),size(obj.fixedInputs,2));
            
            % Delete redundant combinations
            obj.combinationMatrix = unique(obj.combinationMatrix,'rows');
            
        end
        
        % If the minimum number of inputs is zero, add that line seperately to
        % the combination matrix.
        if minInputs == 0
            obj.combinationMatrix = [false(1,size(obj.combinationMatrix,2));obj.combinationMatrix];
        end
        
        % Delete all input combinations in the combinationMatrix where more
        % than the maximum number of desired inputs are considered
        remainIdx = sum(obj.combinationMatrix,2)<=maxInputs;
        obj.combinationMatrix = obj.combinationMatrix(remainIdx,:);
        
        % Predefine temporary variables, which contain the results of the
        % input selection and will be passed to the object after finishing
        % all calculations
        lossFunctionValues = inf(1,size(obj.combinationMatrix,1));
%         aicValues = lossFunctionValues;
        lmnCell = cell(1,size(obj.combinationMatrix,1));
        lmnCell2 = lmnCell;
        lmnCell3 = lmnCell;
        
        % Loop over all input combinations
        parfor kk=1:size(obj.combinationMatrix,1)
            % Create training dataset for current input combination
            netz = obj.selectDataAndTrainModel(obj.combinationMatrix(kk,:),obj); %#ok<PFBNS>
            
            % Save current model
            lmnCell{1,kk} = netz;
            lmnCell2{1,kk} = find(obj.combinationMatrix(kk,:));
            lmnCell3{1,kk} = obj.determineAbbreviation(obj.combinationMatrix(kk,:));
            
            % Evaluate the lossFunctionValue for the trained model
            lossFunctionValues(kk) = obj.evaluateLossFunctionValue(netz,obj,kk);
            
%             if isempty(netz)
%                 % In case of no input while investigating the x-space,
%                 % there is no trained model. In this case set the model
%                 % outputs all to zero (another version could be the mean
%                 % value of all outputs).
%                 yHatVali = zeros(size(obj.input,1),1);
%                 lossFunctionValues(kk) = obj.calcGlobalLossFunction(obj.output,yHatVali);
%             elseif isempty(obj.validationInput)
%                 % Calculate AIC value
%                 lossFunctionValues(kk) = calcPenaltyLossFunction(netz);
% %                 [~, ~, ~, ~, lossFunctionValues(2,kk)] = calcPenaltyLossFunction(netz);
%             elseif strcmp(obj.space2investigate,'x')
%                 % Determine the error on validation data
%                 yHatVali = netz.calculateModelOutput(obj.validationInput(:,obj.combinationMatrix(kk,:)));
%                 lossFunctionValues(kk) = netz.calcGlobalLossFunction(obj.validationOutput,yHatVali);
%             else
%                 % This condition will be true, if the z-input space is
%                 % investigated and validation data is available
%                 yHatVali = netz.calculateModelOutput(obj.validationInput);
%                 lossFunctionValues(kk) = netz.calcGlobalLossFunction(obj.validationOutput,yHatVali);
%             end
        end
        
        % Save results to the object
        obj.sensitivityResults(1).model(1,:) = lmnCell;
        obj.sensitivityResults(1).model(2,:) = lmnCell2;
        obj.sensitivityResults(1).model(3,:) = lmnCell3;
        obj.sensitivityResults(1).lossFunctionValue = lossFunctionValues;
        
    case 'forwardSelection'
        
        % Create Matrix containing one subset in every line where at least
        % are minInputs contained
        combiMatrix = obj.buildCombinationMatrix(minInputs4combinationMatrixDetermination,size(obj.originalInput,2));
        
        % Make sure all fixedInputs are contained in every input
        % combination
        combiMatrix(:,obj.fixedInputs) = true(size(combiMatrix,1),size(obj.fixedInputs,2));
        
        % Delete redundant combinations, that can occur after setting all
        % fixedInputs to true
        combiMatrix = unique(combiMatrix,'rows');
        
        % Remain only the combinations, where the minimum number of inputs
        % is selected
        remainIdx = sum(combiMatrix,2)==minInputs;
        combiMatrix = combiMatrix(remainIdx,:);
        
        % Flag to determine, that an forward selection should be performed
        method = true;
        
    case 'backwardElimination'
        
        % Create Matrix containing one subset in every line where at least
        % are minInputs contained
        combiMatrix = obj.buildCombinationMatrix(minInputs4combinationMatrixDetermination,size(obj.originalInput,2));
        
        % Make sure all fixedInputs are contained in every input
        % combination
        combiMatrix(:,obj.fixedInputs) = true(size(combiMatrix,1),size(obj.fixedInputs,2));
        
        % Delete redundant combinations, that can occur after setting all
        % fixedInputs to true
        combiMatrix = unique(combiMatrix,'rows');
        
        % Remain only the combinations, where the maximum number of inputs
        % is used
        remainIdx = sum(combiMatrix,2)==maxInputs;
        combiMatrix = combiMatrix(remainIdx,:);
        
        % Flag to determine, that an backward elimination should be performed
        method = false;
        
end

% Only for the bruteForce strategy everything has already been calculated
% within the switch statement
if ~strcmp(obj.strategy,'bruteForce')
    
    % If the minimum number of inputs is zero, add that line seperately to
    % the combination matrix.
    if minInputs == 0
        combiMatrix = [false(1,size(combiMatrix,2));combiMatrix];
    end
    
    % Pass everything to the sequentialSearch function, where the addition
    % or removal of one input per loop is carried out as long as the number
    % of minInputs or maxInputs has not been reached
    [obj.sensitivityResults(1).lossFunctionValue,...
        obj.sensitivityResults(1).model,...
        obj.combinationMatrix] = ...
        sequentialSearch(combiMatrix,method,maxInputs-minInputs,obj.fixedInputs,obj);
    
end

% Create entries for the sensitivityResults property of the object
[~, minIdx] = sort(obj.sensitivityResults.lossFunctionValue);
obj.sensitivityResults(1).ranking = cell(3,length(minIdx)+1);
obj.sensitivityResults(1).ranking{1,1} = 'Used Inputs:';
obj.sensitivityResults(1).ranking{2,1} = 'Relative error related to the smallest error:';
obj.sensitivityResults(1).ranking{3,1} = 'Abbreviation of used inputs:';
for kk=1:length(minIdx)
    obj.sensitivityResults(1).ranking{1,kk+1} = find(obj.combinationMatrix(minIdx(kk),:));
    obj.sensitivityResults(1).ranking{2,kk+1} = obj.sensitivityResults.lossFunctionValue(minIdx(kk))/obj.sensitivityResults.lossFunctionValue(minIdx(1));
    obj.sensitivityResults(1).ranking{3,kk+1} = obj.determineAbbreviation(obj.combinationMatrix(minIdx(kk),:));
end


% Show information about the performed sensitivity analysis
fprintf(['\n Chosen Strategy: ',obj.strategy,'\n Fixed Inputs: [',num2str(obj.fixedInputs),'] \n \n']);


    function [lossFunctionValues,lmnCell,combinationMatrix] = sequentialSearch(combinationMatrix,method,numberOfLoops,fixedInputs,netz)
        
        % Check if fixedInputs has been passed. The contained information
        % is only needed for the backward elimination procedure
        if ~exist('fixedInputs','var') || isempty(fixedInputs)
            fixedInputs = NaN;
        end
        
        % First determine which possible combination contained in
        % the combinationMatrix is the 'best' to initialize the
        % sequential adding/removal of further input variables
        lossFunctionValues = inf(1,size(combinationMatrix,1));
%         aicValues = lossFunctionValues;
        lmn = cell(1,size(combinationMatrix,1));
        parfor ii=1:size(combinationMatrix,1)
            
            % Create training dataset for current input combination
            lmn{1,ii} = netz.selectDataAndTrainModel(combinationMatrix(ii,:),netz); %#ok<PFBNS>
            
            % Save lossFunction value of the current input variable subset
            lossFunctionValues(ii) = netz.evaluateLossFunctionValue(lmn{1,ii},netz,combinationMatrix(ii,:));
%             lossFunctionValues(ii) = calcPenaltyLossFunction(lmn{1,ii});
%             [~, ~, ~, ~, lossFunctionValues(ii)] = calcPenaltyLossFunction(lmn{1,ii});
            
        end
        
        % Pick the best input combination
        [~,idx] = min(lossFunctionValues);
        
        % Because of the sequential search, only the best combination
        % contained in the combinationMatrix will be passed back together
        % with further input combinations with more or less input variables
        combinationMatrix = [combinationMatrix(idx,:);repmat(method,numberOfLoops,size(combinationMatrix,2))];
        
        % Predefine the output variable lmnCell, that contains the models,
        % that correspond to the input combinations contained in the
        % combinationMatrix
        lmnCell = cell(3,numberOfLoops+1);
        lmnCell{1,1} = lmn{1,idx};
        lmnCell{2,1} = find(combinationMatrix(1,:));
        lmnCell{3,1} = netz.determineAbbreviation(combinationMatrix(1,:));
        
        % Predefine the output variable lossFunctionValues
        lossFunctionValues = [lossFunctionValues(1,idx) inf(1,numberOfLoops)];
        
        % Determine the missing values for the missing numbers of input
        % variables
        if method
            % Method is true if a forward selection should be performed
            % All not yet implemented inputs are passed to the variable
            % remainingInputs
            remainingInputs = find(abs(combinationMatrix(1,:)-1));
        else
            % Method is false if a backward elimination should be performed
            % All not yet removed inputs are passed to the variable
            % remainingInputs
            remainingInputs = find(combinationMatrix(1,:));
        end
        for ii=1:numberOfLoops
            
            % Predefine variables to contain the loss function values and
            % the corresponding model for the remainingInputs
            lmnTMP = cell(1,size(remainingInputs,2));
            lossFunctionValuesTMP = inf(1,size(remainingInputs,2));
            
            % Calculate the loss function for the addition/removal of every
            % remaining input
            parfor ll=1:size(remainingInputs,2)
                
                % All already contained inputs, before adding or removing the
                % next input
                combinationVector = combinationMatrix(ii,:); %#ok<PFBNS>
                
                % Add/remove input. remainingInputs(ll) points to the input
                % that has to be added or removed. Method contains the
                % information, if the input should be added or removed
                combinationVector(remainingInputs(ll)) = method;
                
                % Perform training and loss function evaluation in case of
                % a backward elimination only if the currently chosen input
                % is not contained within the fixed inputs. Fixed inputs is
                % set to NaN, if there are no fixed inputs or a forward
                % selection should be performed. In case of an forward
                % selection the fixed inputs are already considered
                % earlier.
                if sum(fixedInputs==remainingInputs(ll)) < 1
                    
                    % Train model with currently chosen input combination
                    lmnTMP{1,ll} = netz.selectDataAndTrainModel(combinationVector,netz); %#ok<PFBNS>
                    
                    % Evaluate the lossFunctionValue for the trained model
                    lossFunctionValuesTMP(1,ll) = netz.evaluateLossFunctionValue(lmnTMP{1,ll},netz,combinationVector);
                    %    lossFunctionValuesTMP(1,ll) = calcPenaltyLossFunction(lmnTMP{1,ll});
%                     [~, ~, ~, ~, lossFunctionValuesTMP(1,ll)] = calcPenaltyLossFunction(lmnTMP{1,ll});

                    
                end
                
            end
            
            % Find out which removal/addition led to the 'best' result
            [~,idx] = min(lossFunctionValuesTMP);
            
            % Update the combinationMatrix
            combinationMatrix(ii+1,:) = combinationMatrix(ii,:);
            combinationMatrix(ii+1,remainingInputs(idx)) = method;
            
            % Update the remainingInputs
            remainingInputs(idx) = [];
            
            % Save the best model and the corresponding loss function value
            lmnCell{1,ii+1} = lmnTMP{1,idx};
            lossFunctionValues(1,ii+1) = lossFunctionValuesTMP(1,idx);
            
            % Save additional information such as the input variables and
            % the abbreviation for the input combination
            lmnCell{2,ii+1} = find(combinationMatrix(ii+1,:));
            lmnCell{3,ii+1} = netz.determineAbbreviation(combinationMatrix(ii+1,:));
            
        end
        
    end

end
