Index: src/InitTweak/FixInit.cc
===================================================================
--- src/InitTweak/FixInit.cc	(revision 92b3de18e700712b3d68077d08943feb95d6e9e4)
+++ src/InitTweak/FixInit.cc	(revision dc2334c815416c4950e1e3b716f8372f8f3f1701)
@@ -185,5 +185,5 @@
 		};
 
-		class FixCopyCtors final : public GenPoly::PolyMutator {
+		class FixCopyCtors final : public WithStmtsToAdd, public WithShortCircuiting, public WithVisitorRef<FixCopyCtors> {
 		  public:
 			FixCopyCtors( UnqCount & unqCount ) : unqCount( unqCount ){}
@@ -192,9 +192,7 @@
 			static void fixCopyCtors( std::list< Declaration * > &translationUnit, UnqCount & unqCount );
 
-			typedef GenPoly::PolyMutator Parent;
-			using Parent::mutate;
-			virtual Expression * mutate( ImplicitCopyCtorExpr * impCpCtorExpr ) override;
-			virtual Expression * mutate( UniqueExpr * unqExpr ) override;
-			virtual Expression * mutate( StmtExpr * stmtExpr ) override;
+			Expression * postmutate( ImplicitCopyCtorExpr * impCpCtorExpr );
+			void premutate( StmtExpr * stmtExpr );
+			void premutate( UniqueExpr * unqExpr );
 
 			UnqCount & unqCount;
@@ -312,5 +310,5 @@
 
 		void FixCopyCtors::fixCopyCtors( std::list< Declaration * > & translationUnit, UnqCount & unqCount ) {
-			FixCopyCtors fixer( unqCount );
+			PassVisitor<FixCopyCtors> fixer( unqCount );
 			mutateAll( translationUnit, fixer );
 		}
@@ -512,8 +510,7 @@
 		}
 
-		Expression * FixCopyCtors::mutate( ImplicitCopyCtorExpr * impCpCtorExpr ) {
+		Expression * FixCopyCtors::postmutate( ImplicitCopyCtorExpr * impCpCtorExpr ) {
 			CP_CTOR_PRINT( std::cerr << "FixCopyCtors: " << impCpCtorExpr << std::endl; )
 
-			impCpCtorExpr = strict_dynamic_cast< ImplicitCopyCtorExpr * >( Parent::mutate( impCpCtorExpr ) );
 			std::list< ObjectDecl * > & tempDecls = impCpCtorExpr->get_tempDecls();
 			std::list< ObjectDecl * > & returnDecls = impCpCtorExpr->get_returnDecls();
@@ -522,8 +519,8 @@
 			// add all temporary declarations and their constructors
 			for ( ObjectDecl * obj : tempDecls ) {
-				stmtsToAdd.push_back( new DeclStmt( noLabels, obj ) );
+				stmtsToAddBefore.push_back( new DeclStmt( noLabels, obj ) );
 			} // for
 			for ( ObjectDecl * obj : returnDecls ) {
-				stmtsToAdd.push_back( new DeclStmt( noLabels, obj ) );
+				stmtsToAddBefore.push_back( new DeclStmt( noLabels, obj ) );
 			} // for
 
@@ -533,5 +530,4 @@
 			} // for
 
-			// xxx - update to work with multiple return values
 			ObjectDecl * returnDecl = returnDecls.empty() ? nullptr : returnDecls.front();
 			Expression * callExpr = impCpCtorExpr->get_callExpr();
@@ -566,21 +562,23 @@
 		}
 
-		Expression * FixCopyCtors::mutate( StmtExpr * stmtExpr ) {
+		void FixCopyCtors::premutate( StmtExpr * stmtExpr ) {
 			// function call temporaries should be placed at statement-level, rather than nested inside of a new statement expression,
 			// since temporaries can be shared across sub-expressions, e.g.
 			//   [A, A] f();
 			//   g([A] x, [A] y);
-			//   f(g());
+			//   g(f());
 			// f is executed once, so the return temporary is shared across the tuple constructors for x and y.
+			// Explicitly mutating children instead of mutating the inner compound statment forces the temporaries to be added
+			// to the outer context, rather than inside of the statement expression.
+			visit_children = false;
 			std::list< Statement * > & stmts = stmtExpr->get_statements()->get_kids();
 			for ( Statement *& stmt : stmts ) {
-				stmt = stmt->acceptMutator( *this );
+				stmt = stmt->acceptMutator( *visitor );
 			} // for
-			// stmtExpr = strict_dynamic_cast< StmtExpr * >( Parent::mutate( stmtExpr ) );
 			assert( stmtExpr->get_result() );
 			Type * result = stmtExpr->get_result();
 			if ( ! result->isVoid() ) {
 				for ( ObjectDecl * obj : stmtExpr->get_returnDecls() ) {
-					stmtsToAdd.push_back( new DeclStmt( noLabels, obj ) );
+					stmtsToAddBefore.push_back( new DeclStmt( noLabels, obj ) );
 				} // for
 				// add destructors after current statement
@@ -589,8 +587,7 @@
 				} // for
 				// must have a non-empty body, otherwise it wouldn't have a result
-				CompoundStmt * body = stmtExpr->get_statements();
-				assert( ! body->get_kids().empty() );
+				assert( ! stmts.empty() );
 				assert( ! stmtExpr->get_returnDecls().empty() );
-				body->get_kids().push_back( new ExprStmt( noLabels, new VariableExpr( stmtExpr->get_returnDecls().front() ) ) );
+				stmts.push_back( new ExprStmt( noLabels, new VariableExpr( stmtExpr->get_returnDecls().front() ) ) );
 				stmtExpr->get_returnDecls().clear();
 				stmtExpr->get_dtors().clear();
@@ -598,12 +595,11 @@
 			assert( stmtExpr->get_returnDecls().empty() );
 			assert( stmtExpr->get_dtors().empty() );
-			return stmtExpr;
-		}
-
-		Expression * FixCopyCtors::mutate( UniqueExpr * unqExpr ) {
+		}
+
+		void FixCopyCtors::premutate( UniqueExpr * unqExpr ) {
+			visit_children = false;
 			unqCount[ unqExpr->get_id() ]--;
 			static std::unordered_map< int, std::list< Statement * > > dtors;
 			static std::unordered_map< int, UniqueExpr * > unqMap;
-			static std::unordered_set< int > addDeref;
 			// has to be done to clean up ImplicitCopyCtorExpr nodes, even when this node was skipped in previous passes
 			if ( unqMap.count( unqExpr->get_id() ) ) {
@@ -616,30 +612,16 @@
 					stmtsToAddAfter.splice( stmtsToAddAfter.end(), dtors[ unqExpr->get_id() ] );
 				}
-				if ( addDeref.count( unqExpr->get_id() ) ) {
-					// other UniqueExpr was dereferenced because it was an lvalue return, so this one should be too
-					return UntypedExpr::createDeref( unqExpr );
-				}
-				return unqExpr;
-			}
-			FixCopyCtors fixer( unqCount );
+				return;
+			}
+			PassVisitor<FixCopyCtors> fixer( unqCount );
 			unqExpr->set_expr( unqExpr->get_expr()->acceptMutator( fixer ) ); // stmtexprs contained should not be separately fixed, so this must occur after the lookup
-			stmtsToAdd.splice( stmtsToAdd.end(), fixer.stmtsToAdd );
+			stmtsToAddBefore.splice( stmtsToAddBefore.end(), fixer.pass.stmtsToAddBefore );
 			unqMap[unqExpr->get_id()] = unqExpr;
 			if ( unqCount[ unqExpr->get_id() ] == 0 ) {  // insert destructor after the last use of the unique expression
 				stmtsToAddAfter.splice( stmtsToAddAfter.end(), dtors[ unqExpr->get_id() ] );
 			} else { // remember dtors for last instance of unique expr
-				dtors[ unqExpr->get_id() ] = fixer.stmtsToAddAfter;
-			}
-			if ( UntypedExpr * deref = dynamic_cast< UntypedExpr * >( unqExpr->get_expr() ) ) {
-				// unique expression is now a dereference, because the inner expression is an lvalue returning function call.
-				// Normalize the expression by dereferencing the unique expression, rather than the inner expression
-				// (i.e. move the dereference out a level)
-				assert( getFunctionName( deref ) == "*?" );
-				unqExpr->set_expr( getCallArg( deref, 0 ) );
-				getCallArg( deref, 0 ) = unqExpr;
-				addDeref.insert( unqExpr->get_id() );
-				return deref;
-			}
-			return unqExpr;
+				dtors[ unqExpr->get_id() ] = fixer.pass.stmtsToAddAfter;
+			}
+			return;
 		}
 
