Skip to content

Commit

Permalink
Fixes for localIds and Nary function resolution (#1356)
Browse files Browse the repository at this point in the history
  • Loading branch information
JPercival committed Apr 23, 2024
1 parent ec84b9e commit 70722a1
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.commons.lang3.StringUtils;
import org.cqframework.cql.cql2elm.model.*;
import org.cqframework.cql.cql2elm.model.invocation.*;
import org.cqframework.cql.elm.IdObjectFactory;
import org.cqframework.cql.elm.tracking.TrackBack;
import org.cqframework.cql.elm.tracking.Trackable;
import org.hl7.cql.model.*;
Expand Down Expand Up @@ -42,17 +43,18 @@ public enum SignatureLevel {
All
}

public LibraryBuilder(LibraryManager libraryManager, ObjectFactory objectFactory) {
public LibraryBuilder(LibraryManager libraryManager, IdObjectFactory objectFactory) {
this(null, libraryManager, objectFactory);
}

public LibraryBuilder(NamespaceInfo namespaceInfo, LibraryManager libraryManager, ObjectFactory objectFactory) {
public LibraryBuilder(NamespaceInfo namespaceInfo, LibraryManager libraryManager, IdObjectFactory objectFactory) {
this.libraryManager = Objects.requireNonNull(libraryManager);
this.of = Objects.requireNonNull(objectFactory);

this.namespaceInfo = namespaceInfo; // Note: allowed to be null, implies global namespace
this.modelManager = libraryManager.getModelManager();
this.typeBuilder = new TypeBuilder(of, this.modelManager);
this.systemFunctionResolver = new SystemFunctionResolver(this, this.of);

this.library = of.createLibrary()
.withSchemaIdentifier(of.createVersionedIdentifier()
Expand Down Expand Up @@ -100,7 +102,7 @@ public List<CqlCompilerException> getExceptions() {
return exceptions;
}

public ObjectFactory getObjectFactory() {
public IdObjectFactory getObjectFactory() {
return of;
}

Expand All @@ -112,7 +114,7 @@ public LibraryManager getLibraryManager() {

private final Map<String, ResultWithPossibleError<NamedTypeSpecifier>> nameTypeSpecifiers = new HashMap<>();
private final Map<String, CompiledLibrary> libraries = new LinkedHashMap<>();
private final SystemFunctionResolver systemFunctionResolver = new SystemFunctionResolver(this);
private final SystemFunctionResolver systemFunctionResolver;
private final Stack<String> expressionContext = new Stack<>();
private final ExpressionDefinitionContextStack expressionDefinitions = new ExpressionDefinitionContextStack();
private final Stack<FunctionDef> functionDefs = new Stack<>();
Expand Down Expand Up @@ -142,7 +144,7 @@ public ConversionMap getConversionMap() {
return conversionMap;
}

private final ObjectFactory of;
private final IdObjectFactory of;
private final org.hl7.cql_annotations.r1.ObjectFactory af = new org.hl7.cql_annotations.r1.ObjectFactory();
private boolean listTraversal = true;
private final CqlCompilerOptions options;
Expand Down Expand Up @@ -2454,6 +2456,7 @@ public Expression resolveIdentifier(String identifier, boolean mustResolve) {
if (element instanceof IncludeDef) {
checkLiteralContext();
LibraryRef libraryRef = new LibraryRef();
libraryRef.setLocalId(of.nextId());
libraryRef.setLibraryName(((IncludeDef) element).getLocalIdentifier());
return libraryRef;
}
Expand Down Expand Up @@ -3009,14 +3012,14 @@ private Expression resolveQueryResultElement(String identifier) {
QueryContext query = peekQueryContext();
if (query.inSortClause() && !query.isSingular()) {
if (identifier.equals("$this")) {
IdentifierRef result = new IdentifierRef().withName(identifier);
IdentifierRef result = of.createIdentifierRef().withName(identifier);
result.setResultType(query.getResultElementType());
return result;
}

PropertyResolution resolution = resolveProperty(query.getResultElementType(), identifier, false);
if (resolution != null) {
IdentifierRef result = new IdentifierRef().withName(resolution.getName());
IdentifierRef result = of.createIdentifierRef().withName(resolution.getName());
result.setResultType(resolution.getType());
return applyTargetMap(result, resolution.getTargetMap());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import org.cqframework.cql.cql2elm.model.PropertyResolution;
import org.cqframework.cql.cql2elm.model.SystemModel;
import org.cqframework.cql.cql2elm.model.invocation.*;
import org.cqframework.cql.elm.IdObjectFactory;
import org.hl7.elm.r1.*;

public class SystemFunctionResolver {
private final ObjectFactory of = new ObjectFactory();
private final IdObjectFactory of;
private final LibraryBuilder builder;

public SystemFunctionResolver(LibraryBuilder builder) {
public SystemFunctionResolver(LibraryBuilder builder, IdObjectFactory of) {
this.builder = builder;
this.of = builder.getObjectFactory();
}

public Invocation resolveSystemFunction(FunctionRef fun) {
Expand Down Expand Up @@ -220,15 +222,12 @@ public Invocation resolveSystemFunction(FunctionRef fun) {
}

case "Contains":
case "Except":
case "Expand":
case "In":
case "Includes":
case "IncludedIn":
case "Intersect":
case "ProperIncludes":
case "ProperIncludedIn":
case "Union": {
case "ProperIncludedIn": {
return resolveBinary(fun);
}

Expand All @@ -241,8 +240,10 @@ public Invocation resolveSystemFunction(FunctionRef fun) {
return resolveUnary(fun);
}

// Nullological Functions
case "Coalesce": {
case "Coalesce":
case "Intersect":
case "Union":
case "Except": {
return resolveNary(fun);
}

Expand Down Expand Up @@ -699,93 +700,60 @@ private ConvertInvocation resolveConvert(FunctionRef fun) {

// General Function Support

private UnaryExpressionInvocation resolveUnary(FunctionRef fun) {
UnaryExpression operator = null;
@SuppressWarnings("unchecked")
private <T extends Expression> T createExpression(FunctionRef fun) {
try {
Class<?> clazz = Class.forName("org.hl7.elm.r1." + fun.getName());
if (UnaryExpression.class.isAssignableFrom(clazz)) {
operator = (UnaryExpression) clazz.getConstructor().newInstance();
checkNumberOfOperands(fun, 1);
operator.setOperand(fun.getOperand().get(0));
UnaryExpressionInvocation invocation = new UnaryExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}
return (T) of.getClass().getMethod("create" + fun.getName()).invoke(of);
} catch (Exception e) {
// Do nothing but fall through
throw new CqlInternalException(
String.format("Could not create instance of Element \"%s\"", fun.getName()),
!fun.getTrackbacks().isEmpty() ? fun.getTrackbacks().get(0) : null,
e);
}
return null;
}

private UnaryExpressionInvocation resolveUnary(FunctionRef fun) {
UnaryExpression operator = createExpression(fun);
checkNumberOfOperands(fun, 1);
operator.setOperand(fun.getOperand().get(0));
UnaryExpressionInvocation invocation = new UnaryExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}

private BinaryExpressionInvocation resolveBinary(FunctionRef fun) {
BinaryExpression operator = null;
try {
Class<?> clazz = Class.forName("org.hl7.elm.r1." + fun.getName());
if (BinaryExpression.class.isAssignableFrom(clazz)) {
operator = (BinaryExpression) clazz.getConstructor().newInstance();
checkNumberOfOperands(fun, 2);
operator.getOperand().addAll(fun.getOperand());
BinaryExpressionInvocation invocation = new BinaryExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}
} catch (Exception e) {
// Do nothing but fall through
}
return null;
BinaryExpression operator = createExpression(fun);
checkNumberOfOperands(fun, 2);
operator.getOperand().addAll(fun.getOperand());
BinaryExpressionInvocation invocation = new BinaryExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}

private TernaryExpressionInvocation resolveTernary(FunctionRef fun) {
TernaryExpression operator = null;
try {
Class<?> clazz = Class.forName("org.hl7.elm.r1." + fun.getName());
if (TernaryExpression.class.isAssignableFrom(clazz)) {
operator = (TernaryExpression) clazz.getConstructor().newInstance();
checkNumberOfOperands(fun, 3);
operator.getOperand().addAll(fun.getOperand());
TernaryExpressionInvocation invocation = new TernaryExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}
} catch (Exception e) {
// Do nothing but fall through
}
return null;
TernaryExpression operator = createExpression(fun);
checkNumberOfOperands(fun, 3);
operator.getOperand().addAll(fun.getOperand());
TernaryExpressionInvocation invocation = new TernaryExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}

private NaryExpressionInvocation resolveNary(FunctionRef fun) {
NaryExpression operator = null;
try {
Class<?> clazz = Class.forName("org.hl7.elm.r1." + fun.getName());
if (NaryExpression.class.isAssignableFrom(clazz)) {
operator = (NaryExpression) clazz.getConstructor().newInstance();
operator.getOperand().addAll(fun.getOperand());
NaryExpressionInvocation invocation = new NaryExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}
} catch (Exception e) {
// Do nothing but fall through
}
return null;
NaryExpression operator = createExpression(fun);
operator.getOperand().addAll(fun.getOperand());
NaryExpressionInvocation invocation = new NaryExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}

private AggregateExpressionInvocation resolveAggregate(FunctionRef fun) {
AggregateExpression operator = null;
try {
Class<?> clazz = Class.forName("org.hl7.elm.r1." + fun.getName());
if (AggregateExpression.class.isAssignableFrom(clazz)) {
operator = (AggregateExpression) clazz.getConstructor().newInstance();
checkNumberOfOperands(fun, 1);
operator.setSource(fun.getOperand().get(0));
AggregateExpressionInvocation invocation = new AggregateExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}
} catch (Exception e) {
// Do nothing but fall through
}
return null;
AggregateExpression operator = createExpression(fun);
checkNumberOfOperands(fun, 1);
operator.setSource(fun.getOperand().get(0));
AggregateExpressionInvocation invocation = new AggregateExpressionInvocation(operator);
builder.resolveInvocation("System", fun.getName(), invocation);
return invocation;
}

private void checkNumberOfOperands(FunctionRef fun, int expectedOperands) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.Objects;
import java.util.Set;
import org.cqframework.cql.cql2elm.model.QueryContext;
import org.cqframework.cql.elm.IdObjectFactory;
import org.cqframework.cql.gen.cqlParser;
import org.hl7.cql.model.*;
import org.hl7.elm.r1.*;
Expand All @@ -14,7 +15,7 @@
* Created by Bryn on 12/27/2016.
*/
public class SystemMethodResolver {
private final ObjectFactory of;
private final IdObjectFactory of;
private final Cql2ElmVisitor visitor;
private final LibraryBuilder builder;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import java.util.List;
import javax.xml.namespace.QName;
import org.cqframework.cql.cql2elm.model.Model;
import org.cqframework.cql.elm.IdObjectFactory;
import org.hl7.cql.model.*;
import org.hl7.elm.r1.ObjectFactory;
import org.hl7.elm.r1.ParameterTypeSpecifier;
import org.hl7.elm.r1.TupleElementDefinition;
import org.hl7.elm.r1.TypeSpecifier;
import org.hl7.elm_modelinfo.r1.ModelInfo;

public class TypeBuilder {

private ObjectFactory of;
private IdObjectFactory of;
private ModelResolver mr;

public static class InternalModelResolver implements ModelResolver {
Expand All @@ -28,12 +28,12 @@ public Model getModel(String modelName) {
}
}

public TypeBuilder(ObjectFactory of, ModelResolver mr) {
public TypeBuilder(IdObjectFactory of, ModelResolver mr) {
this.of = of;
this.mr = mr;
}

public TypeBuilder(ObjectFactory of, ModelManager modelManager) {
public TypeBuilder(IdObjectFactory of, ModelManager modelManager) {
this(of, new InternalModelResolver(modelManager));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.cqframework.cql.cql2elm.model.Chunk;
import org.cqframework.cql.cql2elm.model.FunctionHeader;
import org.cqframework.cql.cql2elm.model.Model;
import org.cqframework.cql.elm.IdObjectFactory;
import org.cqframework.cql.elm.tracking.TrackBack;
import org.cqframework.cql.elm.tracking.Trackable;
import org.cqframework.cql.gen.cqlBaseVisitor;
Expand All @@ -32,7 +33,7 @@
* Common functionality used by {@link CqlPreprocessor} and {@link Cql2ElmVisitor}
*/
public class CqlPreprocessorElmCommonVisitor extends cqlBaseVisitor<Object> {
protected final ObjectFactory of;
protected final IdObjectFactory of;
protected final org.hl7.cql_annotations.r1.ObjectFactory af = new org.hl7.cql_annotations.r1.ObjectFactory();
private boolean implicitContextCreated = false;
private String currentContext = "Unfiltered";
Expand All @@ -42,7 +43,6 @@ public class CqlPreprocessorElmCommonVisitor extends cqlBaseVisitor<Object> {
protected LibraryInfo libraryInfo = new LibraryInfo();
private boolean annotate = false;
private boolean detailedErrors = false;
private int nextLocalId = 1;
private boolean locate = false;
private boolean resultTypes = false;
private boolean dateRangeOptimization = false;
Expand Down Expand Up @@ -92,6 +92,16 @@ public Object visit(ParseTree tree) {
// ERROR:
try {
o = super.visit(tree);
if (o instanceof Element) {
Element element = (Element) o;
if (element.getLocalId() == null) {
throw new CqlInternalException(
String.format(
"Internal translator error. 'localId' was not assigned for Element \"%s\"",
element.getClass().getName()),
getTrackBack(tree));
}
}
} catch (CqlIncludeException e) {
CqlCompilerException translatorException =
new CqlCompilerException(e.getMessage(), getTrackBack(tree), e);
Expand Down Expand Up @@ -289,9 +299,6 @@ private void popChunk(ParseTree tree, Object o, boolean pushedChunk) {
Chunk chunk = chunks.pop();
if (o instanceof Element) {
Element element = (Element) o;
if (element.getLocalId() == null) {
element.setLocalId(Integer.toString(getNextLocalId()));
}
chunk.setElement(element);

if (!(tree instanceof cqlParser.LibraryContext)) {
Expand Down Expand Up @@ -779,10 +786,6 @@ public static String stripLeading(String s) {
return s.substring(index);
}

public int getNextLocalId() {
return nextLocalId++;
}

private void addExpression(Expression expression) {
expressions.add(expression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import com.tngtech.archunit.core.domain.JavaClasses;
import com.tngtech.archunit.core.importer.ClassFileImporter;
import org.cqframework.cql.cql2elm.model.LibraryRef;
import org.cqframework.cql.elm.IdObjectFactory;
import org.hl7.elm.r1.Element;
import org.hl7.elm.r1.ObjectFactory;
import org.junit.Test;

public class ArchitectureTest {
Expand All @@ -27,7 +27,7 @@ public void ensureNoDirectElmConstruction() {
.should()
.onlyBeCalled()
.byClassesThat()
.areAssignableTo(ObjectFactory.class)
.areAssignableTo(IdObjectFactory.class)
.because("ELM classes should never be instantiated directly, "
+ "use an ObjectFactory that ensures that "
+ "the classes are initialized and tracked correctly.")
Expand Down
Loading

0 comments on commit 70722a1

Please sign in to comment.