Skip to content

Commit a2f4594

Browse files
authored
Serialized GMM no longer depends on command line annotation order (#1632)
Order annotations by the order in the model report.
1 parent 09b4cf7 commit a2f4594

File tree

4 files changed

+168
-10
lines changed

4 files changed

+168
-10
lines changed

protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java

+15-3
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,15 @@ public List<VariantDatum> getData() {
110110
return data;
111111
}
112112

113-
public void normalizeData(final boolean calculateMeans) {
113+
/**
114+
* Normalize annotations to mean 0 and standard deviation 1.
115+
* Order the variant annotations by the provided list {@code theOrder} or standard deviation.
116+
*
117+
* @param calculateMeans Boolean indicating whether or not to calculate the means
118+
* @param theOrder a list of integers specifying the desired annotation order. If this is null
119+
* annotations will get sorted in decreasing size of their standard deviations.
120+
*/
121+
public void normalizeData(final boolean calculateMeans, List<Integer> theOrder) {
114122
boolean foundZeroVarianceAnnotation = false;
115123
for( int iii = 0; iii < meanVector.length; iii++ ) {
116124
final double theMean, theSTD;
@@ -150,15 +158,19 @@ public void normalizeData(final boolean calculateMeans) {
150158

151159
// re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line
152160
// standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here
153-
final List<Integer> theOrder = calculateSortOrder(meanVector);
161+
// or use the serialized report's annotation order via the argument theOrder
162+
if (theOrder == null){
163+
theOrder = calculateSortOrder(meanVector);
164+
}
154165
annotationKeys = reorderList(annotationKeys, theOrder);
155166
varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder));
156167
meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder));
157168
for( final VariantDatum datum : data ) {
158169
datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder));
159170
datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder));
160171
}
161-
logger.info("Annotations are now ordered by their information content: " + annotationKeys.toString());
172+
logger.info("Annotation order is: " + annotationKeys.toString());
173+
162174
}
163175

164176
public double[] getMeanVector() {

protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java

+34-7
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
5353

54+
import com.google.common.annotations.VisibleForTesting;
5455
import htsjdk.variant.variantcontext.Allele;
5556
import org.broadinstitute.gatk.utils.commandline.*;
5657
import org.broadinstitute.gatk.engine.CommandLineGATK;
@@ -312,6 +313,9 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
312313
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
313314
protected Boolean TRUST_ALL_POLYMORPHIC = false;
314315

316+
@VisibleForTesting
317+
protected List<Integer> annotationOrder = null;
318+
315319
/////////////////////////////
316320
// Private Member Variables
317321
/////////////////////////////
@@ -372,18 +376,15 @@ public void initialize() {
372376
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
373377
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
374378
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
375-
final int numAnnotations = dataManager.annotationKeys.size();
376379

377-
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
378-
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
379-
}
380+
orderAndValidateAnnotations(anMeansTable, dataManager.annotationKeys);
380381

381382
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
382383
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
383384
dataManager.setNormalization(anMeans, anStdDevs);
384385

385-
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
386-
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
386+
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, annotationOrder.size());
387+
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, annotationOrder.size());
387388
}
388389

389390
final Set<VCFHeaderLine> hInfo = new HashSet<>();
@@ -401,6 +402,32 @@ public void initialize() {
401402

402403
}
403404

405+
/**
406+
* Order and validate annotations according to the annotations in the serialized model
407+
* Annotations on the command line must be the same as those in the model report or this will throw an exception.
408+
* Sets the {@code annotationOrder} list to map from command line order to the model report's order.
409+
* n^2 because we typically use 7 or less annotations.
410+
* @param annotationTable GATKReportTable of annotations read from the serialized model file
411+
*/
412+
protected void orderAndValidateAnnotations(final GATKReportTable annotationTable, final List<String> annotationKeys){
413+
annotationOrder = new ArrayList<Integer>(annotationKeys.size());
414+
415+
for (int i = 0; i < annotationTable.getNumRows(); i++){
416+
String serialAnno = (String)annotationTable.get(i, "Annotation");
417+
for (int j = 0; j < annotationKeys.size(); j++) {
418+
if (serialAnno.equals( annotationKeys.get(j) )){
419+
annotationOrder.add(j);
420+
}
421+
}
422+
}
423+
424+
if(annotationOrder.size() != annotationTable.getNumRows() || annotationOrder.size() != annotationKeys.size()) {
425+
final String errorMsg = "Annotations specified on the command line:"+annotationKeys.toString() +" do not match annotations in the model report:"+inputModel;
426+
throw new UserException.CommandLineException(errorMsg);
427+
}
428+
429+
}
430+
404431

