#!/usr/bin/env python

import os
import io
import errno
import unittest

from pfr_bitstream import PFRBitstream
from pfm_data import PFMData
from pfm_spi_region import SPIRegionDef
from pfm_smbus_rule import SMBusRuleDef
from convert import Convert


class PFMFactory:
    """
    This is a class that provides read and write utilities of PFM data.
    """

    def __init__(self):
        pass

    def peek_value(self, stream, size):
        """
        Return the value stored at the offset of interest without moving the file pointer.
        """
        read_bytes = stream.peek(size)
        if len(read_bytes) < size:
            raise ValueError('Unknown Stream - Too Short')
        return Convert().bytearray_to_integer(read_bytes, offset=0, num_of_bytes=size)

    def mkdir_p(self, path):
        """
        Safe mkdir implementation. Equivalent to "mkdir -p" in linux command.
        """
        if path is None or path == "":
            return
        try:
            os.makedirs(path)
        except OSError as exc:
            if exc.errno == errno.EEXIST and os.path.isdir(path):
                pass
            else:
                raise

    def generate_from_file(self, filename):
        """
        Return the PFM object that is represented by the filename binary file.
        """
        if not os.path.isfile(filename):
            raise ValueError("Cannot find file " + filename + ".")

        with io.open(filename, "rb") as stream:
            magic = self.peek_value(stream=stream, size=4)

            if magic == PFMData.magic():
                # This is a PFM data structure
                data = PFMData()
                data.read(fp=stream)
                data.validate()
                return data

            def_type = self.peek_value(stream=stream, size=1)
            if def_type == SPIRegionDef.def_type():
                # This is a PFM SPI region definition
                region_def = SPIRegionDef()
                region_def.read(fp=stream)
                region_def.validate()
                return region_def
            elif def_type == SMBusRuleDef.def_type():
                # This is a PFM SMBus rule definition
                rule_def = SMBusRuleDef()
                rule_def.read(fp=stream)
                rule_def.validate()
                return rule_def

        raise ValueError("The input bitstream does not match any PFM format. ")

    def save_to_file(self, pfm_bistream, filename):
        """
        Store the PFM object as a binary file.
        """
        assert(isinstance(pfm_bistream, PFRBitstream))
        # Create directory
        file_dir = os.path.dirname(filename)
        self.mkdir_p(file_dir)

        with io.open(filename, "wb+") as stream:
            pfm_bistream.write(fp=stream)


class PFMFactoryTest(unittest.TestCase):

    def test_rw_on_pfm_data(self):
        data = PFMData()
        data.initialize()

        factory = PFMFactory()
        binary_f = "example_pfm_data.bin"
        factory.save_to_file(data, binary_f)

        generated_data = factory.generate_from_file(filename=binary_f)
        generated_data.validate()

    def test_rw_on_pfm_spi_region_def(self):
        region_def = SPIRegionDef()
        region_def.initialize()

        factory = PFMFactory()
        binary_f = "example_spi_region.bin"
        factory.save_to_file(region_def, binary_f)

        generated_region_def = factory.generate_from_file(filename=binary_f)
        generated_region_def.validate()

    def test_rw_on_pfm_smbus_rule_def(self):
        rule_def = SMBusRuleDef()
        rule_def.initialize()
        rule_def.set_bus_id(2)
        rule_def.set_rule_id(5)
        rule_def.set_device_addr(0xDC)

        factory = PFMFactory()
        binary_f = "example_smbus_rule.bin"
        factory.save_to_file(rule_def, binary_f)

        generated_rule_def = factory.generate_from_file(filename=binary_f)
        generated_rule_def.validate()

        self.assertEquals(generated_rule_def.bus_id(), 2)
        self.assertEquals(generated_rule_def.rule_id(), 5)
        self.assertEquals(generated_rule_def.device_addr(), 0xDC)

    def test_long_pfm_data(self):
        # TODO: Break this test down once the data structure is finalized.
        # TODO: waiting on PFM magic number and value for reserved area
        data = PFMData()
        data.initialize(size=256)

        # Instantiate some rule definitions
        rule_def1 = SMBusRuleDef()
        rule_def1.initialize()
        rule_def1.set_bus_id(1)
        rule_def1.set_rule_id(2)
        rule_def1.set_device_addr(0xAC)

        rule_def2 = SMBusRuleDef()
        rule_def2.initialize()
        rule_def2.set_bus_id(2)
        rule_def2.set_rule_id(8)
        rule_def2.set_device_addr(0xB0)

        rule_def3 = SMBusRuleDef()
        rule_def3.initialize()
        rule_def3.set_bus_id(2)
        rule_def3.set_rule_id(11)
        rule_def3.set_device_addr(0xE4)

        # Instantiate some region definitions
        region_def1 = SPIRegionDef()
        region_def1.initialize()
        region_def1.set_protection_mask(3)
        region_def1.set_start_addr(176)
        region_def1.set_end_addr(190)

        region_def2 = SPIRegionDef()
        region_def2.initialize()
        region_def2.set_protection_mask(0b11011)
        region_def2.set_start_addr(208)
        region_def2.set_end_addr(268)

        region_def3 = SPIRegionDef()
        region_def3.initialize()
        region_def3.set_protection_mask(1)
        region_def3.set_start_addr(268)
        region_def3.set_end_addr(500)
        region_def3.set_hash_algorithm(0b1)

        # Add definitions to PFM
        data.append_pfm_def(rule_def1)
        data.append_pfm_def(rule_def2)
        data.append_pfm_def(rule_def3)
        data.append_pfm_def(region_def1)
        data.append_pfm_def(region_def2)
        data.append_pfm_def(region_def3)
        data.set_padding()

        factory = PFMFactory()
        binary_f = "example_complex_pfm_data.bin"
        factory.save_to_file(data, binary_f)

        generated_data = factory.generate_from_file(filename=binary_f)
        factory.save_to_file(generated_data, "example_complex_pfm_data_dup.bin")
        generated_data.validate()

        self.assertEquals(generated_data.get_pfm_def(0).bus_id(), 1)
        self.assertEquals(generated_data.get_pfm_def(1).bus_id(), 2)
        self.assertEquals(generated_data.get_pfm_def(2).bus_id(), 2)


if __name__ == '__main__':
    unittest.main()
