diff options
Diffstat (limited to 'graphics/java/android/renderscript/ScriptGroup.java')
-rw-r--r-- | graphics/java/android/renderscript/ScriptGroup.java | 90 |
1 files changed, 69 insertions, 21 deletions
diff --git a/graphics/java/android/renderscript/ScriptGroup.java b/graphics/java/android/renderscript/ScriptGroup.java index 4efb45b..7afdb39 100644 --- a/graphics/java/android/renderscript/ScriptGroup.java +++ b/graphics/java/android/renderscript/ScriptGroup.java @@ -32,6 +32,12 @@ import java.util.ArrayList; * user supplied allocation. Inputs are similar but supply the * input of a kernal. Inputs bounds to a script are set directly * upon the script. + * <p> + * A ScriptGroup must contain at least one kernel. A ScriptGroup + * must contain only a single directed acyclic graph (DAG) of + * script kernels and connections. Attempting to create a + * ScriptGroup with multiple DAGs or attempting to create + * a cycle within a ScriptGroup will throw an exception. * **/ public final class ScriptGroup extends BaseObj { @@ -71,7 +77,7 @@ public final class ScriptGroup extends BaseObj { ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>(); ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>(); ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>(); - boolean mSeen; + int dagNumber; Node mNext; @@ -169,39 +175,70 @@ public final class ScriptGroup extends BaseObj { mRS = rs; } - private void validateRecurse(Node n, int depth) { - n.mSeen = true; - - //android.util.Log.v("RSR", " validateRecurse outputCount " + n.mOutputs.size()); - for (int ct=0; ct < n.mOutputs.size(); ct++) { - final ConnectLine cl = n.mOutputs.get(ct); + // do a DFS from original node, looking for original node + // any cycle that could be created must contain original node + private void validateCycle(Node target, Node original) { + for (int ct = 0; ct < target.mOutputs.size(); ct++) { + final ConnectLine cl = target.mOutputs.get(ct); if (cl.mToK != null) { Node tn = findNode(cl.mToK.mScript); - if (tn.mSeen) { + if (tn.equals(original)) { throw new RSInvalidStateException("Loops in group not allowed."); } - validateRecurse(tn, depth + 1); + validateCycle(tn, original); } if (cl.mToF != null) { Node tn = findNode(cl.mToF.mScript); - if (tn.mSeen) { + if (tn.equals(original)) { throw new RSInvalidStateException("Loops in group not allowed."); } - validateRecurse(tn, depth + 1); + validateCycle(tn, original); } } } - private void validate() { - //android.util.Log.v("RSR", "validate"); - + private void mergeDAGs(int valueUsed, int valueKilled) { for (int ct=0; ct < mNodes.size(); ct++) { - for (int ct2=0; ct2 < mNodes.size(); ct2++) { - mNodes.get(ct2).mSeen = false; + if (mNodes.get(ct).dagNumber == valueKilled) + mNodes.get(ct).dagNumber = valueUsed; + } + } + + private void validateDAGRecurse(Node n, int dagNumber) { + // combine DAGs if this node has been seen already + if (n.dagNumber != 0 && n.dagNumber != dagNumber) { + mergeDAGs(n.dagNumber, dagNumber); + return; + } + + n.dagNumber = dagNumber; + for (int ct=0; ct < n.mOutputs.size(); ct++) { + final ConnectLine cl = n.mOutputs.get(ct); + if (cl.mToK != null) { + Node tn = findNode(cl.mToK.mScript); + validateDAGRecurse(tn, dagNumber); + } + if (cl.mToF != null) { + Node tn = findNode(cl.mToF.mScript); + validateDAGRecurse(tn, dagNumber); } + } + } + + private void validateDAG() { + for (int ct=0; ct < mNodes.size(); ct++) { Node n = mNodes.get(ct); if (n.mInputs.size() == 0) { - validateRecurse(n, 0); + if (n.mOutputs.size() == 0 && mNodes.size() > 1) { + throw new RSInvalidStateException("Groups cannot contain unconnected scripts"); + } + validateDAGRecurse(n, ct+1); + } + } + int dagNumber = mNodes.get(0).dagNumber; + for (int ct=0; ct < mNodes.size(); ct++) { + if (mNodes.get(ct).dagNumber != dagNumber) { + throw new RSInvalidStateException("Multiple DAGs in group not allowed."); } } } @@ -274,7 +311,7 @@ public final class ScriptGroup extends BaseObj { Node nf = findNode(from); if (nf == null) { - throw new RSInvalidStateException("From kernel not found."); + throw new RSInvalidStateException("From script not found."); } Node nt = findNode(to.mScript); @@ -288,7 +325,7 @@ public final class ScriptGroup extends BaseObj { nf.mOutputs.add(cl); nt.mInputs.add(cl); - validate(); + validateCycle(nf, nf); return this; } @@ -309,7 +346,7 @@ public final class ScriptGroup extends BaseObj { Node nf = findNode(from); if (nf == null) { - throw new RSInvalidStateException("From kernel not found."); + throw new RSInvalidStateException("From script not found."); } Node nt = findNode(to); @@ -323,7 +360,7 @@ public final class ScriptGroup extends BaseObj { nf.mOutputs.add(cl); nt.mInputs.add(cl); - validate(); + validateCycle(nf, nf); return this; } @@ -336,6 +373,17 @@ public final class ScriptGroup extends BaseObj { * @return ScriptGroup The new ScriptGroup */ public ScriptGroup create() { + + if (mNodes.size() == 0) { + throw new RSInvalidStateException("Empty script groups are not allowed"); + } + + // reset DAG numbers in case we're building a second group + for (int ct=0; ct < mNodes.size(); ct++) { + mNodes.get(ct).dagNumber = 0; + } + validateDAG(); + ArrayList<IO> inputs = new ArrayList<IO>(); ArrayList<IO> outputs = new ArrayList<IO>(); |