Location: Stomach Annotator for SPARC @ 9ea33dda840d / digitiser / zincwidgets.py

Author:
rjag008 <rjag008@auckland.ac.nz>
Date:
2018-08-18 17:01:42+12:00
Desc:
Final Release
Permanent Source URI:
https://models.physiomeproject.org/workspace/51d/rawfile/9ea33dda840d259caf68e26d21300bece4eeff3f/digitiser/zincwidgets.py

'''
Created on 15/06/2018

@author: rjag008
'''
from opencmiss.zinc._graphics import Graphicslineattributes_SHAPE_TYPE_CIRCLE_EXTRUSION
from digitiser.dataset import PaintDatasetItem
from opencmiss.zincwidgets.sceneviewerwidget import SceneviewerWidget
from opencmiss.zinc.sceneviewerinput import Sceneviewerinput
from opencmiss.zinc.glyph import Glyph
from opencmiss.zinc.scenecoordinatesystem import SCENECOORDINATESYSTEM_WINDOW_PIXEL_TOP_LEFT, SCENECOORDINATESYSTEM_WORLD

import sys
from os import path
from digitiser.model import PaintModel, signalHandle
from collections import OrderedDict
from opencmiss.zinc.context import Context
from opencmiss.zinc.field import Field, FieldFindMeshLocation


dir_path = path.dirname(path.realpath(sys.argv[0]))
if not hasattr(sys, 'frozen'): #For py2exe
    dir_path = path.join(dir_path,"..")

uiFile = path.join(dir_path,"./digitiser/PaintUI.ui")

try:
    from PySide import QtCore, QtGui, QtOpenGL
    from pysideuiutils.uic import loadUi
    class ZincDigitiserWidgetBase(QtGui.QWidget):
        def __init__(self, parent=None):
            QtGui.QWidget.__init__(self, parent)
            loadUi(uiFile, self)
            
except ImportError:
    #from PyQt4 import QtCore, QtGui, QtOpenGL, uic
    '''
    form,base = uic.loadUiType(uiFile)
    class ZincDigitiserWidgetBase(base,form):
        def __init__(self,parent=None):
            super(base,self).__init__(parent)
            self.setupUi(self)
    '''
    pass

button_map = {QtCore.Qt.RightButton: Sceneviewerinput.BUTTON_TYPE_RIGHT}

class SceneViewerWidget(SceneviewerWidget):
    '''
    classdocs
    '''
    try:
        # PySide
        doubleClicked = QtCore.Signal()
        pointSelectionClick = QtCore.Signal(tuple)
        pointDeletionClick = QtCore.Signal(tuple)
    except AttributeError:
        # PyQt
        doubleClicked = QtCore.pyqtSignal()
        pointSelectionClick = QtCore.pyqtSignal(tuple)
        pointDeletionClick = QtCore.pyqtSignal(tuple)

    def __init__(self, parent=None, shared=None):
        '''
        Constructor
        '''
        SceneviewerWidget.__init__(self, parent, shared)
        #Code to handle double click
        self.timer = QtCore.QTimer()
        self.timer.setSingleShot(True)
        # had to add a "__doubleClickedFlag" flag
        self.__doubleClickedFlag = False
        self.__doubleClickInterval = QtGui.qApp.doubleClickInterval()
        self._selectionKeyPressed = False
        self.nodeOfInterest = None
        
    #Override mouse events so that they can be broadcast to listerners        
    def mousePressEvent(self, event):
        #x = event.x()
        #y = event.y()
        #self.nodeOfInterest = self.getNearestGraphicsData(x, y)
        super(SceneViewerWidget,self).mousePressEvent(event)
    
    def mouseMoveEvent(self, event):
        super(SceneViewerWidget,self).mouseMoveEvent(event)
    
    def mouseReleaseEvent(self, event):
        #super(SceneViewerWidget,self).mouseReleaseEvent(event)
        event.accept()
        
        if event.button() in button_map:
            x = event.x()
            y = event.y()
            px,py,_ = self.transformToWorld(x,y)            
             
            if event.button()==QtCore.Qt.RightButton:
                if not self._selectionKeyPressed:
                    self.pointSelectionClick.emit([px,py])
                else:
                    self.nodeOfInterest = self.getNearestNode(x, y)
                    self.pointDeletionClick.emit([px,py])

        else:
            self.nodeOfInterest = None
                
        if not self.__doubleClickedFlag:
            self.timer.start(self.__doubleClickInterval)
        else:
            self.__doubleClickedFlag = False

    def mouseDoubleClickEvent(self, event):
        event.accept()
        self.timer.stop()
        self.__doubleClickedFlag = True
        self.doubleClicked.emit()
        
    def keyPressEvent(self, event):
        if (event.key() == QtCore.Qt.Key_Shift) and (event.isAutoRepeat() == False):
            self._selectionKeyPressed = True
            event.setAccepted(True)
        else:
            event.ignore()

    def keyReleaseEvent(self, event):
        if (event.key() == QtCore.Qt.Key_Shift) and event.isAutoRepeat() == False:
            self._selectionKeyPressed = False
            event.setAccepted(True)
        else:
            event.ignore()        


    def transformToWorld(self,x,y):
        #Go from screen to world coordinates
        #First find the depth by projecting known node value in world coordinates on to pixel coordinates
        _,loc = self._sceneviewer.transformCoordinates (SCENECOORDINATESYSTEM_WORLD,SCENECOORDINATESYSTEM_WINDOW_PIXEL_TOP_LEFT,self._sceneviewer.getScene(), [0.0,0.0,0.0] )
        #Due to right handed coordinate system, y should be sent in as -ve
        _,pos = self._sceneviewer.transformCoordinates (SCENECOORDINATESYSTEM_WINDOW_PIXEL_TOP_LEFT,SCENECOORDINATESYSTEM_WORLD,self._sceneviewer.getScene(), [x,-y,loc[2]] )
        return pos



