%
% Runs the basic fragment based method with fragment selection
% from random pool without attention mechanism. Produces the RP
% column of the paper for cars (see README).
%

clear;

load Seeds/Seed_one SeedState;
rand('state',SeedState); % Seed_one is the seed just after reset.

load datasets/carsL positives negatives;

% Mix them:
idx_perm  = randperm(size(positives,3));
positives = positives(:,:,idx_perm);
idx_perm  = randperm(size(negatives,3));
negatives = negatives(:,:,idx_perm);
clear idx_perm;


% Prepare n-fold cv:
n_folds = 10;
n_pos = size(positives,3);
n_neg = size(negatives,3);

pos_in_fold = floor(n_pos/n_folds);
neg_in_fold = floor(n_neg/n_folds);

n_pairs = min([n_pos n_neg]);
pos_in_fold = floor(n_pairs/n_folds);
neg_in_fold = floor(n_pairs/n_folds);

folds = [];
all_pos_idx = [];
all_neg_idx = [];
for k=1:n_folds,
   folds(k).pos_idx = ((k-1)*pos_in_fold+1):(k*pos_in_fold);
   all_pos_idx = union(all_pos_idx,folds(k).pos_idx);
   folds(k).neg_idx = ((k-1)*neg_in_fold+1):(k*neg_in_fold);
   all_neg_idx = union(all_neg_idx,folds(k).neg_idx);
   folds(k).train_stat     = [];
   folds(k).test_stat      = [];
   folds(k).saved_features = [];
   folds(k).W              = [];
   folds(k).theta          = [];
end
all_positives = positives(:,:,all_pos_idx);
all_negatives = negatives(:,:,all_neg_idx);

%all_positives = make_noise(all_positives);
%all_negatives = make_noise(all_negatives);

load cand_spec cand_spec;

