#!/usr/bin/env python

import unittest

from pfr_bitstream import PFRBitstream


class SPIRegionDef(PFRBitstream):
    """
    This is a class that represents Platform Firmware Manifest SPI region definition.
    """

    # When there's no hash present, SPI region definition is the smallest.
    _DEFINITION_MIN_SIZE = 16

    # Protection level mask
    _PROTECT_MASK_READ_ALLOWED               = 0b1
    _PROTECT_MASK_WRITE_ALLOWED              = 0b10
    _PROTECT_MASK_RECOVER_ON_FIRST_RECOVERY  = 0b100
    _PROTECT_MASK_RECOVER_ON_SECOND_RECOVERY = 0b1000
    _PROTECT_MASK_RECOVER_ON_THIRD_RECOVERY  = 0b10000

    def __init__(self):
        super(SPIRegionDef, self).__init__()
        self.hash_size = {"SHA256": 32, "SHA384": 48, "SHA512": 64}

    def initialize(self, byte_array=None, size=_DEFINITION_MIN_SIZE):
        if byte_array:
            super(SPIRegionDef, self).initialize(byte_array=byte_array)
        else:
            super(SPIRegionDef, self).initialize(size=size)
            self.set_def_type(1)
            self.set_protection_mask(3)
            self.set_hash_algorithm(0)

    def validate(self):
        # definition type: PFM SPI region definition
        if not self.read_def_type() == SPIRegionDef.def_type():
            raise ValueError("The definition type is wrong in this PFM SPI region definition.")

        # Reserved
        if not self.get_value(offset=0x4, size=4) == 0xFFFFFFFF:
            raise ValueError("The reserved area is corrupted this PFM SPI region definition.")

        # There should be hash present for static region but not for dynamic region
        if self.protection_mask() & SPIRegionDef._PROTECT_MASK_WRITE_ALLOWED:
            # In dynamic region, hash should not be present
            if self.hash_algorithm() != 0:
                raise ValueError("There cannot be hash present, when this SPI region is writable. ")
        else:
            if self.protection_mask() & SPIRegionDef._PROTECT_MASK_READ_ALLOWED and self.hash_algorithm() == 0:
                # In static region, there must be hash present
                raise ValueError("There must be hash present in this SPI region definition, " + hex(self.start_addr()) +
                                 " - " + hex(self.end_addr()) + ", "
                                 "because this is a read-only (static) SPI region. ")

        # Size check
        self.check_size()

        # For phase 1, assume only 1 hash (SHA256).
        if self.hash_algorithm() & 0b1 and self.hash_algorithm() & 0b10:
            raise ValueError("This SPI region definition contains two region hashes, which is not supported. ")

        return super(SPIRegionDef, self).validate()

    def read(self, fp):
        if fp.closed:
            return
        # Read until hash area
        self.append(byte_array=bytearray(fp.read(SPIRegionDef._DEFINITION_MIN_SIZE)))

        if self.hash_algorithm() & 0b1:
            self.append(byte_array=bytearray(fp.read(self.hash_size["SHA256"])))
        elif self.hash_algorithm() & 0b10:
            self.append(byte_array=bytearray(fp.read(self.hash_size["SHA384"])))

    @staticmethod
    def def_type():
        return 0x1

    @staticmethod
    def def_min_size():
        return SPIRegionDef._DEFINITION_MIN_SIZE

    @staticmethod
    def def_max_size():
        # TODO: Only allow 1 hash and the largest hash allowed is SHA 384
        return SPIRegionDef._DEFINITION_MIN_SIZE + 48

    def read_def_type(self):
        return self.get_value(0x0, size=1)

    def set_def_type(self, value=1):
        self.set_value(value=value, offset=0x0, size=1)

    def protection_mask(self):
        """
        Return region protection level mask.
        """
        return self.get_value(0x1, size=1)

    def set_protection_mask(self, value):
        """
        Set region protection level mask with value.
        """
        self.set_value(value=value, offset=0x1, size=1)

    def hash_algorithm(self):
        """
        Return SPI region hash algorithm
        """
        return self.get_value(0x2, size=2)

    def set_hash_algorithm(self, value):
        """
        Set SPI region hash algorithm
        """
        if value == 0b1:
            self.append(size=self.hash_size["SHA256"])
        elif value == 0b10:
            self.append(size=self.hash_size["SHA384"])
        elif value == 0b100:
            self.append(size=self.hash_size["SHA512"])
        self.set_value(value=value, offset=0x2, size=2)

    def start_addr(self):
        """
        Return SPI region start address
        """
        return self.get_value(0x8, size=4)

    def set_start_addr(self, value):
        """
        Set SPI region start address
        """
        self.set_value(value=value, offset=0x8, size=4)

    def end_addr(self):
        """
        Return SPI region end address
        """
        return self.get_value(0xC, size=4)

    def set_end_addr(self, value):
        """
        Set SPI region end address
        """
        self.set_value(value=value, offset=0xC, size=4)

    def calc_size(self):
        if self.hash_algorithm() & 0b1:
            return SPIRegionDef.def_min_size() + self.hash_size["SHA256"]
        elif self.hash_algorithm() & 0b10:
            return SPIRegionDef.def_min_size() + self.hash_size["SHA384"]
        return SPIRegionDef.def_min_size()

    def check_size(self):
        if self.size() < SPIRegionDef.def_min_size():
            raise ValueError("The actual size (" + str(self.size()) +
                             " bytes) of this SPI region definition is less than the minimum size.")

        if self.size() != self.calc_size():
            raise ValueError("The actual size (" + str(self.size()) +
                             " bytes) of this SPI region definition does not match the expected size ("
                             + str(self.calc_size()) + " bytes). ")


class SPIRegionDefTest(unittest.TestCase):

    def test_initialize(self):
        region_def = SPIRegionDef()
        region_def.initialize()
        region_def.validate()

    def test_example_region_def(self):
        region_def = SPIRegionDef()
        region_def.initialize()
        region_def.set_hash_algorithm(value=0b0)
        region_def.set_protection_mask(value=0b1011)
        region_def.set_start_addr(135)
        region_def.set_end_addr(144)
        region_def.validate()
        self.assertEqual(region_def.size(), 16)
        self.assertEqual(region_def.hash_algorithm(), 0)
        self.assertEqual(region_def.protection_mask(), 0b1011)
        self.assertEqual(region_def.start_addr(), 135)
        self.assertEqual(region_def.end_addr(), 144)


if __name__ == '__main__':
    unittest.main()