405432
//---------------------------------------------------------------------------------------------------------------
406433
//
@@ -518,7 +545,7 @@ public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum )
518545
for (int i = 1; i <= max_attempts; i++) {
519546
try {
520547
dataManager.setData(reduceSum);
521-
dataManager.normalizeData(inputModel.isEmpty()); // Each data point is now (x - mean) / standard deviation
548+
dataManager.normalizeData(inputModel.isEmpty(), annotationOrder); // Each data point is now (x - mean) / standard deviation
522549

523550
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
524551
final List<VariantDatum> negativeTrainingData;

protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java

+71
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151

5252
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
5353

54+
import org.broadinstitute.gatk.utils.exceptions.UserException;
55+
import org.broadinstitute.gatk.utils.exceptions.UserException.CommandLineException;
5456
import org.broadinstitute.gatk.utils.variant.VCIterable;
5557
import org.broadinstitute.gatk.engine.walkers.WalkerTest;
5658
import htsjdk.variant.variantcontext.VariantContext;
@@ -60,6 +62,7 @@
6062
import org.testng.annotations.Test;
6163

6264
import java.io.File;
65+
import java.io.IOException;
6366
import java.util.Arrays;
6467
import java.util.List;
6568

@@ -390,5 +393,73 @@ private void setPDFsForDeletion( final List<File> walkerOutputFiles ) {
390393
new File(outputFile.getAbsolutePath() + ".pdf").deleteOnExit();
391394
}
392395
}
396+
397+
@Test
398+
public void testVQSRAnnotationOrder() throws IOException {
399+
final String inputFile = privateTestDir + "oneSNP.vcf";
400+
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
401+
final String annoOrderRecal = privateTestDir + "anno_order.recal";
402+
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
403+
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
404+
final String base = "-R " + b37KGReference +
405+
" -T VariantRecalibrator" +
406+
" -input " + inputFile +
407+
" -L 1:110201699" +
408+
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
409+
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR"+
410+
" --recal_file " + annoOrderRecal +
411+
" -tranchesFile " + annoOrderTranches +
412+
" --input_model " + exacModelReportFilename +
413+
" -ignoreAllFilters -mode SNP" +
414+
" --no_cmdline_in_header" ;
415+
416+
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
417+
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
418+
419+
List<File> outputFiles = executeTest("testVQSRAnnotationOrder", spec).getFirst();
420+
setPDFsForDeletion(outputFiles);
421+
422+
423+
final String base2 = "-R " + b37KGReference +
424+
" -T VariantRecalibrator" +
425+
" -input " + inputFile +
426+
" -L 1:110201699" +
427+
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
428+
" -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an FS "+
429+
" --recal_file " + annoOrderRecal +
430+
" -tranchesFile " + annoOrderTranches +
431+
" --input_model " + exacModelReportFilename +
432+
" -ignoreAllFilters -mode SNP" +
433+
" --no_cmdline_in_header" ;
434+
435+
final WalkerTestSpec spec2 = new WalkerTestSpec(base2, 1, Arrays.asList(goodMd5));
436+
spec2.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
437+
outputFiles = executeTest("testVQSRAnnotationOrder2", spec2).getFirst();
438+
setPDFsForDeletion(outputFiles);
439+
}
440+
441+
@Test(expectedExceptions={RuntimeException.class, CommandLineException.class})
442+
public void testVQSRAnnotationMismatch() throws IOException {
443+
final String inputFile = privateTestDir + "oneSNP.vcf";
444+
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
445+
final String annoOrderRecal = privateTestDir + "anno_order.recal";
446+
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
447+
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
448+
final String base = "-R " + b37KGReference +
449+
" -T VariantRecalibrator" +
450+
" -input " + inputFile +
451+
" -L 1:110201699" +
452+
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
453+
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an BaseQRankSum"+
454+
" --recal_file " + annoOrderRecal +
455+
" -tranchesFile " + annoOrderTranches +
456+
" --input_model " + exacModelReportFilename +
457+
" -ignoreAllFilters -mode SNP" +
458+
" --no_cmdline_in_header" ;
459+
460+
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
461+
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
462+
executeTest("testVQSRAnnotationMismatch", spec).getFirst();
463+
}
393464
}
394465

protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java

+48
Original file line numberDiff line numberDiff line change
@@ -276,4 +276,52 @@ private GaussianMixtureModel getBadGMM(){
276276
return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
277277
}
278278

279+
@Test
280+
public void testAnnotationOrderAndValidate() {
281+
final VariantRecalibrator vqsr = new VariantRecalibrator();
282+
final List<String> annotationList = new ArrayList<>();
283+
annotationList.add("QD");
284+
annotationList.add("FS");
285+
annotationList.add("ReadPosRankSum");
286+
annotationList.add("MQ");
287+
annotationList.add("MQRankSum");
288+
annotationList.add("SOR");
289+
290+
double[] meanVector = {16.13, 2.45, 0.37, 59.08, 0.14, 0.91};
291+
final String columnName = "Mean";
292+
final String formatString = "%.3f";
293+
GATKReportTable annotationTable = vqsr.makeVectorTable("AnnotationMeans", "Mean for each annotation, used to normalize data", annotationList, meanVector, columnName, formatString);
294+
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
295+
296+
for (int i = 0; i < vqsr.annotationOrder.size(); i++){
297+
Assert.assertEquals(i, (int)vqsr.annotationOrder.get(i));
298+
}
299+
300+
annotationList.remove(0);
301+
annotationList.add("QD");
302+
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
303+
for (int i = 0; i < vqsr.annotationOrder.size(); i++) {
304+
if (i == 0) {
305+
Assert.assertEquals(annotationList.size()-1, (int)vqsr.annotationOrder.get(i));
306+
} else {
307+
Assert.assertEquals(i - 1, (int)vqsr.annotationOrder.get(i));
308+
}
309+
}
310+
311+
final List<String> annotationList2 = new ArrayList<>();
312+
annotationList2.add("ReadPosRankSum");
313+
annotationList2.add("MQRankSum");
314+
annotationList2.add("MQ");
315+
annotationList2.add("SOR");
316+
annotationList2.add("QD");
317+
annotationList2.add("FS");
318+
319+
final VariantRecalibrator vqsr2 = new VariantRecalibrator();
320+
vqsr2.orderAndValidateAnnotations(annotationTable, annotationList2);
321+
for (int i = 0; i < vqsr2.annotationOrder.size(); i++){
322+
Assert.assertEquals(annotationList.get(vqsr.annotationOrder.get(i)), annotationList2.get(vqsr2.annotationOrder.get(i)));
323+
}
324+
325+
}
326+
279327
}

0 commit comments

Comments
 (0)