Hitachi Vantara Pentaho Community Wiki
Child pages
  • LearningCurve.groovy
Skip to end of metadata
Go to start of metadata

You are viewing an old version of this page. View the current version.

Compare with Current View Page History

Version 1 Next »

import java.io.Serializable
import java.util.Vector
import java.util.Enumeration
import org.pentaho.dm.kf.KFGroovyScript
import org.pentaho.dm.kf.GroovyHelper
import weka.core.*
import weka.gui.Logger
import weka.gui.beans.*
import weka.classifiers.bayes.NaiveBayes
import weka.classifiers.functions.Logistic
import weka.classifiers.Evaluation
import weka.classifiers.Classifier

import groovy.swing.SwingBuilder
import javax.swing.*
import java.awt.*


// add further imports here if necessary

/**
 * Example Groovy script that generates a learning curve for a classifier.
 * Allows the classifier to be connected via a "configuration" event or
 * specified via an environment variable (CLASSIFIER_NAME). Classifier options
 * and the parameters of the learning curve could be specified via environment
 * variables as well through just minor changes to the script.
 *
 * Generates both a "TextEvent" containing the curve information and a
 * "DataSetEvent". The latter can be visualized in a DataVisualizer component.
 *
 * Also demonstrates how to allow the user to set options for the script
 * via a graphical pop-up window.
 *
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}org)
 */
