import sys
import os 
import re
from time import sleep
from subprocess import Popen
import getpass
import traceback
import logging

import utils

KEYWORDS = ['.bss','.data','.initialization_data','.initialization','.rodata','.DDR_TEXT','.DDR_BSS','.DDR_DATA']
KEYWORDS_text = ['.text']

def unhandeledException(type, value, tb):
    print( "\n\nUNHANDELED EXCEPTION: "+os.path.basename(sys.argv[0])+"\n")
    traceback.print_exception(type, value, tb)
    os.system("pause")
    sys.exit(1)

class map_parser:
	def __init__(self, branch_path_name):
		self.debug = True
		self.parent_dir = os.path.join(branch_path_name, "firmware")
		self.compilation_output = os.path.join(branch_path_name, "output")
		self.out = os.path.join(self.compilation_output, "memory_analysis")
		self.ModulesList = []
		self.ext_list = [".c", ".mip", ".s"]

		
	def validate(self, argv):
		if len(argv) < 1:
			utils.error("syntax: map_parser must get at lease one argument")
		
	def prepare(self):
		# output_dir = os.path.join(self.out, "output")
		utils.recreate_dir(self.out)
		
	def calucalteModulesSize(self, mapfile, startLine, endLine, globalSymbolsSize, resultFileHandle):
		detailedObjList = []
		self.stack_size = 0
		stack_line = 0
	
		if self.arc_analysis is True:
			line_number = 0
			for line in mapfile[startLine:endLine]:
				line_number += 1
				if line.startswith(".text    text"):
					text_start_line = startLine + line_number - 1
				if line.startswith(".rodata  lit"):
					data_start_line = startLine + line_number - 1
				if line.startswith(".end_of_ram"):
					data_end_line = startLine + line_number - 1
				if line.startswith(".stack   bss"):
					stack_line = startLine + line_number - 1

			if stack_line is not 0:
				self.stack_size =  utils.hex_to_dec(mapfile[stack_line].split()[-1])
			
			object_list_text_size = {}
			for line in mapfile[text_start_line:data_start_line]:
				if line.endswith(".o\n"):
					line_split = line.split()
					object_name = line_split[-1].strip("./")
					if object_name not in object_list_text_size:
						object_list_text_size[object_name] = 0
					object_list_text_size[object_name] += utils.hex_to_dec(line_split[-2])
		
			object_list_data_size = {}
			for line in mapfile[data_start_line:data_end_line]:
				if line.endswith(".o\n"):
					line_split = line.split()
					object_name = line_split[-1].strip("./")
					if object_name not in object_list_data_size:
						object_list_data_size[object_name] = 0
					object_list_data_size[object_name] += utils.hex_to_dec(line_split[-2])
		else:
			
			object_list_text_size = {}
			for line in mapfile[startLine:endLine]:
				if any(x in line for x in KEYWORDS_text):
					line_split = line.replace("+", " ").split()
					object_name = line_split[-1]
					if object_name not in object_list_text_size:
						object_list_text_size[object_name] = 0
					object_list_text_size[object_name] += utils.hex_to_dec(line_split[1])
					
			object_list_data_size = {}
			for line in mapfile[startLine:endLine]:
				if any(x in line for x in KEYWORDS):
					line_split = line.replace("+", " ").split()
					object_name = line_split[-1]
					if object_name not in object_list_data_size:
						object_list_data_size[object_name] = 0
					object_list_data_size[object_name] += utils.hex_to_dec(line_split[1])
		
		detailedObjList.append(["\n*****************************************","","",""])
		detailedObjList.append(["* 	    Detailed Modules Size		    *","","",""])
		detailedObjList.append(["*****************************************","","",""])

		# Calculate Module Sizes
		resultFileHandle.write("========================================================\n")
		resultFileHandle.write("{:<30}".format('Module') + "Size = text + data [Bytes]\n")
		resultFileHandle.write("========================================================\n")
		
		# Write Global Symbols Size
		resultFileHandle.write("{:<30}".format('Global Symbols') +  "{:<10}".format(str(globalSymbolsSize)) + "\t=\t0 + "+ str(globalSymbolsSize) + "\n")
		resultFileHandle.write("{:<30}".format('Padding') +  "{:<10}".format(str(self.padding_size)) + "\t=\t0 + "+ str(self.padding_size) + "\n")
		resultFileHandle.write("{:<30}".format('Stack') +  "{:<10}".format(str(self.stack_size)) + "\t=\t0 + "+ str(self.stack_size) + "\n")
		totalMemorySize = globalSymbolsSize + self.padding_size + self.stack_size

		for Module in self.ModulesList:
			detailedObjList.append(["=========================================","","",""])
			detailedObjList.append([str(Module[0]),"","",""])
			detailedObjList.append(["=========================================","","",""])
			size = 0
			data_size = 0
			text_size = 0
			
			for i in range(1, len(Module)):
				module_size = object_list_data_size.get(Module[i],0) + object_list_text_size.get(Module[i],0)
				detailedObjList.append([Module[i],str(module_size),str(object_list_text_size.get(Module[i],0)),str(object_list_data_size.get(Module[i],0))])
				size += module_size
				data_size += object_list_data_size.get(Module[i],0)
				text_size += object_list_text_size.get(Module[i],0)
			resultFileHandle.write("{:<30}".format(Module[0]) +  "{:<10}".format(str(size)) + "\t=\t" +str(text_size) + " + " + str(data_size) + "\n")
			
			detailedObjList.append(["-----------------------------------------","","",""])
			detailedObjList.append([Module[0] + " Total Size:",size,"",""])
			detailedObjList.append(["-----------------------------------------","","","\n"])
			# Add Module Size to total Memory size
			totalMemorySize = totalMemorySize + size
				
		# Calculate all External Modules size
		detailedObjList.append(["=========================================","","",""])
		detailedObjList.append(["External Modules","","",""])
		detailedObjList.append(["=========================================","","",""])
		size = 0 
					
		detailedObjList.append(["-----------------------------------------","","",""])
		detailedObjList.append(["External Modules Total Size:",size,"",""])
		detailedObjList.append(["-----------------------------------------","","","\n"])
		
		resultFileHandle.write("{:<30}".format("External Modules") +  str(size) + "\n")
		# Add Module Size to total Memory size
		totalMemorySize = totalMemorySize + size
		
		resultFileHandle.write("-----------------------------------------\n")
		resultFileHandle.write("{:<30}{}".format('Total Size:',totalMemorySize) +"\n")
		resultFileHandle.write("-----------------------------------------\n\n\n")
		
		return detailedObjList
		

	'''
	buildModuleList()
	self.ModulesList = [['802.11protocol', 'frame.o', 'ieee_address.o', 'Protocol_AirTimeAndDuration.o'
, 'Protocol_WmeDefinitions.o'], ['Ager', 'Ager.o', 'AgerEmulator.o']]
	'''
	def buildModuleList(self, parent_dir):
		module_list = [item for item in os.listdir(parent_dir) if os.path.isdir(os.path.join(parent_dir, item))]

		for module in module_list:
			temp_ObjectList = []
			temp_ObjectList.append(module)
			topdir = os.path.join(parent_dir, module)

			for dirpath,dirnames,files in os.walk(topdir):
				for filename in files:
					filename, file_extension = os.path.splitext(filename)
					file_extension = file_extension.strip().lower()
					if file_extension in self.ext_list:
						ObjectListItem = filename + ".o"
						if ObjectListItem not in temp_ObjectList:
							temp_ObjectList.append(ObjectListItem)
			if len(temp_ObjectList) > 1:
				self.ModulesList.append(temp_ObjectList)
		# print("ModulesList",self.ModulesList)
		
	def build_compilation_flavor_list(self, parent_dir):
		self.flavor_list = [flavor for flavor in os.listdir(parent_dir) if os.path.isdir(os.path.join(parent_dir, flavor))]
		# self.flavor_list = ["asic_lower_mac_cpu_wave500b_ap_real_phy_wrx500_none"]
		# print(self.flavor_list)

	def calculateSectionSize(self, mapfile, startLine, endLine, resultFileHandle):

		iramImageSize = 0
		ddrImageSize = 0
		sharedRamSections = []
		sharedRamSectionsTable = []
		sharedRamDescSection = []
		sharedRamSize = 0
		sharedRamDescSize = 0
		temp = 0
		
		if self.arc_analysis is True:
			text_val_index = -1
			if "lower_mac" in self.compilationOutput:
				start_of_ram_section = ".vectors_lm0"
				end_of_ram_section = ".end_of_ram"
			if "lower1_mac" in self.compilationOutput:
				start_of_ram_section = ".vectors_lm1"
				end_of_ram_section = ".end_of_ram"
			if "upper_mac" in self.compilationOutput:
				start_of_ram_section = ".vectors_um"
				end_of_ram_section = ".end_of_ram"
		else:
			start_of_ram_section = ".base_inst"
			end_of_ram_section = ".secinfo"
			text_val_index = 2
		
		sharedRamSections.append(["=========================================",""])
		sharedRamSections.append(["Shared Ram Sections (alphabetically sorted)","Size"])
		sharedRamSections.append(["=========================================",""])
		
		sharedRamDescSection.append(["=========================================",""])
		sharedRamDescSection.append(["Shared Ram Descriptors Sections","Size"])
		sharedRamDescSection.append(["=========================================",""])

		
		start_of_ram_section_addr = 0
		for lineNum in range(startLine, endLine):
			currentLine = mapfile[lineNum].replace('+', ' ')
			if currentLine == "": continue
			currentLine_split = currentLine.split()
			if len(currentLine_split) == 0:	continue
			if currentLine_split[0] == start_of_ram_section:
				if (len(currentLine_split) < 2):
					start_of_ram_section_addr = utils.hex_to_dec(mapfile[lineNum + 1].split()[1])
				else:
					start_of_ram_section_addr = utils.hex_to_dec(currentLine_split[1])
			if currentLine_split[0] == end_of_ram_section:
				if (len(currentLine_split) < 2):
					iramImageSize = utils.hex_to_dec(mapfile[lineNum + 1].split()[1]) - start_of_ram_section_addr
				else:
					iramImageSize = utils.hex_to_dec(currentLine_split[1]) - start_of_ram_section_addr
			if currentLine_split[0] == ".text":
				iramTextImageSize = utils.hex_to_dec(currentLine_split[text_val_index])
			if len(currentLine_split[0].replace('_',' ').split()) == 0:	continue
			if currentLine_split[0].replace('_',' ').split()[0].strip().lower() == ".shared":
				dec_size = 0
				desc_dec_size = 0

				if (len(currentLine_split) < 2):
					# gen6
					current_line = mapfile[lineNum + 1].split()
					# dec_size = utils.hex_to_dec(current_line[3])
					if ((utils.hex_to_dec(current_line[2]) > self.shared_ram_base_address) and (utils.hex_to_dec(current_line[2]) < (self.shared_ram_base_address + self.shared_ram_size))):
						# check if current_line[2] lies between shared_ram_base_address and shared_ram_base_address+shared_ram_size
						dec_size = utils.hex_to_dec(current_line[3])
					if ((utils.hex_to_dec(current_line[2]) > self.descriptor_ram_base_address) and (utils.hex_to_dec(current_line[2]) < (self.descriptor_ram_base_address + self.descriptor_ram_size))):
						# check if current_line[2] lies between descriptor_ram_base_address and descriptor_ram_base_address+descriptor_ram_size
						desc_dec_size = utils.hex_to_dec(current_line[3])
						sharedRamDescSection.append([currentLine_split[0].replace(".SHARED_RAM_",""), desc_dec_size])
				else:
					#gen5
					dec_size = int(currentLine_split[3])
					
				# Calculates the Shram Sections Sections
				sharedRamSize = sharedRamSize + dec_size
				sharedRamDescSize = sharedRamDescSize + desc_dec_size
				sharedRamSectionsTable.append([currentLine_split[0].replace(".SHARED_RAM_",""), dec_size])
			if currentLine_split[0].replace('_',' ').split()[0].strip().lower()  == ".ddr":
				ddrImageSize = ddrImageSize +  int(currentLine_split[3]) 
		
		# sort shared ram sections
		sharedRamSectionsTable = sorted(sharedRamSectionsTable,key=lambda section:section[0])
		sharedRamSections = sharedRamSections + sharedRamSectionsTable
		
		# Wrire Sections Summary Into Result File
		resultFileHandle.write("\n===============================================")
		resultFileHandle.write("\nSummary")
		resultFileHandle.write("\n===============================================\n")	
		resultFileHandle.write("{:<40}{}".format('Total Iram Size:',iramImageSize) +"\n")
		resultFileHandle.write("{:<40}{}".format('Total Iram Text Size:',iramTextImageSize) +"\n")
		resultFileHandle.write("{:<40}{}".format('Total DDR Size:',ddrImageSize) +"\n")
		resultFileHandle.write("{:<40}{}".format('Total Iram + DDR Size:',iramImageSize + ddrImageSize) +"\n")
		resultFileHandle.write("{:<40}{}".format('Total Shram Size:',sharedRamSize) +"\n")
		resultFileHandle.write("{:<40}{}".format('Total Shram Descriptor(gen6) Size:',sharedRamDescSize) +"\n")
		resultFileHandle.write("===============================================\n\n\n")
		
		sharedRamSections.append(["-----------------------------------------",""])
		sharedRamSections.append(["Total Shram  Size:",str(sharedRamSize)])
		sharedRamSections.append(["-----------------------------------------",""])
		
		sharedRamDescSection.append(["-----------------------------------------",""])
		sharedRamDescSection.append(["Total Shram Descriptor  Size:",str(sharedRamDescSize)])
		sharedRamDescSection.append(["-----------------------------------------",""])
		
		# print(sharedRamSections)
		return (sharedRamSections, sharedRamDescSection)

	def calculateGlobalSymbolsSize(self, mapfile, startLine, endLine):

		GlobalVariables =[ ]
		TotalGlobalVarSize = 0
		
		GlobalVariables.append(["=================================================",""])
		GlobalVariables.append(["Global Symbols:","Size"])
		GlobalVariables.append(["=================================================",""])
		if self.arc_analysis is True:
			# print("-- no globals for arc")
			pass
		else:
		
			for i in range(startLine, endLine):
				currentLine = mapfile[i].split()
				
				if (len(currentLine) != 0):
					if currentLine[0].strip().lower()  == ".bss" or currentLine[0].strip().lower()  == ".ddr_bss":
					#if currentLine[0].strip().lower()  == ".initialization_data":
						badLineCounter = 0
						size = 0
						
						try:
							outputStr = currentLine[2].split(".")
								
							if(len(outputStr) < 2):
								size = (currentLine[1].split("+"))[1]
									
								if(int(size,16) > 0):								
									TotalGlobalVarSize = TotalGlobalVarSize + int(size.lstrip('0'),16)
									GlobalVariables.append([outputStr[0],size.lstrip('0')])
						except:
							utils.error("error in calculateGlobalSymbolsSize")
							badLineCounter+=1
						
		return TotalGlobalVarSize,GlobalVariables
	
	
	def parse_map_file(self, flavor_list):
		for compilationOutput in flavor_list:

			mapfilePathFileName = os.path.join(self.compilation_output, compilationOutput, "images", "map.map")
			
			# in case there are several folders in output dir
			if not os.path.exists(mapfilePathFileName):
				continue

			# print("compilationOutput ",compilationOutput)
			
			self.compilationOutput = compilationOutput
			# print("parsing:", mapfilePathFileName)
			with open(mapfilePathFileName, 'r') as infile:
				mapfile = infile.readlines()
				
			if len(mapfile) == 0:
				utils.error("input file is empty")

			resultsFile = os.path.join(self.out, compilationOutput + "_mem_consumption.txt")
			
			badLineCounter=0
			prefixLineCounter=0
			retVal = 0
			
			globalSymbolParams = []
			objectList = []
			sharedRamSections = []
			sharedRamDescSection = []
			self.arc_analysis = False
			self.wave600HwMemoryMapPath = None
			self.shared_ram_base_address = None
			self.shared_ram_size = None
			self.descriptor_ram_base_address = None
			self.descriptor_ram_size = None
			
			if "wave600" in compilationOutput:
				self.arc_analysis = True
				self.section_summary_start = "SECTION SUMMARY"
				self.section_summary_end_module_start = "SECTION DETAILS"
				self.section_module_end_global_symbols_start = "SYMBOL SUMMARY"
				self.wave600HwMemoryMapPath = os.path.join(self.parent_dir, "builds/metaware")
				if "wave600b" in compilationOutput:
					self.wave600HwMemoryMapPath = os.path.join(self.wave600HwMemoryMapPath, "WAVE600B")
				else:
					if "wave600d2" in compilationOutput:
						self.wave600HwMemoryMapPath = os.path.join(self.wave600HwMemoryMapPath, "WAVE600D2")
					else:
						self.wave600HwMemoryMapPath = os.path.join(self.wave600HwMemoryMapPath, "WAVE600")

				self.wave600HwMemoryMapPath = os.path.join(self.wave600HwMemoryMapPath, "mem_def.lcf")
				with open(self.wave600HwMemoryMapPath, 'r') as in_file:
					for line in in_file:
						if "SHARED_RAM_BASE_ADDR" in line:
							self.shared_ram_base_address = int(line.replace(";","").split()[2],0)
						if "SHARED_RAM_SIZE" in line:
							self.shared_ram_size = int(line.replace(";","").split()[2],0)
						if "DESCRIPTOR_RAM_BASE_ADDR" in line:
							self.descriptor_ram_base_address = int(line.replace(";","").split()[2],0)
						if "DESCRIPTOR_RAM_SIZE" in line:
							self.descriptor_ram_size = int(line.replace(";","").split()[2],0)
			else:
				self.section_summary_start = "Image Summary"
				self.section_summary_end_module_start = "Module Summary"
				self.section_module_end_global_symbols_start = "Global Symbols"
				
			section_summary_start_line = 0
			section_summary_end_module_start_line = 0
			section_module_end_global_symbols_start_line = 0
			line_number = 0
			padding = 0
			for line in mapfile:
				line_number += 1
				if self.section_summary_start in line:
					section_summary_start_line = line_number
				if self.section_summary_end_module_start in line:
					section_summary_end_module_start_line = line_number
				if self.section_module_end_global_symbols_start in line:
					section_module_end_global_symbols_start_line = line_number
				if "<padding>" in line:
					padding += utils.hex_to_dec(line.split()[-1])
					# padding
			self.padding_size = padding
			# print("padding[Bytes]:", padding)
			
			#global parsing
			###############
			
			globalSymbolParams = self.calculateGlobalSymbolsSize(mapfile, section_module_end_global_symbols_start_line, len(mapfile))
			moduleResultFileHandle = open(resultsFile,"w")
			(sharedRamSections, sharedRamDescSection) = self.calculateSectionSize(mapfile, section_summary_start_line, section_summary_end_module_start_line, moduleResultFileHandle)
			objectList = self.calucalteModulesSize(mapfile, section_summary_end_module_start_line, section_module_end_global_symbols_start_line, globalSymbolParams[0], moduleResultFileHandle)

			for entry in sharedRamSections:
				moduleResultFileHandle.write("{:<35}".format(entry[0]) + str(entry[1])+  "\n")
				
			for entry in sharedRamDescSection:
				moduleResultFileHandle.write("{:<35}".format(entry[0]) + str(entry[1])+  "\n")
			
			# Write To File The Detailed Sections Size
			for entry in objectList:
				if (entry[2] == ""):
					moduleResultFileHandle.write("{:<30}{}".format(entry[0],entry[1]) + "\n")
				else:
					moduleResultFileHandle.write("{:<30}{}".format(entry[0],entry[1]) + " (" + entry[2] + " + " + entry[3] + ")\n")
		
			# Write To File All Global Symblos
			for entry in globalSymbolParams[1]:	
				moduleResultFileHandle.write("\n{:<40}".format(entry[0]) + entry[1])
			
			moduleResultFileHandle.write("\n-----------------------------------------\n")
			moduleResultFileHandle.write("{:<30}{}".format('Total Global Symbol Size:',globalSymbolParams[0]) +"\n")
			moduleResultFileHandle.write("-----------------------------------------\n\n\n")
		
			moduleResultFileHandle.close()
	
def main(branch_path_name):
	
	# sys.excepthook = unhandeledException

	parser_p = map_parser(branch_path_name)
	parser_p.validate(branch_path_name)
	parser_p.prepare()
	parser_p.buildModuleList(parser_p.parent_dir)
	parser_p.build_compilation_flavor_list(parser_p.compilation_output)
	parser_p.parse_map_file(parser_p.flavor_list)
	# exit(0)
	
	# return retVal

if __name__ == '__main__':
	cwd = os.getcwd()
	branch_path_name = os.path.split(cwd)[0]
	main(branch_path_name)