class ZincGraphicsEllipseItem(object):

    def __init__(self, material, key,parent):
        self.parent = parent
        self.material = material
        self._key = key

    def getItemKey(self):
        return self._key

    def setMaterial(self,color):
        self.material = color



class DigitizationView(object):
    #PenNames = ['red','green','blue','cyan','magenta','yellow']
    PenColors = [QtGui.QColor(col) for col in [ QtCore.Qt.red, QtCore.Qt.green, QtCore.Qt.blue,\
                                                         QtCore.Qt.cyan, QtCore.Qt.magenta, QtCore.Qt.yellow]]
    
    def __init__(self, model, zincContext, graphicsScene, listDatasets):
        self._model = model
        self._graphicsScene = graphicsScene
        self.zincContext = zincContext
        self.rootRegion = self.zincContext.getDefaultRegion()
        self._listDatasets = listDatasets
        #Create a list of materials
        
        self.datasetColors = self._model.datasetColors
        if len(self.datasetColors)==0:
            self.datasetColors[0] = self.createPenAndMaterial(self.PenColors[0]) #Default item's color
        
        #Set the clipping plane
        self._graphicsScene.setViewParameters([0.0, 0.0, 4.742128877220128], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0], 0.6981317007977312)
        
        self.reset()
        self._model.modelChanged.connect(self.modelChanged)
        self.modelChanged()
        self._listDatasets.itemChanged.connect(self.updateDatasetName)
        self._listDatasets.itemDoubleClicked.connect(self.changeItemColor)
        self.freeDataNodes = []
        self.numberOfTemporaryColors = 0

    def changeItemColor(self,item):
        newColor = item.changeColor()
        if not newColor is None:
            ds = item._key
            #Change the color of related points
            self.datasetColors[ds] = self.createPenAndMaterial(newColor)
            mat = self.datasetColors[ds][0]
            for k in self._pointItems:
                itm, dataset = self._pointItems[k]
                if dataset==ds:
                    itm.setMaterial(mat)
                    self.fieldCache.setNode(itm.parent)
                    self.nodeColorField.assignReal(self.fieldCache,mat)
                     

    def createPenAndMaterial(self,color):
        rgb = [color.red()/255.0, color.green()/255.0, color.blue()/255.0]
        return [rgb,color]

    def reset(self):
        region = self.rootRegion.findChildByName('digitizer')
        if region.isValid():
            self.rootRegion.removeChild(region)
        self._listDatasets.clear()

        self._pointItems = {}
        self._datasetItems = {}
        self.freeDataNodes = []
        self.usedDataNodes = dict()


    def loadMesh(self,meshGenerator):
        region = self.rootRegion.findChildByName('digitizer')
        if region.isValid():
            self.rootRegion.removeChild(region)
        region = self.rootRegion.createChild('digitizer')
        meshGenerator.generateMesh(region)
        self.region = region
        self.useMesh()
        
    def useMesh(self):
        region = self.rootRegion.findChildByName('digitizer')
        self.region = region
        fieldModule = region.getFieldmodule()
        self.xiCoordinates = fieldModule.findFieldByName('texture')
        #Create necessary fields
        nodeAnnotationField = fieldModule.findFieldByName('nodeAnnotation')
        if not nodeAnnotationField.isValid():
            nodeAnnotationField = fieldModule.createFieldStoredString()
            nodeAnnotationField.setName('nodeAnnotation')
        self.nodeAnnotationField = nodeAnnotationField

        nodeVisibilityFlagField = fieldModule.findFieldByName('nodeVisibilityFlag')
        if not nodeVisibilityFlagField.isValid():
            nodeVisibilityFlagField = fieldModule.createFieldFiniteElement(1)
            nodeVisibilityFlagField.setName('nodeVisibilityFlag')           
        self.nodeVisibilityFlagField = nodeVisibilityFlagField
        
        nodeSizeField = fieldModule.findFieldByName('nodeSizeField')
        if not nodeSizeField.isValid():
            nodeSizeField = fieldModule.createFieldFiniteElement(9)
            nodeSizeField.setName('nodeSizeField')           
        self.nodeSizeField = nodeSizeField
        
        nodeColorField = fieldModule.findFieldByName('nodeColor')
        if not nodeColorField.isValid():
            nodeColorField = fieldModule.createFieldFiniteElement(3)
            nodeColorField.setName('nodeColor')
        self.nodeColorField = nodeColorField        
        
        self.datanodeset = fieldModule.findNodesetByFieldDomainType(self.xiCoordinates.DOMAIN_TYPE_DATAPOINTS)
        self.fieldCache = fieldModule.createFieldcache()
        self.coordinatesField = fieldModule.findFieldByName('coordinates').castFiniteElement()
        self.nodetemplate = self.datanodeset.createNodetemplate()
        self.nodetemplate.defineField(self.coordinatesField)
        self.nodetemplate.defineField(self.nodeAnnotationField)
        self.nodetemplate.defineField(self.nodeVisibilityFlagField)
        self.nodetemplate.defineField(self.nodeSizeField)
        self.nodetemplate.defineField(self.nodeColorField)
        
        #Create the scene filter and set the app's sceneview to use it
        sceneFilterModule = self.zincContext.getScenefiltermodule()        
        self.sceneFilter = sceneFilterModule.createScenefilterRegion(region)
        
    def linkToSceneViewer(self):
        if hasattr(self, 'sceneFilter') and self.sceneFilter.isValid():
            self._graphicsScene.setScenefilter(self.sceneFilter)
            self._graphicsScene.setScene(self.region.getScene()) #Picking fails without this
            self._graphicsScene.setTumbleRate(0.0)

    def setTexture(self,imageFile):
        fieldModule = self.region.getFieldmodule()
        image_field = fieldModule.createFieldImage()
        image_field.setFilterMode(image_field.FILTER_MODE_LINEAR)
        image_field.setWrapMode(image_field.WRAP_MODE_EDGE_CLAMP)
        #coordinates = fieldModule.findFieldByName('coordinates')
        image_field.setDomainField(self.coordinatesField)
        image_field.setTextureCoordinateSizes([1, 1, 1])
        # Create a stream information object that we can use to read the
        # image file from disk
        stream_information = image_field.createStreaminformationImage()
        stream_information.createStreamresourceFile(imageFile)
        image_field.read(stream_information)
        scene = self.region.getScene()
        textureMaterial = scene.getMaterialmodule().createMaterial()
        textureMaterial.setManaged(True)
        textureMaterial.setName("regionMap")
        textureMaterial.setTextureField(1, image_field)
        #textureMaterial.setAttributeReal(Material.ATTRIBUTE_ALPHA, 0.2)
        
        if hasattr(self, 'textureMaterial'):
            del self.textureMaterial
        self.textureMaterial = textureMaterial
        if hasattr(self, 'surfaceGraphics'):
            self.surfaceGraphics.setMaterial(self.textureMaterial)

    
    def setupGraphics(self):
        scene = self.region.getScene()

        scene.beginChange()
        # createSurfaceGraphic graphic start
        fieldModule = self.region.getFieldmodule()
        coordinateField = fieldModule.findFieldByName('coordinates')

        glyphModule = scene.getGlyphmodule()
        glyphModule.defineStandardGlyphs() 
        glyphModule.beginChange()
        #axisGlyph = glyphModule.findGlyphByGlyphShapeType(Glyph.SHAPE_TYPE_AXES_SOLID_XYZ) 
        glyphModule.endChange()
        '''
        graphics = scene.findGraphicsByName("axis")
        if graphics.isValid():
            scene.removeGraphics(graphics)
        graphics = scene.createGraphicsPoints()
        graphics.setName("axis")
        graphics.setScenecoordinatesystem(SCENECOORDINATESYSTEM_NORMALISED_WINDOW_FIT_BOTTOM)
        pointattributes = graphics.getGraphicspointattributes()
        pointattributes.setGlyph(axisGlyph)
        pointattributes.setBaseSize([0.1,0.1,0.1])
        #pointattributes.setGlyphOffset([-0.9,0.0,0.0])
        '''
        
        '''
        graphics1 = scene.createGraphicsPoints()
        graphics1.setFieldDomainType(Field.DOMAIN_TYPE_MESH_HIGHEST_DIMENSION)
        graphics1.setCoordinateField(coordinateField)
        gpa1 = graphics1.getGraphicspointattributes()
        cmiss_number = fieldModule.findFieldByName('cmiss_number')
        gpa1.setLabelField(cmiss_number)
        gpa1.setGlyphShapeType(Glyph.SHAPE_TYPE_NONE)
        '''
        
        materialModule = scene.getMaterialmodule ()
        materialModule.defineStandardMaterials ()
        surfaceMaterial = materialModule.findMaterialByName('gold')
        lineMaterial = materialModule.findMaterialByName('silver')
        # Create Surface
        
        self.surfaceGraphics = scene.createGraphicsSurfaces()
        self.surfaceGraphics.setCoordinateField(coordinateField)
        self.surfaceGraphics.setTextureCoordinateField(self.xiCoordinates)
        self.surfaceGraphics.setMaterial(surfaceMaterial)
        
        #Create lines
        self.surfaceLines = scene.createGraphicsLines()
        self.surfaceLines.setCoordinateField(coordinateField)
        lineattributes = self.surfaceLines.getGraphicslineattributes()
        lineattributes.setShapeType(Graphicslineattributes_SHAPE_TYPE_CIRCLE_EXTRUSION)
        lineattributes.setBaseSize([0.01,0.01,0.01])
        self.surfaceLines.setMaterial(lineMaterial)
        #Create spectrum to pick up rgb from nodeColor Field
        spectrumModule = scene.getSpectrummodule()
        spectrum = spectrumModule.createSpectrum()
        spectrum.setMaterialOverwrite(True) #This will ensure that the transparency of the material is used
        spectrum.setName("RGB")
        spectrumR = spectrum.createSpectrumcomponent()
        spectrumR.setColourMappingType(spectrumR.COLOUR_MAPPING_TYPE_RED)
        spectrumR.setFieldComponent(1)
        spectrumG = spectrum.createSpectrumcomponent()
        spectrumG.setColourMappingType(spectrumR.COLOUR_MAPPING_TYPE_GREEN)
        spectrumG.setFieldComponent(2)
        spectrumB = spectrum.createSpectrumcomponent()
        spectrumB.setColourMappingType(spectrumR.COLOUR_MAPPING_TYPE_BLUE)
        spectrumB.setFieldComponent(3)
        for spec in [spectrumR,spectrumG,spectrumB]:
            spec.setRangeMinimum(0.0)
            spec.setRangeMaximum(1.0)        
            spec.setColourMinimum(0.0)
            spec.setColourMaximum(1.0)
            spec.setExtendBelow(True)
            spec.setExtendAbove(True)
        self.spectrumR = spectrumR
        self.spectrumB = spectrumB
        self.spectrumG = spectrumG
        self.spectrum = spectrum
        #Code for rendering datapoints
        graphics = scene.createGraphicsPoints()
        graphics.setFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
        graphics.setCoordinateField(coordinateField)
        graphics.setSubgroupField(self.nodeVisibilityFlagField)
        graphics.setDataField(self.nodeColorField)
        
        gpa = graphics.getGraphicspointattributes()
        gpa.setGlyphShapeType(Glyph.SHAPE_TYPE_SPHERE)
        gpa.setOrientationScaleField(self.nodeSizeField)
        graphics.setSpectrum(spectrum)

        
        scene.endChange()

    def getUsedNodesWithXiAndData(self):
        fieldModule = self.region.getFieldmodule()
        fieldCache  = fieldModule.createFieldcache()
        mesh = fieldModule.findMeshByDimension(2)
        meshLocation = fieldModule.createFieldFindMeshLocation(self.coordinatesField,self.coordinatesField,mesh)
        meshLocation.setSearchMode(FieldFindMeshLocation.SEARCH_MODE_NEAREST)
        result = dict()
        for nd in self.usedDataNodes:
            fieldCache.setNode(self.datanodeset.findNodeByIdentifier(nd))
            element, xi = meshLocation.evaluateMeshLocation(fieldCache, 2)
            annot = self.nodeAnnotationField.evaluateString(fieldCache)
            _,vis = self.nodeVisibilityFlagField.evaluateReal(fieldCache,1)
            _,size = self.nodeSizeField.evaluateReal(fieldCache,9)
            _,color = self.nodeColorField.evaluateReal(fieldCache,3)
            result[nd] = [element.getIdentifier(),xi,annot,vis,size,color]
        return result
        
    def modelChanged(self):
        self.datasetColors = self._model.datasetColors
        self.updateListDatasets() #Create the list and then the points as list name is used for annotation
        self.updatePointItems()
        

    def updatePointItems(self):
        model_points = self._model.getPoints()
        set_model_keys = set(model_points.keys())
        set_view_keys = set(self._pointItems.keys())
        set_intersect_keys = set_model_keys.intersection(set_view_keys)
        # Added points
        added_keys = set_model_keys - set_intersect_keys
        for key in added_keys:
            pos, dataset = model_points[key]
            pos = [pos[0], pos[1]]
            self.addPointItem(key, pos, dataset)

        # Removed points
        removed_keys = set_view_keys - set_intersect_keys
        for key in removed_keys:
            self.removePointItem(key)


    def createDataNode(self,key):
        if len(self.freeDataNodes)==0:
            node = self.datanodeset.createNode(-1, self.nodetemplate)
        else:
            node = self.freeDataNodes.pop(0)
        self.usedDataNodes[node.getIdentifier()] = key
        return node

    def setPosition(self,node,pos):
        self.fieldCache.setNode(node)
        self.coordinatesField.assignReal(self.fieldCache,[pos[0],pos[1],0.0])
        
    def setAnnotation(self,node,annot):
        self.fieldCache.setNode(node)
        self.nodeAnnotationField.assignString(self.fieldCache,str(annot))
        
    def addPointItem(self, key, pos, dataset):
        #Only add if it is within the bounds
        if pos[0]<=1.0 and pos[0]>=-1.0 and pos[1]<=1.0 and pos[1]>=-1.0:            
            pen = self.datasetColors[dataset][0]
            item = ZincGraphicsEllipseItem(pen, key,self.createDataNode(key))
            self._pointItems[key] = (item, dataset)
            self.setPosition(item.parent, pos)
            self.nodeSizeField.assignReal(self.fieldCache,[0.1,0,0,0,0.1,0,0,0,0.1])
            self.nodeVisibilityFlagField.assignReal(self.fieldCache,1.0)
            #Assign color
            self.nodeColorField.assignReal(self.fieldCache,pen)
            annot = self._datasetItems[dataset]._name
            self.nodeAnnotationField.assignString(self.fieldCache,str(annot))

    def removePointItem(self, key):
        item, _ = self._pointItems[key]
        self.fieldCache.setNode(item.parent)
        self.nodeVisibilityFlagField.assignReal(self.fieldCache,0.0)
        self.nodeAnnotationField.assignString(self.fieldCache,"Deleted")
        self.freeDataNodes.append(item.parent)
        del self.usedDataNodes[item.parent.getIdentifier()]
        item.parent = None
        del self._pointItems[key]

    def listSelectionChanged(self,dataset):
        for k in self._pointItems:
            item,ds = self._pointItems[k]
            self.fieldCache.setNode(item.parent)
            if ds==dataset:
                self.nodeSizeField.assignReal(self.fieldCache,[0.1,0,0,0,0.1,0,0,0,0.1])
            else:
                self.nodeSizeField.assignReal(self.fieldCache,[0.05,0,0,0,0.05,0,0,0,0.05])


    def updateDatasetName(self,item):
        key = item._key
        if item._name != str(item.text()):
            annot = str(item.text())
            item._name = annot
            self._model.changeDatasetName(key,annot)
            for k in self._pointItems:
                itm,ds = self._pointItems[k]
                if ds==item._key:
                    self.fieldCache.setNode(itm.parent)
                    self.nodeAnnotationField.assignString(self.fieldCache,annot)
        else:
            self.listItemVisibilityToggled(item)
            
    def updateListDatasets(self):
        model_datasets = self._model.getDatasets()
        set_model_keys = set(model_datasets.keys())
        set_view_keys = set(self._datasetItems.keys())
        set_intersect_keys = set_model_keys.intersection(set_view_keys)
        
        #Set the names to match for common keys
        def iterAllItems(self):
            for i in range(self.count()):
                yield self.item(i)

        for item in iterAllItems(self._listDatasets):
            if item._key in set_intersect_keys:
                name = model_datasets[item._key]
                item._name = name
                item.setText(name)
                    
        # Added datasets
        added_keys = set_model_keys - set_intersect_keys
        for key in added_keys:
            name = model_datasets[key]
            self.addListItem(key, name)

        # Removed datasets
        removed_keys = set_view_keys - set_intersect_keys

        for key in removed_keys:
            self.removeListItem(key)

        
        #Set current item
        key = self._model.getCurrentDataset()
        if not key is None:
            item = self._datasetItems[key]
            self._listDatasets.setCurrentItem(item)
        

    def addListItem(self, key, name):
        if key in self.datasetColors:
            col = self.datasetColors[key][1]
        else:
            pen,col = self.createPenAndMaterial(self.PenColors[key % len(self.PenColors)])
            self.datasetColors[key] = [pen,col]
           
        item = PaintDatasetItem(key, name, col)
        self._datasetItems[key] = item
        self._listDatasets.addItem(item)
       

    def listItemVisibilityToggled(self,item):
        state = 0.0
        if item.checkState():
            state = 1.0
        key = item._key
        for k in self._pointItems:
            itm,ds = self._pointItems[k]
            if ds==key:
                self.fieldCache.setNode(itm.parent)
                self.nodeVisibilityFlagField.assignReal(self.fieldCache,state)
            
            
    def removeListItem(self, key):
        # TODO: Iterate  through items to find the one with the key
        def iterAllItems(self):
            for i in range(self.count()):
                yield self.item(i)

        for item in iterAllItems(self._listDatasets):
            if item._key == key:
                self._listDatasets.takeItem(self._listDatasets.row(item))
                del item
                break
 
        del self._datasetItems[key]