class LearningCurve
        implements
                KFGroovyScript,
                EnvironmentHandler,
                BeanCommon,
                EventConstraints,
                UserRequestAcceptor,
                TrainingSetListener,
                TestSetListener,
                DataSourceListener,
                InstanceListener,
                TextListener,
                BatchClassifierListener,
                IncrementalClassifierListener,
                BatchClustererListener,
                GraphListener,
                ChartListener,
                ThresholdDataListener,
                VisualizableErrorListener,
                ConfigurationListener,
                Serializable {

  /** Don't delete!!
   *  GroovyHelper has the following useful methods:
   *
   *  notifyListenerType(Object event) - GroovyHelper will pass on event
   *    appropriate listener type for you
   *  ArrayList<TrainingSetListener> getTrainingSetListeners() - get
   *    a list of any directly connected components that are listening
   *    for TrainingSetEvents from us
   *  ArrayList<TestSetListener> getTestSetListeners()
   *  ArrayList<InstanceListener> getInstanceListeners()
   *  ArrayList<TextListener> getTextListeners()
   *  ArrayList<DataSourceListener> getDataSourceListeners()
   *  ArrayList<BatchClassifierListener> getBatchClassifierListeners()
   *  ArrayList<IncrementalClassifierListener> getIncrementalClassifierListeners()
   *  ArrayList<BatchClustererListener> getBatchClustererListeners()
   *  ArrayList<GraphListenerListener> getGraphListeners()
   *  ArrayList<ChartListener> getChartListeners()
   *  ArrayList<ThresholdDataListener> getThresholdDataListeners()
   *  ArrayList<VisualizableErrorListener> getVisualizableErrorListeners()
   */
  GroovyHelper m_helper

  Logger m_log = null

  Environment m_env = Environment.getSystemWide()

  String m_holdoutSize = "33.0"
  String m_stepSize = "100"
  String m_numSteps = "10"
  String m_classifierName = "\${CLASSIFIER_NAME}"
  String m_classifierOptions = null

  Object m_incomingConnection = null

  weka.gui.beans.Classifier m_connectedConfigurable = null

  /** Don't delete!! */
  void setManager(GroovyHelper manager) { m_helper = manager }

  /** Alter or add to in order to tell the KnowlegeFlow
   *  environment whether a certain incoming connection type is allowed
   */
  boolean connectionAllowed(String eventName) {
    if (eventName.equals("trainingSet") && 
        m_incomingConnection == null) { return true }

    if (eventName.equals("configuration") &&
        m_connectedConfigurable == null) { return true}

      return false
  }

  /** Add (optional) code to do something when you have been
   *  registered as a listener with a source for the named event
   */
  void connectionNotification(String eventName, Object source) {
    if (eventName.equals("trainingSet")) {
      m_incomingConnection = source
    }

    if (eventName.equals("configuration")) {
      // check the type of the configurable
      if (source instanceof weka.gui.beans.Classifier) {
        m_connectedConfigurable = (weka.gui.beans.Classifier)source
      } else {
        if (m_log != null) {
          m_log.statusMessage("LearningCurve\$"+hashCode()+"|ERROR (see log for details)")
            m_log.logMessage("[LearningCurve] Connected configurable is not a classifier!!")
        }
      } 
    }
  }

  /** Add (optional) code to do something when you have been
   *  deregistered as a listener with a source for the named event
   */
  void disconnectionNotification(String eventName, Object source) { 
    if (eventName.equals("trainingSet")) {
      m_incomingConnection = null
    }

    if (eventName.equals("configuration")) {
      m_connectedConfigurable = null 
    }
  }

  /** Custom name of this component. Do something with it if you
   *  like. GroovyHelper already stores it and alters the icon text
   *  for you    */
  void setCustomName(String name) { }

  /** Custom name of this component. No need to return anything
   *  GroovyHelper already stores it and alters the icon text
   *  for you    */
  String getCustomName() { return null }

  /** Add code to return true when you are busy doing something
   */
  boolean isBusy() { return false }

  /** Store and use this logging object in order to post messages
   *  to the log
   */
  void setLog(Logger logger) { 
    m_log = logger
  }

  /** Store and use this Environment object in order to lookup and
   *  use the values of environment variables
   */
  void setEnvironment(Environment env) { 
    m_env = env
  }

  /** Stop any processing (if possible)
   */
  void stop() { }

  /** Alter or add to in order to tell the KnowlegeFlow
   *  whether, at the current time, the named event could
   *  be generated.
   */
  boolean eventGeneratable(String eventName) {
    if (eventName.equals("text")) { return true }
    if (eventName.equals("dataSet")) { return true }
    return false
  }

  /** Implement this to tell KnowledgeFlow about any methods
   *  that the user could invoke (i.e. to show a popup visualization
   *  or something).
   */
  Enumeration enumerateRequests() {
    Vector items = new Vector(0)
      items.add("Set options...")
      return items.elements()
  }

  /** Make the user-requested action happen here.
   */
  void performRequest(String requestName) {
    if (requestName.equals("Set options...")) {
      def swing = new SwingBuilder()
        def holderP1 = {
        swing.panel() {
          borderLayout()
          label (text:'Holdout set size: ', constraints:BorderLayout.WEST)
          hSize = textField(text:m_holdoutSize, columns:6, 
                            actionPerformed: {
                              m_holdoutSize = hSize.text 
                            }, constraints:BorderLayout.CENTER)
        }
      }

      def holderP2 = {
        swing.panel() {
          borderLayout()
          label (text:'Number of steps: ', constraints:BorderLayout.WEST)
          nSteps = textField(text:m_numSteps, columns:6,
                             actionPerformed: {
                               m_numSteps = nSteps.text 
                             }, constraints:BorderLayout.CENTER)
        } 
      }

      def holderP3 = {
        swing.panel() {
          borderLayout()
          label (text:'Step size: ', constraints:BorderLayout.WEST)
          sSize = textField(text:m_stepSize, columns:6,
                            actionPerformed: {
                              m_stepSize = sSize.text 
                            }, constraints:BorderLayout.CENTER)
        } 
      }

      def holderP4 = {
        swing.panel() {
          boxLayout(axis:BoxLayout.Y_AXIS)
          widget(holderP1())
          widget(holderP2())
          widget(holderP3()) 
        } 
      }

      def holderP5 = {
        swing.panel() {
          boxLayout(axis:BoxLayout.X_AXIS)
          button(text:'OK',
                 actionPerformed: {
                   m_holdoutSize = hSize.text
                   m_numSteps = nSteps.text
                   m_stepSize = sSize.text
                   dispose()
                 })
          button(text:"CANCEL",
                 actionPerformed: {
                   dispose() 
                 })
        } 
      }
      def frame = swing.frame(title:'Learning Curve Options', size:[300,600]) {
        borderLayout()
        widget(holderP4(), constraints:BorderLayout.NORTH)
        widget(holderP5(), constraints:BorderLayout.SOUTH)
      }
      frame.pack()
      frame.show()
      
    }
  }

  //--------------- Incoming events ------------------
  //--------------- Implement as necessary -----------

  void acceptTrainingSet(TrainingSetEvent e) {
    if (e.isStructureOnly()) {
      return 
    }

    StringBuffer buff = new StringBuffer()
    Instances insts = new Instances(e.getTrainingSet())
    insts.randomize(new Random(1))

    String hSize = m_holdoutSize
    String sSize = m_stepSize
    String nSteps = m_numSteps
    String classifierName = m_classifierName
    String classifierOptions = m_classifierOptions
    String[] splitOptions = null

    if (m_env != null) {
      try {
        hSize = m_env.substitute(hSize)
        sSize = m_env.substitute(sSize)
        nSteps = m_env.substitute(nSteps)
        if (classifierName != null && classifierName.length() > 0) {
          classifierName = m_env.substitute(classifierName) 
        }
        if (classifierOptions != null && classifierOptions.length() > 0) {
          classifierOptions = m_env.substitute(classifierOptions) 
        }
      } catch (Exception ex) { 
      } 
    }

    weka.classifiers.Classifier classifierToUse = null
    if (m_connectedConfigurable == null) {
      // try and instantiate from the supplied classifier name
      if (classifierName == null || classifierName.length() == 0) {
        if (m_log != null) {
          m_log.statusMessage("LearningCurve\$"+hashCode()+"|ERROR (see log for details)")
          m_log.logMessage("[LearningCurve] No classifier supplied!")
        }
        return 
      }
      if (classifierOptions != null && classifierOptions.length() > 0) {
        try {
          splitOptions = Utils.splitOptions(classifierOptions)
        } catch (Exception ex) {
          if (m_log != null) {
            m_log.statusMessage("LearningCurve\$"+hashCode()+"ERROR (see log for details)")
            m_log.logMessage("[LearningCurve] Problem parsing classifier options") 
          } 
          return
        }
      }
      classifierToUse = Classifier.forName(classifierName, splitOptions)
    } else {
        classifierToUse = m_connectedConfigurable.getClassifier() 
    }

    double hS = Double.parseDouble(hSize)
    hS /= 100
    int sS = Integer.parseInt(sSize)
    int nS = Integer.parseInt(nSteps)

    int numInHoldout = hS * insts.numInstances()
    Instances holdoutI = new Instances(insts, numInHoldout)
    for (int i = insts.numInstances() - numInHoldout; i < insts.numInstances(); i++) {
          holdoutI.add(insts.instance(i)) 
    }

    String classifierSetUpString = classifierToUse.class.toString() + " "
    if (classifierToUse instanceof OptionHandler) {
      classifierSetUpString += Utils.joinOptions(((OptionHandler)classifierToUse).getOptions()) 
    }

    if (m_log != null) {
      m_log.logMessage("[LearningCurve] Using classifier " + classifierSetUpString) 
    }

    // create the instances structure to hold the learning curve results
    Attribute setSize = new Attribute("NumInstances")
    Attribute aucA = new Attribute("PercentCorrect")
    FastVector atts = new FastVector()
    atts.addElement(setSize)
    atts.addElement(aucA)

    // The preceeding "__" tells the DataVisualizer to connect the points with lines
    Instances learnCInstances = new Instances("__Learning curve: " + classifierSetUpString, atts, 0)

    boolean done = false
    Instances training = new Instances(insts, 0)
    for (int i = 0; i < nS; i++) {
       if (m_log != null) {
         m_log.statusMessage("LearningCurve\$"+hashCode()+"|Processing set "+(i+1))
       }
     
       int numInThisStep = ((i + 1) * sS)
       if (numInThisStep >= (insts.numInstances() - numInHoldout)) {
         numInThisStep = (insts.numInstances() - numInHoldout)
         done = true 
       }
       for (int k = (i * sS); k < numInThisStep; k++) {
         training.add(insts.instance(k)) 
       }
    

       // train on this set
       Classifier newModel = Classifier.makeCopies(classifierToUse, 1)[0]
       newModel.buildClassifier(training)

       Evaluation eval = new Evaluation(holdoutI)
       eval.evaluateModel(newModel, holdoutI)
       double pc = (1.0 - eval.errorRate()) * 100.0
       //double auc = 1.0 - eval.errorRate();
       buff.append(""+numInThisStep+","+pc+"\n")
       //System.err.println(""+numInThisStep+","+auc+"\n")
       Instance newInst = new Instance(2)
       newInst.setValue(0, (double)numInThisStep)
       newInst.setValue(1, pc)
       learnCInstances.add(newInst)
       if (done) {
         break 
       } 
    }

    if (m_log != null) {
      m_log.statusMessage("LearningCurve\$"+hashCode()+"|Finished.")
    }
    //System.err.println(buff.toString())
    m_helper.notifyTextListeners(new TextEvent(this, buff.toString(), "learning curve"))
    m_helper.notifyDataSourceListeners(new DataSetEvent(this, learnCInstances))
  }

  void acceptTestSet(TestSetEvent e) { }

  void acceptDataSet(DataSetEvent e) { }

  void acceptInstance(InstanceEvent e) { }

  void acceptText(TextEvent e) { }

  void acceptClassifier(BatchClassifierEvent e) { }

  void acceptClassifier(IncrementalClassifierEvent e) { }

  void acceptClusterer(BatchClustererEvent e) { }

  void acceptGraph(GraphEvent e) { }

  void acceptDataPoint(ChartEvent e) { }

  void acceptDataSet(ThresholdDataEvent e) { }

  void acceptDataSet(VisualizableErrorEvent e) { }

  void acceptConfiguration(ConfigurationEvent e) { }

}
  • No labels