for current_fold=1:n_folds,
   
   fprintf(1,'\nFold %d starting..\n', current_fold);
   
   % Prepare to train:
   clear positives;
   clear negatives;
   positives = all_positives(:,:,folds(current_fold).pos_idx);
   negatives = all_negatives(:,:,folds(current_fold).neg_idx);
   
   fprintf(1,'\nThis training set has %d pos and %d neg examples.\n',size(positives,3),size(negatives,3));
   
   saved_features        = [];
   match_table_saved     = [];
   max_feature_pool_size = 50;
   
   current_images = cat(3,positives,negatives);
   current_labels = [ones(size(positives,3),1); -ones(size(negatives,3),1)];
   
   %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
   
   % Pruning assumed:
   current_features    = extract_features_L(current_images,current_labels,...
      cand_spec(current_fold));
   fprintf(1,'\nThere are %d features\n',size(current_features,2));
   pause(1);
   match_table_current = compute_mtable_batch(current_features,current_images);
   fprintf(1,'\nMatch table ready..\n');  
   
   %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
   par1 = size(current_features,2);
   gain_tables2(1,1).gain = spalloc(par1,par1,par1*par1);
   gain_tables2(1,1).marks= spalloc(par1,par1,par1*par1);
   
   %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
   
   % 2: Growth..
   keep_going = 1;
   while keep_going,
      
      fprintf(1,'g');
      
      max_min_val = 0;
      max_min_idx = [];
      for i = 1:size(current_features,2),
         X_i = current_features(i).feature;
         % TH  = X_i.TH_range;
         
         [split_loc,okey] = generate_split(match_table_current(i,:),current_labels);
         
         th_idx = split_loc; % XXX: a hack
         X_i.th = split_loc;
         min_val = -inf;
         if okey,
            %for th_idx = 1:size(TH,2),
            %X_i.th  = TH(th_idx);
            min_val = inf;
            min_idx = [];
            X_i_occ = find(match_table_current(i,:)>=X_i.th);
            X_i_occ = X_i_occ(:);
            if size(saved_features,2)==0,
               min_val = measure_gain1c(X_i_occ,current_labels);
               min_idx = 0;
               % This thing is used only once.. The pool will have
               % 2+ features even if pruned...
               % => No need for table lookups..
            else
               for j = 1:size(saved_features,2),
                  X_j     = saved_features(j).feature;
                  % th_idx2 = X_j.th_idx;
                  if gain_tables2(1,1).marks(X_i.unique_id,X_j.unique_id)==1,
                     g = gain_tables2(1,1).gain(X_i.unique_id,X_j.unique_id);
                  else
                     X_j_occ = find(match_table_saved(j,:)>=X_j.th);
                     X_j_occ = X_j_occ(:);
                     g = measure_gain2c(X_i_occ,X_j_occ,current_labels);
                     gain_tables2(1,1).marks(X_i.unique_id,X_j.unique_id) = 1;
                     gain_tables2(1,1).gain(X_i.unique_id,X_j.unique_id) = g;
                  end
                  
                  if g < min_val,
                     min_val = g;
                     min_idx = j;
                  end
               end
            end
            % The worst opponent for i is now known (min_val,min_idx).
            % See if X_i has the least bad worst opponent:
            if min_val > max_min_val,
               % Best of the worst so far..
               max_min_val = min_val;
               max_min_idx = [i th_idx min_idx];
            end
         end
      end
      
      if max_min_val>0,
         % Something useful found..
         
         X_k      = current_features(max_min_idx(1)).feature;
         %TH       = X_k.TH_range;
         X_k.th   = max_min_idx(2); %%%: A hack
         X_k.th_idx = max_min_idx(2); % Save the th_idx !!
         X_k.gain = max_min_val; 
         % This is transient: Valid only for this instant..
         % fprintf(1,'.');
         
         if isempty(saved_features),
            saved_features(1).feature = X_k;
            match_table_saved = match_table_current(max_min_idx(1),:);
         else
            saved_features(size(saved_features,2)+1).feature = X_k;
            match_table_saved = [match_table_saved; match_table_current(max_min_idx(1),:)];
         end
         
         % Remove the feature from current features:
         idx_other_current_features = setdiff(1:size(current_features,2),max_min_idx(1));
         if not(isempty(idx_other_current_features)),
            current_features    = current_features(idx_other_current_features);
            match_table_current = match_table_current(idx_other_current_features,:);
         else
            current_features    = [];
            match_table_current = [];
         end
      else
         % Nothing useful..
         keep_going = 0;
      end
      
      
      if size(saved_features,2) >= max_feature_pool_size,
         keep_going = 0;
      end
      if isempty(current_features),
         keep_going = 0;
      end
      
      fprintf(1,'%d ',size(saved_features,2));
   end % End of the feature learning phase..
   %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
   % Clear countless megabytes of unnecessary stuff..
   clear gain_tables2;
   clear par1;
   
   fprintf(1,'\nFeature pool size = %d\n',size(saved_features,2));  
   
   % Change the data into svm format (just binarize)..
   svm_data = match_table_saved;
   for k = 1:size(saved_features,2),
      svm_data(k,:) = (match_table_saved(k,:) >= saved_features(k).feature.th);
   end
   addpath osu_svm;
   svm_answers = current_labels';
   % Learn a linear svm:
   [AlphaY,SVs,Bias,Parameters,nSV,nLabel] = LinearSVC(svm_data,svm_answers);
   svm = [];
   svm.AlphaY = AlphaY;
   svm.SVs = SVs;
   svm.Bias = Bias;
   svm.Parameters = Parameters;
   svm.nSV = nSV;
   svm.nLabel = nLabel;
   
   % Remember the classifier:
   folds(current_fold).saved_features = saved_features;
   folds(current_fold).svm = svm;
   
   % Prepare to test:
   clear positives;
   clear negatives;
   pos_idx_test = setdiff(all_pos_idx,folds(current_fold).pos_idx);
   neg_idx_test = setdiff(all_neg_idx,folds(current_fold).neg_idx);
   positives = all_positives(:,:,pos_idx_test);
   negatives = all_negatives(:,:,neg_idx_test);
   n_mistakes_pos = 0;
   n_correct_pos  = 0;
   n_mistakes_neg = 0;
   n_correct_neg  = 0;
   
   fprintf(1,'\nTesting with %d + %d images\n',size(positives,3),size(negatives,3));
   
   for test_idx = 1:size(positives,3),
      pred = predict_example_svm(saved_features,svm,positives(:,:,test_idx));
      if pred==1,
         n_correct_pos = n_correct_pos + 1;
      else
         n_mistakes_pos = n_mistakes_pos + 1;
      end
   end
   for test_idx = 1:size(negatives,3),
      pred = predict_example_svm(saved_features,svm,negatives(:,:,test_idx));
      if pred==1,
         n_mistakes_neg = n_mistakes_neg + 1;
      else
         n_correct_neg = n_correct_neg + 1;
      end
   end
   folds(current_fold).test_stat = ...
      [n_correct_pos n_mistakes_pos; n_correct_neg n_mistakes_neg];
   [n_correct_pos n_mistakes_pos; n_correct_neg n_mistakes_neg]
   
   %%%%%%%
   % Save:
   save results/roth_basic_10fold;
   pause(2);
   
end % Crossvalidation loop...

% Calculate stats, etc...

final_stats = validation_stats(folds,1,'r*');

% End.