class ZincPainterWidget(ZincDigitiserWidgetBase):
    
    graphicsInitialized = signalHandle()
    
    def __init__(self,zincContext,parent=None,shared=None,project=None):
        super(ZincPainterWidget,self).__init__(parent)
        self.sceneLayout = QtGui.QVBoxLayout(self.graphicsViewHolder)
        self.qgl = QtOpenGL.QGLWidget()
        self.graphicsView = SceneViewerWidget(self.qgl,shared)
        self.graphicsView.setContext(zincContext)
        self.sceneLayout.addWidget(self.graphicsView)
        self.zincContext = zincContext
        self.filename = None
        self.model = PaintModel()
        self.model.modelChanged.connect(self.updateInterface)
        
        self.addNewDataset.setIcon(self.style().standardIcon(getattr(QtGui.QStyle, 'SP_FileDialogNewFolder')))
        self.removeCurrentDataset.setIcon(self.style().standardIcon(getattr(QtGui.QStyle, 'SP_MessageBoxCritical')))
        self.moveUp.setIcon(self.style().standardIcon(getattr(QtGui.QStyle, 'SP_ArrowUp')))
        self.moveDown.setIcon(self.style().standardIcon(getattr(QtGui.QStyle, 'SP_ArrowDown')))
        self.saveData.setIcon(self.style().standardIcon(getattr(QtGui.QStyle, 'SP_DialogSaveButton')))
        self.loadData.setIcon(self.style().standardIcon(getattr(QtGui.QStyle, 'SP_DialogOpenButton')))
        
        self.addNewDataset.clicked.connect(self.addDataset)
        self.removeCurrentDataset.clicked.connect(self.removeDataset)
        #self.loadBackground.clicked.connect(self.newBackground)
        self.saveData.clicked.connect(self.saveRegions)
        self.loadData.clicked.connect(self.loadRegions)
        
        self.moveUp.clicked.connect(self.moveDatasetItemUp)
        self.moveDown.clicked.connect(self.moveDatasetItemDown)
        # Actions
        self.listDatasets.itemSelectionChanged.connect(self.selectedDatasetChanged)
        
        # Mouse move
        self.graphicsView.graphicsInitialized.connect(self.graphicsInitialized)
        self.graphicsView.pointSelectionClick.connect(self.pointSelected)
        self.graphicsView.pointDeletionClick.connect(self.pointDeleted)
        self.graphicsView.doubleClicked.connect(self.viewAll)
        self.loadproject(project)


    def moveDatasetItemUp(self):
        numItems = self.listDatasets.count()
        if numItems>1:
            currentRow = self.listDatasets.currentRow()
            if currentRow > 0:
                ci = self.listDatasets.takeItem(currentRow)
                self.listDatasets.insertItem(currentRow - 1, ci)
                self.listDatasets.setCurrentRow(currentRow - 1)

    def moveDatasetItemDown(self):
        numItems = self.listDatasets.count()
        if numItems>1:
            currentRow = self.listDatasets.currentRow()
            if currentRow < numItems-1:
                ci = self.listDatasets.takeItem(currentRow)
                self.listDatasets.insertItem(currentRow + 1, ci)
                self.listDatasets.setCurrentRow(currentRow + 1)

    def setupMeshBackground(self,backgroundMesh):
        '''
        Load a mesh and create a background based on the node that lie on xy-plane
        Assumes the standard Stomach Meshobject is used 
        '''
        self.model.newDataSet('Default')
        self.view = DigitizationView(self.model, self.zincContext, self.graphicsView, self.listDatasets)
        self.view.loadMesh(backgroundMesh)
        self.view.linkToSceneViewer()
        self.view.setupGraphics()
        self.updateInterface()

    def useBackgroundMesh(self):
        '''
        Use a mesh created for the background based 
        Assumes the standard Stomach Meshobject is used 
        '''
        self.model.newDataSet('Default')
        self.view = DigitizationView(self.model, self.zincContext, self.graphicsView, self.listDatasets)
        self.view.useMesh()
        self.view.linkToSceneViewer()
        self.view.setupGraphics()
        self.updateInterface()

    def setTexture(self,filename):
        self.view.setTexture(filename)
    
    def getMarkersAndColors(self):
        result = self.model.getDatasetInTransformedCoordinates()
        #Order the result in the listed order
        orderedResult = OrderedDict()
        for row in range(self.listDatasets.count()):
            item = self.listDatasets.item(row)
            name = item._name
            orderedResult[name] = result[name]
        return orderedResult        
    
    def viewAll(self):
        #self.graphicsView.viewAll()
        self._graphicsScene.setViewParameters([0.0, 0.0, 4.742128877220128], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0], 0.6981317007977312)
    
    def saveRegions(self,filename=None):
        if filename is None:
            filename = QtGui.QFileDialog.getSaveFileName(self,
                'Dataset file', '',
                "Data Files (*.dat)")
            if isinstance(filename,tuple): #Handle pyside
                filename = str(filename[0])        
        if filename is not None and not str(filename) == "":
            return self.model.saveDataset(filename)
        return False
            
    def loadRegions(self,filename=None):
        if filename is None:
            filename = QtGui.QFileDialog.getOpenFileName(self,
                'Dataset file', '',
                "Data Files (*.dat)")
            if isinstance(filename,tuple): #Handle pyside
                filename = str(filename[0])
        if filename is not None and not str(filename) == "":
            #Do not clear the lists, it is automatically taken care by the model
            self.model.loadDataset(filename)
            self.view.listSelectionChanged(self.model._currentDataset)
        
    def getUsedNodesWithXiAndData(self):
        return self.view.getUsedNodesWithXiAndData()

    def newBackground(self,filename):
        self.model.addImage(str(filename))
        self.view.setTexture(filename)   
        self.updateInterface()
    
    def defaultProject(self):
        self.model.newDataSet('Default')
        self.view = DigitizationView(self.model, self.zincContext, self.graphicsView, self.listDatasets)
        self.view.linkToSceneViewer()
        self.view.setupGraphics()
        self.updateInterface()
        

    def loadproject(self,filename):
        if filename is not None and not str(filename) == "":
            # Try to load project
            self.model.loadDataset(filename)
            self.view = DigitizationView(self.model, self.zincContext, self.graphicsView, self.listDatasets)
            self.view.linkToSceneViewer()
            self.view.setupGraphics()
            #Get the colors and texture file information
            self.updateInterface()


    def updateInterface(self):
        pass

    def addDataset(self):
        self.model.addDataset()
        

    def removeDataset(self):
        row = self.listDatasets.currentRow()
        item = self.listDatasets.takeItem(row)
        self.model.removeDataset(item._key)

    def selectedDatasetChanged(self):
        if not self.listDatasets.currentItem() is None:
            ds = self.listDatasets.currentItem()._key
            self.model.changeCurrentDataset(ds)
            if hasattr(self, 'view'):
                self.view.listSelectionChanged(ds)
        else:
            self.model.changeCurrentDataset(None)

    def pointSelected(self,pos):
        if not self.graphicsView.nodeOfInterest is None and self.graphicsView.nodeOfInterest.isValid():
            nid = self.graphicsView.nodeOfInterest.getIdentifier()
            if nid in self.view.usedDataNodes:
                self.pointItemMoved(self.view.usedDataNodes[nid], [pos[0],pos[1]])
            else:
                self.addPoint([pos[0],pos[1]])
        else:
            self.addPoint([pos[0],pos[1]])
    
    def pointDeleted(self,pos):
        if not self.graphicsView.nodeOfInterest is None and self.graphicsView.nodeOfInterest.isValid():
            nid = self.graphicsView.nodeOfInterest.getIdentifier()
            if nid in self.view.usedDataNodes:
                self.model.removePoint(self.view.usedDataNodes[nid])
                
    def addPoint(self, pos):
        self.model.addPoint(pos)
    
    def removePoint(self,key):
        self.model.removePoint(key)

    def pointItemMoved(self, key, pos):
        self.model.movePoint(key, pos)
        

from digitiser.mapping import MeshMapper
if __name__ == '__main__':
    # Entry point
    application = QtGui.QApplication(sys.argv)
    ctx = Context('ctx')
    main_window = ZincPainterWidget(ctx)
    main_window.show()
    #main_window.defaultProject()
    backgroundMesh = MeshMapper()
    #self.backgroundMesh.setupByFile(r'scaffoldmaker.ex2', 8,11)
    backgroundMesh.setupByPickle('symmetricstomachsurface.pkl')
    
    class MeshGen(object):
        def __init__(self,mesh):
            self.meshgen = mesh

        def generateMesh(self,region):
            self.meshgen.generateFlatMesh(region, 8,11)
    
    main_window.setupMeshBackground(MeshGen(backgroundMesh))
    main_window.newBackground(r'../textureprocessing/foot1820.jpg')
    
    sys.exit(application.exec_())