'''
    SPDX-FileCopyrightText: 2024 Agata Cacko <cacko.azh@gmail.com>

    This file is part of Fast Sketch Cleanup Plugin  for Krita

    SPDX-License-Identifier: GPL-3.0-or-later
'''


import numpy as np
import openvino as ov


class InferenceRunner:

	requests = []


	def __init__(self, owner):
		self.owner = owner


	

	def startInference(self, model: ov.CompiledModel, data, partSize, margin, divisableBy):

		# print(f"################ Inference Runner: START INFERENCE ################")
		
		# we assume static model, therefore no smaller parts!
		assert partSize%divisableBy == 0, "Max size of the part must be divisable by the divisableBy, but is: {partSize}"
		
		shape = data.shape
		width = shape[2]
		height = shape[3]
		
		data = data.astype(model.input().get_element_type().to_dtype())

		assert (width % divisableBy == 0 and height % divisableBy == 0), f"Size of data must be already divisable by the divisableBy, but is: {width} {height} (shape: {data.shape})"

		trueDataPartSize = partSize - 2*margin
		partsInData = np.ceil([width/trueDataPartSize, height/trueDataPartSize])
		requiredAdditionalMargin = partsInData*trueDataPartSize
		
		# same additional padding in both cases
		padded = np.pad(data, ((0, 0), (0, 0), (margin, margin + int(requiredAdditionalMargin[0])), (margin, margin + int(requiredAdditionalMargin[1]))), 'constant', constant_values=(0, 0))

		self.outputDataPadded = np.pad(data, ((0, 0), (0, 0), (0, int(requiredAdditionalMargin[0])), (0, int(requiredAdditionalMargin[1]))), 'constant', constant_values=(0, 0))
		
		startWMargin = 0

		partsList = []

		while True: # x/width

			startHMargin = 0

			while True: # y/height

				# note: those are dimensions for padded, not true data
				# true data dimensions are +margin!

				startWData = startWMargin + margin
				startHData = startHMargin + margin

				endWMargin = startWMargin + partSize
				endHMargin = startHMargin + partSize
				
				endWData = endWMargin - margin
				endHData = endHMargin - margin

				####################
				partsList.append(((startWMargin, endWMargin, startHMargin, endHMargin), 
					(margin,(partSize - margin), margin,(partSize - margin)), 
					((startWData-margin),(endWData-margin), (startHData-margin),(endHData-margin))))

				if endHMargin >= height + 2*margin:
					break
					
				startHMargin += trueDataPartSize

			if endWMargin >= width + 2*margin:
				break
			
			startWMargin += trueDataPartSize
		

		self.inferenceRequestsLeft = len(partsList)
		self.inferenceRequestsAllCount = len(partsList)

		self.width = width
		self.height = height

		infer_queue = ov.AsyncInferQueue(model)
		infer_queue.set_callback(self.inferenceCallback)

		# print(f"Inferences to start: {len(partsList)}")

		for partInfo in partsList:
			
			part = padded[:, :, partInfo[0][0]:partInfo[0][1], partInfo[0][2]:partInfo[0][3]]

			input_tensor = ov.Tensor(array=part)
			infer_queue.start_async({0: input_tensor}, partInfo)

		self.infer_queue = infer_queue

		# print(f"Finished starting all inferences.")

	def waitForQueue(self):
		self.infer_queue.wait_all()


	def inferenceCallback(self, i_request, request):
		#print(f"Fast Sketch Cleanup plugin: part {self.inferenceRequestsAllCount - self.inferenceRequestsLeft + 1}/{self.inferenceRequestsAllCount} finished.")
		
		partInfo = request
		infer_request = i_request
		
		output = infer_request.get_output_tensor()
		result = output.data

		self.outputDataPadded[:, :, partInfo[2][0]:partInfo[2][1], partInfo[2][2]:partInfo[2][3]] = result[:, :, partInfo[1][0]:partInfo[1][1], partInfo[1][2]:partInfo[1][3]]


		self.inferenceRequestsLeft -= 1
		if self.inferenceRequestsLeft > 0:
			self.owner.stepInInference(self.inferenceRequestsAllCount - self.inferenceRequestsLeft, self.inferenceRequestsAllCount)
			return

		outputData = self.outputDataPadded[:, :, 0:self.width, 0:self.height]

		# FIXME: it's a weird mixup - probably from numpy and Qt having the opposite ideas in which order height and width should go
		self.owner.finishInference(outputData, width=self.height, height=self.width)
		
		

	def cancelAllInference(self):
		# should work, but I can't call it...
		for infer_request in self.infer_queue:
			infer_request.cancel()
		self.infer_queue.wait_all()

