""" Naive Bayes Classifier """

import numpy as np

class NaiveBayes(object):
    def __init__(self, training_data):
        self.feature_count = len(training_data[0][0])
        self.data_size = len(training_data)
        # Tabulka s pocty instanci v jednotlivych tridach
        self.labels_counts = {}
        # Tabulka s hodnotami rysu v jednotlivych tridach
        self.cond_feature_counts = {}

        for feature_vector, label in training_data:
            # Pokud jsme label, jeste nevideli, zalozime ji polozku v obou
            # dvou tabulkach
            if label not in labels_counts:
                # Znacku jsme jeste nevideli ani jednou
                self.labels_counts[label] = 0
                # Pripravime zvlastni tabulku pro kazdy rys
                self.cond_feature_counts[label] = [{}] * feature_count

            self.labels_counts[label] += 1
            for i in range(self.feature_count):
                feature_value = feature_vector[i]
                if feature_value not in self.cond_feature_counts[label][i]:
                    self.cond_feature_counts[label][i][feature_value] = 1
                else
                    self.cond_feature_counts[label][i][feature_value] += 1
                

    def classify(self, instance):
        # Pro zajisteni numericke stability si rozepiseme nasobeni
        # pravdepodobnosti jako nejprve nasobeni vsech citatelu a potom deleni vsemi
        # jmenovateli
        best_score = 0
        best_label = None

        # Projdu vsechny mozne labely a pro kazdy spocitam skore
        for label in self.labels_counts:
            label_count = self.labels_counts[label]
            count_product = np.prod(
                [ self.cond_feature_counts[label][i].get(feature_value, default=0.0) 
                        for i, feature_value in enumerate(instance) ])
            score = count_product / 
                labels_counts ** (self.feature_count - 1.0) / self.data_size 
            if score > best_score:
                best_score = score
                best_label = label
        
        return best_label
