diff options
7 files changed, 109 insertions, 47 deletions
diff --git a/tests/sketch/src/com/android/gesture/GestureLibrary.java b/tests/sketch/src/com/android/gesture/GestureLibrary.java index 3e753e7..915b840 100644 --- a/tests/sketch/src/com/android/gesture/GestureLibrary.java +++ b/tests/sketch/src/com/android/gesture/GestureLibrary.java @@ -49,11 +49,11 @@ public class GestureLibrary { private static final String NAMESPACE = ""; public static final int SEQUENCE_INVARIANT = 1; - // when SEQUENCE_SENSITIVE is used, only single stroke gestures are allowed + // when SEQUENCE_SENSITIVE is used, only single stroke gestures are currently allowed public static final int SEQUENCE_SENSITIVE = 2; + // ORIENTATION_SENSITIVE and ORIENTATION_INVARIANT are only for SEQUENCE_SENSITIVE gestures public static final int ORIENTATION_INVARIANT = 1; - // ORIENTATION_SENSITIVE is only available for single stroke gestures public static final int ORIENTATION_SENSITIVE = 2; private int mSequenceType = SEQUENCE_SENSITIVE; @@ -77,8 +77,8 @@ public class GestureLibrary { } /** - * Specify whether the gesture library will handle orientation sensitive - * gestures. Use ORIENTATION_INVARIANT or ORIENTATION_SENSITIVE + * Specify how the gesture library will handle orientation. + * Use ORIENTATION_INVARIANT or ORIENTATION_SENSITIVE * * @param style */ @@ -114,8 +114,8 @@ public class GestureLibrary { * @return a list of predictions of possible entries for a given gesture */ public ArrayList<Prediction> recognize(Gesture gesture) { - Instance instance = Instance.createInstance(this, gesture, null); - return mClassifier.classify(this, instance); + Instance instance = Instance.createInstance(mSequenceType, gesture, null); + return mClassifier.classify(mSequenceType, instance.vector); } /** @@ -134,7 +134,7 @@ public class GestureLibrary { mEntryName2gestures.put(entryName, gestures); } gestures.add(gesture); - mClassifier.addInstance(Instance.createInstance(this, gesture, entryName)); + mClassifier.addInstance(Instance.createInstance(mSequenceType, gesture, entryName)); mChanged = true; } @@ -300,7 +300,7 @@ public class GestureLibrary { mGestures = null; } else if (localName.equals(GestureConstants.XML_TAG_GESTURE)) { mGestures.add(mCurrentGesture); - mClassifier.addInstance(Instance.createInstance(GestureLibrary.this, + mClassifier.addInstance(Instance.createInstance(mSequenceType, mCurrentGesture, mEntryName)); mCurrentGesture = null; } else if (localName.equals(GestureConstants.XML_TAG_STROKE)) { diff --git a/tests/sketch/src/com/android/gesture/GestureStroke.java b/tests/sketch/src/com/android/gesture/GestureStroke.java index 3555010..c2ebc17 100644 --- a/tests/sketch/src/com/android/gesture/GestureStroke.java +++ b/tests/sketch/src/com/android/gesture/GestureStroke.java @@ -244,4 +244,12 @@ public class GestureStroke { public void invalidate() { mCachedPath = null; } + + /** + * Compute an oriented bounding box of the stroke + * @return OrientedBoundingBox + */ + public OrientedBoundingBox computeOrientedBoundingBox() { + return GestureUtilities.computeOrientedBoundingBox(points); + } } diff --git a/tests/sketch/src/com/android/gesture/GestureUtilities.java b/tests/sketch/src/com/android/gesture/GestureUtilities.java index 2798616..92de987 100755 --- a/tests/sketch/src/com/android/gesture/GestureUtilities.java +++ b/tests/sketch/src/com/android/gesture/GestureUtilities.java @@ -26,7 +26,7 @@ import java.io.IOException; import static com.android.gesture.GestureConstants.*; -public final class GestureUtilities { +final class GestureUtilities { private static final int TEMPORAL_SAMPLING_RATE = 16; private GestureUtilities() { @@ -348,33 +348,31 @@ public final class GestureUtilities { /** * Calculate the cosine distance between two instances * - * @param in1 - * @param in2 + * @param vector1 + * @param vector2 * @return the distance between 0 and Math.PI */ - static double cosineDistance(Instance in1, Instance in2) { + static double cosineDistance(float[] vector1, float[] vector2) { float sum = 0; - float[] vector1 = in1.vector; - float[] vector2 = in2.vector; int len = vector1.length; for (int i = 0; i < len; i++) { sum += vector1[i] * vector2[i]; } - return Math.acos(sum / (in1.magnitude * in2.magnitude)); + return Math.acos(sum); } - public static OrientedBoundingBox computeOrientedBoundingBox(ArrayList<GesturePoint> pts) { + static OrientedBoundingBox computeOrientedBoundingBox(ArrayList<GesturePoint> pts) { GestureStroke stroke = new GestureStroke(pts); float[] points = temporalSampling(stroke, TEMPORAL_SAMPLING_RATE); return computeOrientedBoundingBox(points); } - public static OrientedBoundingBox computeOrientedBoundingBox(float[] points) { + static OrientedBoundingBox computeOrientedBoundingBox(float[] points) { float[] meanVector = computeCentroid(points); return computeOrientedBoundingBox(points, meanVector); } - public static OrientedBoundingBox computeOrientedBoundingBox(float[] points, float[] centroid) { + static OrientedBoundingBox computeOrientedBoundingBox(float[] points, float[] centroid) { Matrix tr = new Matrix(); tr.setTranslate(-centroid[0], -centroid[1]); tr.mapPoints(points); diff --git a/tests/sketch/src/com/android/gesture/Instance.java b/tests/sketch/src/com/android/gesture/Instance.java index 011d1fc..b2e030e 100755 --- a/tests/sketch/src/com/android/gesture/Instance.java +++ b/tests/sketch/src/com/android/gesture/Instance.java @@ -23,7 +23,7 @@ package com.android.gesture; class Instance { private static final int SEQUENCE_SAMPLE_SIZE = 16; - private static final int PATCH_SAMPLE_SIZE = 8; + private static final int PATCH_SAMPLE_SIZE = 16; private final static float[] ORIENTATIONS = { 0, 45, 90, 135, 180, -0, -45, -90, -135, -180 @@ -35,22 +35,26 @@ class Instance { // the label can be null final String label; - // the length of the vector - final float magnitude; - // the id of the instance final long id; - + private Instance(long id, float[] sample, String sampleName) { this.id = id; vector = sample; label = sampleName; + } + + private void normalize() { + float[] sample = vector; float sum = 0; int size = sample.length; for (int i = 0; i < size; i++) { sum += sample[i] * sample[i]; } - magnitude = (float) Math.sqrt(sum); + float magnitude = (float) Math.sqrt(sum); + for (int i = 0; i < size; i++) { + sample[i] /= magnitude; + } } /** @@ -60,21 +64,25 @@ class Instance { * @param label * @return the instance */ - static Instance createInstance(GestureLibrary gesturelib, Gesture gesture, String label) { + static Instance createInstance(int samplingType, Gesture gesture, String label) { float[] pts; - if (gesturelib.getGestureType() == GestureLibrary.SEQUENCE_SENSITIVE) { - pts = temporalSampler(gesturelib, gesture); + Instance instance; + if (samplingType == GestureLibrary.SEQUENCE_SENSITIVE) { + pts = temporalSampler(samplingType, gesture); + instance = new Instance(gesture.getID(), pts, label); + instance.normalize(); } else { pts = spatialSampler(gesture); + instance = new Instance(gesture.getID(), pts, label); } - return new Instance(gesture.getID(), pts, label); + return instance; } - + private static float[] spatialSampler(Gesture gesture) { return GestureUtilities.spatialSampling(gesture, PATCH_SAMPLE_SIZE); } - private static float[] temporalSampler(GestureLibrary gesturelib, Gesture gesture) { + private static float[] temporalSampler(int samplingType, Gesture gesture) { float[] pts = GestureUtilities.temporalSampling(gesture.getStrokes().get(0), SEQUENCE_SAMPLE_SIZE); float[] center = GestureUtilities.computeCentroid(pts); @@ -82,7 +90,7 @@ class Instance { orientation *= 180 / Math.PI; float adjustment = -orientation; - if (gesturelib.getOrientationStyle() == GestureLibrary.ORIENTATION_SENSITIVE) { + if (samplingType == GestureLibrary.ORIENTATION_SENSITIVE) { int count = ORIENTATIONS.length; for (int i = 0; i < count; i++) { float delta = ORIENTATIONS[i] - orientation; diff --git a/tests/sketch/src/com/android/gesture/InstanceLearner.java b/tests/sketch/src/com/android/gesture/InstanceLearner.java index 335719a..4495256 100644 --- a/tests/sketch/src/com/android/gesture/InstanceLearner.java +++ b/tests/sketch/src/com/android/gesture/InstanceLearner.java @@ -34,21 +34,21 @@ class InstanceLearner extends Learner { private static final String LOGTAG = "InstanceLearner"; @Override - ArrayList<Prediction> classify(GestureLibrary lib, Instance instance) { + ArrayList<Prediction> classify(int gestureType, float[] vector) { ArrayList<Prediction> predictions = new ArrayList<Prediction>(); ArrayList<Instance> instances = getInstances(); int count = instances.size(); TreeMap<String, Double> label2score = new TreeMap<String, Double>(); for (int i = 0; i < count; i++) { Instance sample = instances.get(i); - if (sample.vector.length != instance.vector.length) { + if (sample.vector.length != vector.length) { continue; } double distance; - if (lib.getGestureType() == GestureLibrary.SEQUENCE_SENSITIVE) { - distance = GestureUtilities.cosineDistance(sample, instance); + if (gestureType == GestureLibrary.SEQUENCE_SENSITIVE) { + distance = GestureUtilities.cosineDistance(sample.vector, vector); } else { - distance = GestureUtilities.squaredEuclideanDistance(sample.vector, instance.vector); + distance = GestureUtilities.squaredEuclideanDistance(sample.vector, vector); } double weight; if (distance == 0) { diff --git a/tests/sketch/src/com/android/gesture/Learner.java b/tests/sketch/src/com/android/gesture/Learner.java index b4183d2..15b2053 100755 --- a/tests/sketch/src/com/android/gesture/Learner.java +++ b/tests/sketch/src/com/android/gesture/Learner.java @@ -79,5 +79,5 @@ abstract class Learner { instances.removeAll(toDelete); } - abstract ArrayList<Prediction> classify(GestureLibrary library, Instance instance); + abstract ArrayList<Prediction> classify(int gestureType, float[] vector); } diff --git a/tests/sketch/src/com/android/gesture/LetterRecognizer.java b/tests/sketch/src/com/android/gesture/LetterRecognizer.java index 73151de..086aedf 100644 --- a/tests/sketch/src/com/android/gesture/LetterRecognizer.java +++ b/tests/sketch/src/com/android/gesture/LetterRecognizer.java @@ -20,12 +20,14 @@ import android.content.Context; import android.content.res.Resources; import android.util.Log; -import java.io.IOException; -import java.io.DataInputStream; import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.DataInputStream; +import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; public class LetterRecognizer { private static final String LOG_TAG = "LetterRecognizer"; @@ -37,8 +39,13 @@ public class LetterRecognizer { private final String[] mClasses; - private final int mInputCount; + private final int mPatchSize; + + static final String GESTURE_FILE_NAME = "letters.xml"; + private GestureLibrary mGestureLibrary; + private final static int ADJUST_RANGE = 3; + private static class SigmoidUnit { final float[] mWeights; @@ -62,11 +69,15 @@ public class LetterRecognizer { } private LetterRecognizer(int numOfInput, int numOfHidden, String[] classes) { - mInputCount = (int)Math.sqrt(numOfInput); + mPatchSize = (int)Math.sqrt(numOfInput); mHiddenLayer = new SigmoidUnit[numOfHidden]; mClasses = classes; mOutputLayer = new SigmoidUnit[classes.length]; } + + public void save() { + mGestureLibrary.save(); + } public static LetterRecognizer getLetterRecognizer(Context context, int type) { switch (type) { @@ -78,7 +89,12 @@ public class LetterRecognizer { } public ArrayList<Prediction> recognize(Gesture gesture) { - return classify(GestureUtilities.spatialSampling(gesture, mInputCount)); + float[] query = GestureUtilities.spatialSampling(gesture, mPatchSize); + ArrayList<Prediction> predictions = classify(query); + if (mGestureLibrary != null) { + adjustPrediction(gesture, predictions); + } + return predictions; } private ArrayList<Prediction> classify(float[] vector) { @@ -151,16 +167,16 @@ public class LetterRecognizer { SigmoidUnit[] outputLayer = new SigmoidUnit[oCount]; for (int i = 0; i < hCount; i++) { - float[] weights = new float[iCount]; - for (int j = 0; j < iCount; j++) { + float[] weights = new float[iCount + 1]; + for (int j = 0; j <= iCount; j++) { weights[j] = in.readFloat(); } hiddenLayer[i] = new SigmoidUnit(weights); } for (int i = 0; i < oCount; i++) { - float[] weights = new float[hCount]; - for (int j = 0; j < hCount; j++) { + float[] weights = new float[hCount + 1]; + for (int j = 0; j <= hCount; j++) { weights[j] = in.readFloat(); } outputLayer[i] = new SigmoidUnit(weights); @@ -170,11 +186,43 @@ public class LetterRecognizer { classifier.mOutputLayer = outputLayer; } catch (IOException e) { - Log.d(LOG_TAG, "Failed to load gestures:", e); + Log.d(LOG_TAG, "Failed to load handwriting data:", e); } finally { GestureUtilities.closeStream(in); } return classifier; } + + public void enablePersonalization(boolean enable) { + if (enable) { + mGestureLibrary = new GestureLibrary(GESTURE_FILE_NAME); + mGestureLibrary.setGestureType(GestureLibrary.SEQUENCE_INVARIANT); + mGestureLibrary.load(); + } else { + mGestureLibrary = null; + } + } + + public void addExample(String letter, Gesture example) { + mGestureLibrary.addGesture(letter, example); + } + + private void adjustPrediction(Gesture query, ArrayList<Prediction> predictions) { + ArrayList<Prediction> results = mGestureLibrary.recognize(query); + HashMap<String, Prediction> topNList = new HashMap<String, Prediction>(); + for (int j = 0; j < ADJUST_RANGE; j++) { + Prediction prediction = predictions.remove(0); + topNList.put(prediction.name, prediction); + } + int count = results.size(); + for (int j = count - 1; j >= 0 && !topNList.isEmpty(); j--) { + Prediction item = results.get(j); + Prediction original = topNList.get(item.name); + if (original != null) { + predictions.add(0, original); + topNList.remove(item.name); + } + } + } } |