diff --git a/src/main/java/org/ohdsi/webapi/Constants.java b/src/main/java/org/ohdsi/webapi/Constants.java index 2069ed108d..422c0b2546 100644 --- a/src/main/java/org/ohdsi/webapi/Constants.java +++ b/src/main/java/org/ohdsi/webapi/Constants.java @@ -82,6 +82,7 @@ interface Params { String EXECUTABLE_FILE_NAME = "executableFilename"; String GENERATION_ID = "generation_id"; String DESIGN_HASH = "design_hash"; + String DEMOGRAPHIC_STATS = "demographic_stats"; } interface Variables { diff --git a/src/main/java/org/ohdsi/webapi/cohortcharacterization/GenerateLocalCohortTasklet.java b/src/main/java/org/ohdsi/webapi/cohortcharacterization/GenerateLocalCohortTasklet.java index 48a112c37b..74cc329644 100644 --- a/src/main/java/org/ohdsi/webapi/cohortcharacterization/GenerateLocalCohortTasklet.java +++ b/src/main/java/org/ohdsi/webapi/cohortcharacterization/GenerateLocalCohortTasklet.java @@ -1,6 +1,7 @@ package org.ohdsi.webapi.cohortcharacterization; import org.ohdsi.webapi.cohortdefinition.CohortDefinition; +import org.ohdsi.webapi.cohortdefinition.CohortDefinitionDetails; import org.ohdsi.webapi.cohortdefinition.CohortGenerationRequestBuilder; import org.ohdsi.webapi.cohortdefinition.CohortGenerationUtils; import org.ohdsi.webapi.generationcache.GenerationCacheHelper; @@ -32,6 +33,7 @@ import static org.ohdsi.webapi.Constants.Params.SOURCE_ID; import static org.ohdsi.webapi.Constants.Params.TARGET_TABLE; +import static org.ohdsi.webapi.Constants.Params.DEMOGRAPHIC_STATS; public class GenerateLocalCohortTasklet implements StoppableTasklet { @@ -89,14 +91,14 @@ public RepeatStatus execute(StepContribution contribution, ChunkContext chunkCon if (useAsyncCohortGeneration) { List executions = cohortDefinitions.stream() .map(cd -> - CompletableFuture.supplyAsync(() -> generateCohort(cd, source, resultSchema, targetTable), + CompletableFuture.supplyAsync(() -> generateCohort(cd, source, resultSchema, targetTable), Executors.newSingleThreadExecutor() ) ).collect(Collectors.toList()); CompletableFuture.allOf(executions.toArray(new CompletableFuture[]{})).join(); } else { CompletableFuture.runAsync(() -> - cohortDefinitions.stream().forEach(cd -> generateCohort(cd, source, resultSchema, targetTable)), + cohortDefinitions.stream().forEach(cd -> generateCohort(cd, source, resultSchema, targetTable)), Executors.newSingleThreadExecutor() ).join(); } @@ -113,8 +115,8 @@ private Object generateCohort(CohortDefinition cd, Source source, String resultS sessionId, resultSchema ); - - int designHash = this.generationCacheHelper.computeHash(cd.getDetails().getExpression()); + CohortDefinitionDetails details = cd.getDetails(); + int designHash = this.generationCacheHelper.computeHash(details.getExpression()); CohortGenerationUtils.insertInclusionRules(cd, source, designHash, resultSchema, sessionId, cancelableJdbcTemplate); try { diff --git a/src/main/java/org/ohdsi/webapi/cohortcharacterization/converter/BaseCcDTOToCcEntityConverter.java b/src/main/java/org/ohdsi/webapi/cohortcharacterization/converter/BaseCcDTOToCcEntityConverter.java index 509a0c2c53..abaf33dba0 100644 --- a/src/main/java/org/ohdsi/webapi/cohortcharacterization/converter/BaseCcDTOToCcEntityConverter.java +++ b/src/main/java/org/ohdsi/webapi/cohortcharacterization/converter/BaseCcDTOToCcEntityConverter.java @@ -2,7 +2,6 @@ import com.odysseusinc.arachne.commons.utils.ConverterUtils; import org.apache.commons.lang3.StringUtils; -import org.ohdsi.analysis.CohortMetadata; import org.ohdsi.analysis.Utils; import org.ohdsi.analysis.cohortcharacterization.design.CcResultType; import org.ohdsi.webapi.cohortcharacterization.domain.CcStrataConceptSetEntity; @@ -18,7 +17,6 @@ import org.ohdsi.webapi.tag.domain.Tag; import org.springframework.beans.factory.annotation.Autowired; -import java.util.List; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; diff --git a/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationInfo.java b/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationInfo.java index f79d3fd1db..39cf7e77db 100644 --- a/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationInfo.java +++ b/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationInfo.java @@ -84,6 +84,21 @@ public CohortGenerationInfo(CohortDefinition definition, Integer sourceId) @ManyToOne(fetch = FetchType.LAZY) @JoinColumn(name = "created_by_id") private UserEntity createdBy; + + @Column(name = "cc_generate_id") + private Long ccGenerateId; + + // If true, then demographic has been selected. + @Column(name = "is_demographic") + private boolean isDemographic; + + public boolean isDemographic() { + return isDemographic; + } + + public void setIsDemographic(boolean isDemographic) { + this.isDemographic = isDemographic; + } public CohortGenerationInfoId getId() { return id; @@ -187,4 +202,13 @@ public void setCreatedBy(UserEntity createdBy) { public UserEntity getCreatedBy() { return createdBy; } + + public Long getCcGenerateId() { + return ccGenerateId; + } + + public void setCcGenerateId(Long ccGenerateId) { + this.ccGenerateId = ccGenerateId; + } + } diff --git a/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationRequest.java b/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationRequest.java index 647fb9251e..4fb31c5116 100644 --- a/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationRequest.java +++ b/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationRequest.java @@ -11,7 +11,8 @@ public class CohortGenerationRequest { private String targetSchema; private Integer targetId; - public CohortGenerationRequest(CohortExpression expression, Source source, String sessionId, Integer targetId, String targetSchema) { + public CohortGenerationRequest(CohortExpression expression, Source source, String sessionId, Integer targetId, + String targetSchema) { this.expression = expression; this.source = source; diff --git a/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationUtils.java b/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationUtils.java index 232466c464..38597e07c7 100644 --- a/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationUtils.java +++ b/src/main/java/org/ohdsi/webapi/cohortdefinition/CohortGenerationUtils.java @@ -70,7 +70,7 @@ public static String[] buildGenerationSql(CohortGenerationRequest request) { "results_database_schema.cohort_inclusion_stats", "results_database_schema.cohort_summary_stats", "results_database_schema.cohort_censor_stats", - "results_database_schema.cohort_inclusion" + "results_database_schema.cohort_inclusion" }, new String[] { COHORT_CACHE, diff --git a/src/main/java/org/ohdsi/webapi/cohortdefinition/GenerateCohortTasklet.java b/src/main/java/org/ohdsi/webapi/cohortdefinition/GenerateCohortTasklet.java index 47173e8b45..6b8549d236 100644 --- a/src/main/java/org/ohdsi/webapi/cohortdefinition/GenerateCohortTasklet.java +++ b/src/main/java/org/ohdsi/webapi/cohortdefinition/GenerateCohortTasklet.java @@ -16,21 +16,39 @@ package org.ohdsi.webapi.cohortdefinition; import org.ohdsi.circe.helper.ResourceHelper; +import org.ohdsi.cohortcharacterization.CCQueryBuilder; +import org.ohdsi.sql.BigQuerySparkTranslate; import org.ohdsi.sql.SqlRender; import org.ohdsi.sql.SqlSplit; import org.ohdsi.sql.SqlTranslate; +import org.ohdsi.webapi.cohortcharacterization.domain.CohortCharacterizationEntity; import org.ohdsi.webapi.common.generation.CancelableTasklet; +import org.ohdsi.webapi.common.generation.GenerationUtils; +import org.ohdsi.webapi.feanalysis.domain.FeAnalysisEntity; +import org.ohdsi.webapi.feanalysis.repository.FeAnalysisEntityRepository; import org.ohdsi.webapi.generationcache.GenerationCacheHelper; +import org.ohdsi.webapi.shiro.Entities.UserRepository; import org.ohdsi.webapi.source.Source; import org.ohdsi.webapi.source.SourceService; import org.ohdsi.webapi.util.CancelableJdbcTemplate; import org.ohdsi.webapi.util.SessionUtils; +import org.ohdsi.webapi.util.SourceUtils; import org.slf4j.LoggerFactory; import org.springframework.batch.core.scope.context.ChunkContext; import org.springframework.batch.core.step.tasklet.StoppableTasklet; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.transaction.support.TransactionTemplate; +import com.google.common.collect.ImmutableList; +import com.odysseusinc.arachne.commons.types.DBMSType; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.HashSet; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.ohdsi.webapi.Constants.Params.*; @@ -44,54 +62,151 @@ public class GenerateCohortTasklet extends CancelableTasklet implements Stoppabl private final GenerationCacheHelper generationCacheHelper; private final CohortDefinitionRepository cohortDefinitionRepository; private final SourceService sourceService; + private final FeAnalysisEntityRepository feAnalysisRepository; + + public GenerateCohortTasklet(final CancelableJdbcTemplate jdbcTemplate, final TransactionTemplate transactionTemplate, + final GenerationCacheHelper generationCacheHelper, + final CohortDefinitionRepository cohortDefinitionRepository, final SourceService sourceService) { + super(LoggerFactory.getLogger(GenerateCohortTasklet.class), jdbcTemplate, transactionTemplate); + this.generationCacheHelper = generationCacheHelper; + this.cohortDefinitionRepository = cohortDefinitionRepository; + this.sourceService = sourceService; + this.feAnalysisRepository = null; + } public GenerateCohortTasklet( final CancelableJdbcTemplate jdbcTemplate, final TransactionTemplate transactionTemplate, final GenerationCacheHelper generationCacheHelper, final CohortDefinitionRepository cohortDefinitionRepository, - final SourceService sourceService + final SourceService sourceService, final FeAnalysisEntityRepository feAnalysisRepository ) { super(LoggerFactory.getLogger(GenerateCohortTasklet.class), jdbcTemplate, transactionTemplate); this.generationCacheHelper = generationCacheHelper; this.cohortDefinitionRepository = cohortDefinitionRepository; this.sourceService = sourceService; + this.feAnalysisRepository = feAnalysisRepository; } @Override protected String[] prepareQueries(ChunkContext chunkContext, CancelableJdbcTemplate jdbcTemplate) { + Map jobParams = chunkContext.getStepContext().getJobParameters(); + + String[] defaultQueries = prepareQueriesDefault(jobParams, jdbcTemplate); + + Boolean demographicStat = jobParams.get(DEMOGRAPHIC_STATS) == null ? null + : Boolean.valueOf((String) jobParams.get(DEMOGRAPHIC_STATS)); + + if (demographicStat != null && demographicStat.booleanValue()) { + String[] demographicsQueries = prepareQueriesDemographic(chunkContext, jdbcTemplate); + return Stream.concat(Arrays.stream(defaultQueries), Arrays.stream(demographicsQueries)).toArray(String[]::new); + } + + return defaultQueries; + } + + private String[] prepareQueriesDemographic(ChunkContext chunkContext, CancelableJdbcTemplate jdbcTemplate) { + Map jobParams = chunkContext.getStepContext().getJobParameters(); + CohortCharacterizationEntity cohortCharacterization = new CohortCharacterizationEntity(); + + Integer cohortDefinitionId = Integer.valueOf(jobParams.get(COHORT_DEFINITION_ID).toString()); + CohortDefinition cohortDefinition = cohortDefinitionRepository.findOneWithDetail(cohortDefinitionId); + + cohortCharacterization.setCohortDefinitions(new HashSet<>(Arrays.asList(cohortDefinition))); + + // Get FE Analysis Demographic (Gender, Age, Race,) + Set feAnalysis = feAnalysisRepository.findByListIds(Arrays.asList(70, 72, 74, 77)); + +// Set ccFeAnalysis = feAnalysis.stream().map(a -> { +// CcFeAnalysisEntity ccA = new CcFeAnalysisEntity(); +// ccA.setCohortCharacterization(cohortCharacterization); +// ccA.setFeatureAnalysis(a); +// return ccA; +// }).collect(Collectors.toSet()); + + cohortCharacterization.setFeatureAnalyses(feAnalysis); + + final Long jobId = chunkContext.getStepContext().getStepExecution().getJobExecution().getId(); + + final Integer sourceId = Integer.valueOf(jobParams.get(SOURCE_ID).toString()); + final Source source = sourceService.findBySourceId(sourceId); + + final String cohortTable = jobParams.get(TARGET_TABLE).toString(); + final String sessionId = jobParams.get(SESSION_ID).toString(); + + final String tempSchema = SourceUtils.getTempQualifier(source); + + boolean includeAnnual = false; + boolean includeTemporal = false; + + CCQueryBuilder ccQueryBuilder = new CCQueryBuilder(cohortCharacterization, cohortTable, sessionId, + SourceUtils.getCdmQualifier(source), SourceUtils.getResultsQualifier(source), + SourceUtils.getVocabularyQualifier(source), tempSchema, jobId); + String sql = ccQueryBuilder.build(); + + /* + * There is an issue with temp tables on sql server: Temp tables scope is + * session or stored procedure. To execute PreparedStatement sql server + * uses stored procedure sp_executesql and this is the reason why + * multiple PreparedStatements cannot share the same local temporary + * table. + * + * On the other side, temp tables cannot be re-used in the same + * PreparedStatement, e.g. temp table cannot be created, used, dropped and + * created again in the same PreparedStatement because sql optimizator + * detects object already exists and fails. When is required to re-use + * temp table it should be separated to several PreparedStatements. + * + * An option to use global temp tables also doesn't work since such tables + * can be not supported / disabled. + * + * Therefore, there are two ways: - either precisely group SQLs into + * statements so that temp tables aren't re-used in a single statement, - + * or use ‘permanent temporary tables’ + * + * The second option looks better since such SQL could be exported and + * executed manually, which is not the case with the first option. + */ + if (ImmutableList.of(DBMSType.MS_SQL_SERVER.getOhdsiDB(), DBMSType.PDW.getOhdsiDB()) + .contains(source.getSourceDialect())) { + sql = sql.replaceAll("#", tempSchema + "." + sessionId + "_").replaceAll("tempdb\\.\\.", ""); + } + if (source.getSourceDialect().equals("spark")) { + try { + sql = BigQuerySparkTranslate.sparkHandleInsert(sql, source.getSourceConnection()); + } catch (SQLException e) { + e.printStackTrace(); + } + } + + final String translatedSql = SqlTranslate.translateSql(sql, source.getSourceDialect(), sessionId, tempSchema); + return SqlSplit.splitSql(translatedSql); + } + + private String[] prepareQueriesDefault(Map jobParams, CancelableJdbcTemplate jdbcTemplate) { + Integer cohortDefinitionId = Integer.valueOf(jobParams.get(COHORT_DEFINITION_ID).toString()); + Integer sourceId = Integer.parseInt(jobParams.get(SOURCE_ID).toString()); + String targetSchema = jobParams.get(TARGET_DATABASE_SCHEMA).toString(); + String sessionId = jobParams.getOrDefault(SESSION_ID, SessionUtils.sessionId()).toString(); + + CohortDefinition cohortDefinition = cohortDefinitionRepository.findOneWithDetail(cohortDefinitionId); + Source source = sourceService.findBySourceId(sourceId); + + CohortGenerationRequestBuilder generationRequestBuilder = new CohortGenerationRequestBuilder(sessionId, + targetSchema); + + int designHash = this.generationCacheHelper.computeHash(cohortDefinition.getDetails().getExpression()); + CohortGenerationUtils.insertInclusionRules(cohortDefinition, source, designHash, targetSchema, sessionId, + jdbcTemplate); + + GenerationCacheHelper.CacheResult res = generationCacheHelper.computeCacheIfAbsent(cohortDefinition, source, + generationRequestBuilder, + (resId, sqls) -> generationCacheHelper.runCancelableCohortGeneration(jdbcTemplate, stmtCancel, sqls)); - Map jobParams = chunkContext.getStepContext().getJobParameters(); - - Integer cohortDefinitionId = Integer.valueOf(jobParams.get(COHORT_DEFINITION_ID).toString()); - Integer sourceId = Integer.parseInt(jobParams.get(SOURCE_ID).toString()); - String targetSchema = jobParams.get(TARGET_DATABASE_SCHEMA).toString(); - String sessionId = jobParams.getOrDefault(SESSION_ID, SessionUtils.sessionId()).toString(); - - CohortDefinition cohortDefinition = cohortDefinitionRepository.findOneWithDetail(cohortDefinitionId); - Source source = sourceService.findBySourceId(sourceId); - - CohortGenerationRequestBuilder generationRequestBuilder = new CohortGenerationRequestBuilder( - sessionId, - targetSchema - ); - - int designHash = this.generationCacheHelper.computeHash(cohortDefinition.getDetails().getExpression()); - CohortGenerationUtils.insertInclusionRules(cohortDefinition, source, designHash, targetSchema, sessionId, jdbcTemplate); - - GenerationCacheHelper.CacheResult res = generationCacheHelper.computeCacheIfAbsent( - cohortDefinition, - source, - generationRequestBuilder, - (resId, sqls) -> generationCacheHelper.runCancelableCohortGeneration(jdbcTemplate, stmtCancel, sqls) - ); - - String sql = SqlRender.renderSql( - copyGenerationIntoCohortTableSql, - new String[]{ RESULTS_DATABASE_SCHEMA, COHORT_DEFINITION_ID, DESIGN_HASH }, - new String[]{ targetSchema, cohortDefinition.getId().toString(), res.getIdentifier().toString() } - ); - sql = SqlTranslate.translateSql(sql, source.getSourceDialect()); - return SqlSplit.splitSql(sql); + String sql = SqlRender.renderSql(copyGenerationIntoCohortTableSql, + new String[] { RESULTS_DATABASE_SCHEMA, COHORT_DEFINITION_ID, DESIGN_HASH }, + new String[] { targetSchema, cohortDefinition.getId().toString(), res.getIdentifier().toString() }); + sql = SqlTranslate.translateSql(sql, source.getSourceDialect()); + return SqlSplit.splitSql(sql); } } diff --git a/src/main/java/org/ohdsi/webapi/cohortdefinition/GenerationJobExecutionListener.java b/src/main/java/org/ohdsi/webapi/cohortdefinition/GenerationJobExecutionListener.java index aca36adc80..fdda4f1958 100644 --- a/src/main/java/org/ohdsi/webapi/cohortdefinition/GenerationJobExecutionListener.java +++ b/src/main/java/org/ohdsi/webapi/cohortdefinition/GenerationJobExecutionListener.java @@ -88,6 +88,7 @@ public void afterJob(JobExecution je) { CohortGenerationInfo info = findBySourceId(df, sourceId); setExecutionDurationIfPossible(je, info); info.setStatus(GenerationStatus.COMPLETE); + info.setCcGenerateId(je.getId()); if (je.getStatus() == BatchStatus.FAILED || je.getStatus() == BatchStatus.STOPPED) { info.setIsValid(false); diff --git a/src/main/java/org/ohdsi/webapi/cohortdefinition/InclusionRuleReport.java b/src/main/java/org/ohdsi/webapi/cohortdefinition/InclusionRuleReport.java index 78533be7b7..1d0ec8514c 100644 --- a/src/main/java/org/ohdsi/webapi/cohortdefinition/InclusionRuleReport.java +++ b/src/main/java/org/ohdsi/webapi/cohortdefinition/InclusionRuleReport.java @@ -17,6 +17,8 @@ import java.util.List; +import org.ohdsi.webapi.cohortcharacterization.report.Report; + /** * * @author Chris Knoll @@ -42,5 +44,10 @@ public static class InclusionRuleStatistic public Summary summary; public List inclusionRuleStats; public String treemapData; + public List demographicsStats; + + public Float prevalenceThreshold = 0.01f; + public Boolean showEmptyResults = false; + public int count = 0; } diff --git a/src/main/java/org/ohdsi/webapi/cohortdefinition/converter/CohortGenerationInfoToCohortGenerationInfoDTOConverter.java b/src/main/java/org/ohdsi/webapi/cohortdefinition/converter/CohortGenerationInfoToCohortGenerationInfoDTOConverter.java index 2e8ede00b5..a393957f6a 100644 --- a/src/main/java/org/ohdsi/webapi/cohortdefinition/converter/CohortGenerationInfoToCohortGenerationInfoDTOConverter.java +++ b/src/main/java/org/ohdsi/webapi/cohortdefinition/converter/CohortGenerationInfoToCohortGenerationInfoDTOConverter.java @@ -22,6 +22,8 @@ public CohortGenerationInfoDTO convert(CohortGenerationInfo info) { dto.setStartTime(info.getStartTime()); dto.setStatus(info.getStatus()); dto.setIsValid(info.isIsValid()); + dto.setCcGenerateId(info.getCcGenerateId()); + dto.setIsDemographic(info.isDemographic()); return dto; } diff --git a/src/main/java/org/ohdsi/webapi/cohortdefinition/dto/CohortGenerationInfoDTO.java b/src/main/java/org/ohdsi/webapi/cohortdefinition/dto/CohortGenerationInfoDTO.java index f3611a1d61..2cafd7a0b2 100644 --- a/src/main/java/org/ohdsi/webapi/cohortdefinition/dto/CohortGenerationInfoDTO.java +++ b/src/main/java/org/ohdsi/webapi/cohortdefinition/dto/CohortGenerationInfoDTO.java @@ -44,6 +44,18 @@ public class CohortGenerationInfoDTO { private Long recordCount; private UserDTO createdBy; + + private Long ccGenerateId; + private boolean isDemographic; + + public boolean getIsDemographic() { + return isDemographic; + } + + public void setIsDemographic(boolean isDemographic) { + this.isDemographic = isDemographic; + } + public CohortGenerationInfoId getId() { return id; @@ -124,4 +136,13 @@ public UserDTO getCreatedBy() { public void setCreatedBy(UserDTO createdBy) { this.createdBy = createdBy; } + + public Long getCcGenerateId() { + return ccGenerateId; + } + + public void setCcGenerateId(Long ccGenerateId) { + this.ccGenerateId = ccGenerateId; + } + } diff --git a/src/main/java/org/ohdsi/webapi/feanalysis/repository/FeAnalysisEntityRepository.java b/src/main/java/org/ohdsi/webapi/feanalysis/repository/FeAnalysisEntityRepository.java index 3864addd60..97d4551b8c 100644 --- a/src/main/java/org/ohdsi/webapi/feanalysis/repository/FeAnalysisEntityRepository.java +++ b/src/main/java/org/ohdsi/webapi/feanalysis/repository/FeAnalysisEntityRepository.java @@ -5,6 +5,7 @@ import org.springframework.data.repository.query.Param; import java.util.List; +import java.util.Set; public interface FeAnalysisEntityRepository extends BaseFeAnalysisEntityRepository { @Query("Select fe FROM FeAnalysisEntity fe WHERE fe.name LIKE ?1 ESCAPE '\\'") @@ -12,4 +13,7 @@ public interface FeAnalysisEntityRepository extends BaseFeAnalysisEntityReposito @Query("SELECT COUNT(fe) FROM FeAnalysisEntity fe WHERE fe.name = :name and fe.id <> :id") int getCountFeWithSameName(@Param("id") Integer id, @Param("name") String name); + + @Query("SELECT fe FROM FeAnalysisEntity fe WHERE fe.id IN :ids") + Set findByListIds(@Param("ids") List ids); } diff --git a/src/main/java/org/ohdsi/webapi/service/CohortDefinitionService.java b/src/main/java/org/ohdsi/webapi/service/CohortDefinitionService.java index 4de4e872e6..83e8bac7fa 100644 --- a/src/main/java/org/ohdsi/webapi/service/CohortDefinitionService.java +++ b/src/main/java/org/ohdsi/webapi/service/CohortDefinitionService.java @@ -25,17 +25,28 @@ import org.commonmark.parser.Parser; import org.commonmark.renderer.html.HtmlRenderer; import org.ohdsi.analysis.Utils; +import org.ohdsi.analysis.cohortcharacterization.design.StandardFeatureAnalysisType; import org.ohdsi.circe.check.Checker; import org.ohdsi.circe.cohortdefinition.CohortExpression; import org.ohdsi.circe.cohortdefinition.CohortExpressionQueryBuilder; import org.ohdsi.circe.cohortdefinition.ConceptSet; import org.ohdsi.circe.cohortdefinition.printfriendly.MarkdownRender; +import org.ohdsi.circe.helper.ResourceHelper; +import org.ohdsi.featureExtraction.FeatureExtraction; import org.ohdsi.sql.SqlRender; +import org.ohdsi.sql.SqlTranslate; import org.ohdsi.webapi.Constants; import org.ohdsi.webapi.check.CheckResult; import org.ohdsi.webapi.check.checker.cohort.CohortChecker; import org.ohdsi.webapi.check.warning.Warning; import org.ohdsi.webapi.check.warning.WarningUtils; +import org.ohdsi.webapi.cohortcharacterization.dto.CcDistributionStat; +import org.ohdsi.webapi.cohortcharacterization.dto.CcPrevalenceStat; +import org.ohdsi.webapi.cohortcharacterization.dto.CcResult; +import org.ohdsi.webapi.cohortcharacterization.dto.ExecutionResultRequest; +import org.ohdsi.webapi.cohortcharacterization.report.AnalysisItem; +import org.ohdsi.webapi.cohortcharacterization.report.AnalysisResultItem; +import org.ohdsi.webapi.cohortcharacterization.report.Report; import org.ohdsi.webapi.cohortdefinition.CleanupCohortTasklet; import org.ohdsi.webapi.cohortdefinition.CohortDefinition; import org.ohdsi.webapi.cohortdefinition.CohortDefinitionDetails; @@ -55,6 +66,8 @@ import org.ohdsi.webapi.common.generation.GenerateSqlResult; import org.ohdsi.webapi.common.sensitiveinfo.CohortGenerationSensitiveInfoService; import org.ohdsi.webapi.conceptset.ConceptSetExport; +import org.ohdsi.webapi.feanalysis.domain.FeAnalysisEntity; +import org.ohdsi.webapi.feanalysis.repository.FeAnalysisEntityRepository; import org.ohdsi.webapi.job.JobExecutionResource; import org.ohdsi.webapi.job.JobTemplate; import org.ohdsi.webapi.security.PermissionService; @@ -67,10 +80,6 @@ import org.ohdsi.webapi.tag.domain.HasTags; import org.ohdsi.webapi.tag.dto.TagNameListRequestDTO; import org.ohdsi.webapi.util.*; -import org.ohdsi.webapi.util.ExceptionUtils; -import org.ohdsi.webapi.util.NameUtils; -import org.ohdsi.webapi.util.PreparedStatementRenderer; -import org.ohdsi.webapi.util.SessionUtils; import org.ohdsi.webapi.versioning.domain.CohortVersion; import org.ohdsi.webapi.versioning.domain.Version; import org.ohdsi.webapi.versioning.domain.VersionBase; @@ -124,19 +133,26 @@ import java.util.Collections; import java.util.Date; import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import javax.ws.rs.core.Response.ResponseBuilder; +import static org.ohdsi.analysis.cohortcharacterization.design.CcResultType.DISTRIBUTION; +import static org.ohdsi.analysis.cohortcharacterization.design.CcResultType.PREVALENCE; import static org.ohdsi.webapi.Constants.Params.COHORT_DEFINITION_ID; import static org.ohdsi.webapi.Constants.Params.JOB_NAME; import static org.ohdsi.webapi.Constants.Params.SOURCE_ID; import org.ohdsi.webapi.source.SourceService; +import org.ohdsi.webapi.sqlrender.SourceAwareSqlRender; + import static org.ohdsi.webapi.util.SecurityUtils.whitelist; /** @@ -150,6 +166,36 @@ public class CohortDefinitionService extends AbstractDaoService implements HasTags { private static final CohortExpressionQueryBuilder queryBuilder = new CohortExpressionQueryBuilder(); + private static final int DEMOGRAPHIC_MODE = 2; + private static final String DEMOGRAPHIC_DOMAIN = "DEMOGRAPHICS"; + private static final String[] PARAMETERS_RESULTS_FILTERED = { "cohort_characterization_generation_id", + "threshold_level", "analysis_ids", "cohort_ids", "vocabulary_schema" }; + private final List executionPrevalenceHeaderLines = new ArrayList() { + { + add(new String[] { "Analysis ID", "Analysis name", "Strata ID", "Strata name", "Cohort ID", "Cohort name", + "Covariate ID", "Covariate name", "Covariate short name", "Count", "Percent" }); + } + }; + private final List executionDistributionHeaderLines = new ArrayList() { + { + add(new String[] { "Analysis ID", "Analysis name", "Strata ID", "Strata name", "Cohort ID", "Cohort name", + "Covariate ID", "Covariate name", "Covariate short name", "Value field", + "Missing Means Zero", "Count", "Avg", "StdDev", "Min", "P10", "P25", "Median", "P75", "P90", + "Max" }); + } + }; + private final List executionComparativeHeaderLines = new ArrayList() { + { + add(new String[] { "Analysis ID", "Analysis name", "Strata ID", "Strata name", "Target cohort ID", + "Target cohort name", "Comparator cohort ID", "Comparator cohort name", "Covariate ID", + "Covariate name", "Covariate short name", "Target count", "Target percent", + "Comparator count", "Comparator percent", "Std. Diff Of Mean" }); + } + }; + private Map prespecAnalysisMap = FeatureExtraction + .getNameToPrespecAnalysis(); + private final String QUERY_RESULTS = ResourceHelper + .GetResourceAsString("/resources/cohortcharacterizations/sql/queryResults.sql"); @Autowired private CohortDefinitionRepository cohortDefinitionRepository; @@ -205,7 +251,13 @@ public class CohortDefinitionService extends AbstractDaoService implements HasTa @Autowired private VersionService versionService; - @Value("${security.defaultGlobalReadPermissions}") + @Autowired + private FeAnalysisEntityRepository feAnalysisRepository; + + @Autowired + private SourceAwareSqlRender sourceAwareSqlRender; + + @Value("${security.defaultGlobalReadPermissions}") private boolean defaultGlobalReadPermissions; private final MarkdownRender markdownPF = new MarkdownRender(); @@ -293,6 +345,220 @@ private List getInclusionRuleStatist return getSourceJdbcTemplate(source).query(psr.getSql(), psr.getSetter(), inclusionRuleStatisticMapper); } + private List getDemographicStatistics(int id, Source source, + int modeId, long ccGenerateId) { + ExecutionResultRequest params = new ExecutionResultRequest(); + + // Get FE Analysis Demographic (Gender, Age, Race,) + Set featureAnalyses = feAnalysisRepository.findByListIds(Arrays.asList(70, 72, 74, 77)); + + params.setCohortIds(Arrays.asList(id)); + params.setAnalysisIds(featureAnalyses.stream().map(this::mapFeatureAnalysisId).collect(Collectors.toList())); + params.setDomainIds(Arrays.asList(DEMOGRAPHIC_DOMAIN)); + + List ccResults = findResults(ccGenerateId, params, source); + Map analysisMap = new HashMap<>(); + ccResults.stream().peek(cc -> { + if (StandardFeatureAnalysisType.PRESET.toString().equals(cc.getFaType())) { + featureAnalyses.stream().filter(fa -> Objects.equals(fa.getDesign(), cc.getAnalysisName())).findFirst() + .ifPresent(v -> cc.setAnalysisId(v.getId())); + } + }).forEach(ccResult -> { + if (ccResult instanceof CcPrevalenceStat) { + analysisMap.putIfAbsent(ccResult.getAnalysisId(), new AnalysisItem()); + AnalysisItem analysisItem = analysisMap.get(ccResult.getAnalysisId()); + analysisItem.setType(ccResult.getResultType()); + analysisItem.setName(ccResult.getAnalysisName()); + analysisItem.setFaType(ccResult.getFaType()); + List results = analysisItem.getOrCreateCovariateItem( + ((CcPrevalenceStat) ccResult).getCovariateId(), ccResult.getStrataId()); + results.add(ccResult); + } + }); + + CohortDefinition cohortDefinition = cohortDefinitionRepository.findOne(id); + List reports = prepareReportData(analysisMap, + new HashSet(Arrays.asList(cohortDefinition)), featureAnalyses); + + return reports; + } + + private List prepareReportData(Map analysisMap, Set cohortDefs, + Set featureAnalyses) { + // Create map to get cohort name by its id + final Map definitionMap = cohortDefs.stream() + .collect(Collectors.toMap(CohortDefinition::getId, Function.identity())); + // Create map to get feature analyses by its name + final Map feAnalysisMap = featureAnalyses.stream() + .collect(Collectors.toMap(this::mapFeatureName, entity -> entity.getDomain().toString())); + + List reports = new ArrayList<>(); + try { + // list to accumulate results from simple reports + List simpleResultSummary = new ArrayList<>(); + // list to accumulate results from comparative reports + List comparativeResultSummary = new ArrayList<>(); + // do not create summary reports when only one analyses is present + boolean ignoreSummary = analysisMap.keySet().size() == 1; + for (Integer analysisId : analysisMap.keySet()) { + analysisMap.putIfAbsent(analysisId, new AnalysisItem()); + AnalysisItem analysisItem = analysisMap.get(analysisId); + AnalysisResultItem resultItem = analysisItem.getSimpleItems(definitionMap, feAnalysisMap); + Report simpleReport = new Report(analysisItem.getName(), analysisId, resultItem); + simpleReport.faType = analysisItem.getFaType(); + simpleReport.domainId = feAnalysisMap.get(analysisItem.getName()); + + if (PREVALENCE.equals(analysisItem.getType())) { + simpleReport.header = executionPrevalenceHeaderLines; + simpleReport.resultType = PREVALENCE; + // Summary comparative reports are only available for + // prevalence type + simpleResultSummary.add(resultItem); + } else if (DISTRIBUTION.equals(analysisItem.getType())) { + simpleReport.header = executionDistributionHeaderLines; + simpleReport.resultType = DISTRIBUTION; + } + reports.add(simpleReport); + + // comparative mode + if (definitionMap.size() == 2) { + Iterator iter = definitionMap.values().iterator(); + CohortDefinition firstCohortDef = iter.next(); + CohortDefinition secondCohortDef = iter.next(); + AnalysisResultItem comparativeResultItem = analysisItem.getComparativeItems(firstCohortDef, + secondCohortDef, feAnalysisMap); + Report comparativeReport = new Report(analysisItem.getName(), analysisId, comparativeResultItem); + comparativeReport.header = executionComparativeHeaderLines; + comparativeReport.isComparative = true; + comparativeReport.faType = analysisItem.getFaType(); + comparativeReport.domainId = feAnalysisMap.get(analysisItem.getName()); + if (PREVALENCE.equals(analysisItem.getType())) { + comparativeReport.resultType = PREVALENCE; + // Summary comparative reports are only available for + // prevalence type + comparativeResultSummary.add(comparativeResultItem); + } else if (DISTRIBUTION.equals(analysisItem.getType())) { + comparativeReport.resultType = DISTRIBUTION; + } + reports.add(comparativeReport); + } + } + if (!ignoreSummary) { + // summary comparative reports are only available for prevalence + // type + if (!simpleResultSummary.isEmpty()) { + Report simpleSummaryData = new Report("All prevalence covariates", simpleResultSummary); + simpleSummaryData.header = executionPrevalenceHeaderLines; + simpleSummaryData.isSummary = true; + simpleSummaryData.resultType = PREVALENCE; + reports.add(simpleSummaryData); + } + // comparative mode + if (!comparativeResultSummary.isEmpty()) { + Report comparativeSummaryData = new Report("All prevalence covariates", comparativeResultSummary); + comparativeSummaryData.header = executionComparativeHeaderLines; + comparativeSummaryData.isSummary = true; + comparativeSummaryData.isComparative = true; + comparativeSummaryData.resultType = PREVALENCE; + reports.add(comparativeSummaryData); + } + } + + return reports; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + private String mapFeatureName(FeAnalysisEntity entity) { + + if (StandardFeatureAnalysisType.PRESET == entity.getType()) { + return entity.getDesign().toString(); + } + return entity.getName(); + } + + private List findResults(final Long generationId, ExecutionResultRequest params, Source source) { + return executeFindResults(generationId, params, QUERY_RESULTS, getGenerationResults(source), source); + } + + private List executeFindResults(final Long generationId, ExecutionResultRequest params, String query, + RowMapper rowMapper, Source source) { + String analysis = params.getAnalysisIds().stream().map(String::valueOf).collect(Collectors.joining(",")); + String cohorts = params.getCohortIds().stream().map(String::valueOf).collect(Collectors.joining(",")); + String generationResults = sourceAwareSqlRender.renderSql(source.getSourceId(), query, + PARAMETERS_RESULTS_FILTERED, + new String[] { String.valueOf(generationId), String.valueOf(params.getThresholdValuePct()), analysis, + cohorts, SourceUtils.getVocabularyQualifier(source) }); + final String tempSchema = SourceUtils.getTempQualifier(source); + String translatedSql = SqlTranslate.translateSql(generationResults, source.getSourceDialect(), + SessionUtils.sessionId(), tempSchema); + return this.getSourceJdbcTemplate(source).query(translatedSql, rowMapper); + } + + private RowMapper getGenerationResults(Source source) { + return (rs, rowNum) -> { + final String type = rs.getString("type"); + if (StringUtils.equals(type, DISTRIBUTION.toString())) { + final CcDistributionStat distributionStat = new CcDistributionStat(); + gatherForPrevalence(distributionStat, rs, source); + gatherForDistribution(distributionStat, rs); + return distributionStat; + } else if (StringUtils.equals(type, PREVALENCE.toString())) { + final CcPrevalenceStat prevalenceStat = new CcPrevalenceStat(); + gatherForPrevalence(prevalenceStat, rs, source); + return prevalenceStat; + } + return null; + }; + } + + private void gatherForPrevalence(final CcPrevalenceStat stat, final ResultSet rs, Source source) + throws SQLException { + stat.setFaType(rs.getString("fa_type")); + stat.setSourceKey(source.getSourceKey()); + stat.setCohortId(rs.getInt("cohort_definition_id")); + stat.setAnalysisId(rs.getInt("analysis_id")); + stat.setAnalysisName(rs.getString("analysis_name")); + stat.setResultType(PREVALENCE); + stat.setCovariateId(rs.getLong("covariate_id")); + stat.setCovariateName(rs.getString("covariate_name")); + stat.setConceptName(rs.getString("concept_name")); + stat.setConceptId(rs.getLong("concept_id")); + stat.setAvg(rs.getDouble("avg_value")); + stat.setCount(rs.getLong("count_value")); + stat.setStrataId(rs.getLong("strata_id")); + stat.setStrataName(rs.getString("strata_name")); + } + + private void gatherForDistribution(final CcDistributionStat stat, final ResultSet rs) throws SQLException { + stat.setResultType(DISTRIBUTION); + stat.setAvg(rs.getDouble("avg_value")); + stat.setStdDev(rs.getDouble("stdev_value")); + stat.setMin(rs.getDouble("min_value")); + stat.setP10(rs.getDouble("p10_value")); + stat.setP25(rs.getDouble("p25_value")); + stat.setMedian(rs.getDouble("median_value")); + stat.setP75(rs.getDouble("p75_value")); + stat.setP90(rs.getDouble("p90_value")); + stat.setMax(rs.getDouble("max_value")); + stat.setAggregateId(rs.getInt("aggregate_id")); + stat.setAggregateName(rs.getString("aggregate_name")); + stat.setMissingMeansZero(rs.getInt("missing_means_zero") == 1); + } + + private Integer mapFeatureAnalysisId(FeAnalysisEntity feAnalysis) { + + if (feAnalysis.isPreset()) { + return prespecAnalysisMap.values().stream() + .filter(p -> Objects.equals(p.analysisName, feAnalysis.getDesign())).findFirst() + .orElseThrow(() -> new IllegalArgumentException( + String.format("Preset analysis with id=%s does not exist", feAnalysis.getId()))).analysisId; + } else { + return feAnalysis.getId(); + } + } + private int countSetBits(long n) { int count = 0; while (n > 0) { @@ -570,13 +836,13 @@ public CohortDTO saveCohortDefinition(@PathParam("id") final int id, CohortDTO d @Produces(MediaType.APPLICATION_JSON) @Path("/{id}/generate/{sourceKey}") @Transactional - public JobExecutionResource generateCohort(@PathParam("id") final int id, @PathParam("sourceKey") final String sourceKey) { - + public JobExecutionResource generateCohort(@PathParam("id") final int id, + @PathParam("sourceKey") final String sourceKey, + @QueryParam("demographic") Boolean demographicStat) { Source source = getSourceRepository().findBySourceKey(sourceKey); CohortDefinition currentDefinition = this.cohortDefinitionRepository.findOne(id); UserEntity user = userRepository.findByLogin(security.getSubject()); - return cohortGenerationService.generateCohortViaJob(user, currentDefinition, source); - } + return cohortGenerationService.generateCohortViaJob(user, currentDefinition, source, demographicStat); } /** * Cancel a cohort generation task @@ -819,7 +1085,7 @@ public Response exportConceptSets(@PathParam("id") final int id) { public InclusionRuleReport getInclusionRuleReport( @PathParam("id") final int id, @PathParam("sourceKey") final String sourceKey, - @DefaultValue("0") @QueryParam("mode") int modeId) { + @DefaultValue("0") @QueryParam("mode") int modeId, @QueryParam("ccGenerateId") String ccGenerateId) { Source source = this.getSourceRepository().findBySourceKey(sourceKey); @@ -827,23 +1093,37 @@ public InclusionRuleReport getInclusionRuleReport( List inclusionRuleStats = getInclusionRuleStatistics(whitelist(id), source, modeId); String treemapData = getInclusionRuleTreemapData(whitelist(id), inclusionRuleStats.size(), source, modeId); - InclusionRuleReport report = new InclusionRuleReport(); - report.summary = summary; - report.inclusionRuleStats = inclusionRuleStats; - report.treemapData = treemapData; + InclusionRuleReport report = new InclusionRuleReport(); + report.summary = summary; + report.inclusionRuleStats = inclusionRuleStats; + report.treemapData = treemapData; + + if (DEMOGRAPHIC_MODE == modeId) { + if (ccGenerateId != null && ccGenerateId != "null") { + List listDemoDetail = getDemographicStatistics(whitelist(id), source, modeId, + Long.valueOf(ccGenerateId)); + + report.demographicsStats = listDemoDetail; + report.count = 4; + report.showEmptyResults = false; + report.prevalenceThreshold = 0.01f; + } + } return report; } - /** - * Checks the cohort definition for logic issues - * - * This method runs a series of logical checks on a cohort definition and returns the set of warning, info and error messages. - * - * @summary Check Cohort Definition - * @param expression The cohort definition expression - * @return The cohort check result - */ + /** + * Checks the cohort definition for logic issues + * + * This method runs a series of logical checks on a cohort definition and + * returns the set of warning, info and error messages. + * + * @summary Check Cohort Definition + * @param expression + * The cohort definition expression + * @return The cohort check result + */ @POST @Path("/check") @Produces(MediaType.APPLICATION_JSON) diff --git a/src/main/java/org/ohdsi/webapi/service/CohortGenerationService.java b/src/main/java/org/ohdsi/webapi/service/CohortGenerationService.java index 89ec980407..a72e266287 100644 --- a/src/main/java/org/ohdsi/webapi/service/CohortGenerationService.java +++ b/src/main/java/org/ohdsi/webapi/service/CohortGenerationService.java @@ -1,12 +1,18 @@ package org.ohdsi.webapi.service; import org.ohdsi.webapi.GenerationStatus; +import org.ohdsi.webapi.cohortcharacterization.CreateCohortTableTasklet; +import org.ohdsi.webapi.cohortcharacterization.DropCohortTableListener; +import org.ohdsi.webapi.cohortcharacterization.GenerateLocalCohortTasklet; import org.ohdsi.webapi.cohortdefinition.CohortDefinition; import org.ohdsi.webapi.cohortdefinition.CohortDefinitionRepository; import org.ohdsi.webapi.cohortdefinition.CohortGenerationInfo; import org.ohdsi.webapi.cohortdefinition.CohortGenerationInfoRepository; import org.ohdsi.webapi.cohortdefinition.GenerateCohortTasklet; import org.ohdsi.webapi.cohortdefinition.GenerationJobExecutionListener; +import org.ohdsi.webapi.common.generation.AutoremoveJobListener; +import org.ohdsi.webapi.common.generation.GenerationUtils; +import org.ohdsi.webapi.feanalysis.repository.FeAnalysisEntityRepository; import org.ohdsi.webapi.generationcache.GenerationCacheHelper; import org.ohdsi.webapi.job.GeneratesNotification; import org.ohdsi.webapi.job.JobExecutionResource; @@ -14,6 +20,7 @@ import org.ohdsi.webapi.shiro.Entities.UserRepository; import org.ohdsi.webapi.source.Source; import org.ohdsi.webapi.source.SourceService; +import org.ohdsi.webapi.sqlrender.SourceAwareSqlRender; import org.ohdsi.webapi.util.SessionUtils; import org.ohdsi.webapi.util.SourceUtils; import org.ohdsi.webapi.util.TempTableCleanupManager; @@ -27,20 +34,27 @@ import org.springframework.batch.repeat.exception.ExceptionHandler; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.DependsOn; +import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.stereotype.Component; +import org.springframework.transaction.support.TransactionTemplate; import javax.annotation.PostConstruct; + +import java.util.Arrays; import java.util.Calendar; import java.util.List; import java.util.Objects; import static org.ohdsi.webapi.Constants.GENERATE_COHORT; +import static org.ohdsi.webapi.Constants.Params.COHORT_CHARACTERIZATION_ID; import static org.ohdsi.webapi.Constants.Params.COHORT_DEFINITION_ID; import static org.ohdsi.webapi.Constants.Params.GENERATE_STATS; import static org.ohdsi.webapi.Constants.Params.JOB_NAME; import static org.ohdsi.webapi.Constants.Params.SESSION_ID; import static org.ohdsi.webapi.Constants.Params.SOURCE_ID; import static org.ohdsi.webapi.Constants.Params.TARGET_DATABASE_SCHEMA; +import static org.ohdsi.webapi.Constants.Params.TARGET_TABLE; +import static org.ohdsi.webapi.Constants.Params.DEMOGRAPHIC_STATS; @Component @DependsOn("flyway") @@ -53,6 +67,10 @@ public class CohortGenerationService extends AbstractDaoService implements Gener private final JobService jobService; private final SourceService sourceService; private final GenerationCacheHelper generationCacheHelper; + private final FeAnalysisEntityRepository feAnalysisRepository; + private final SourceAwareSqlRender sourceAwareSqlRender; + private TransactionTemplate transactionTemplate; + private StepBuilderFactory stepBuilderFactory; @Autowired public CohortGenerationService(CohortDefinitionRepository cohortDefinitionRepository, @@ -61,7 +79,10 @@ public CohortGenerationService(CohortDefinitionRepository cohortDefinitionReposi StepBuilderFactory stepBuilders, JobService jobService, SourceService sourceService, - GenerationCacheHelper generationCacheHelper) { + GenerationCacheHelper generationCacheHelper, + FeAnalysisEntityRepository feAnalysisRepository, + TransactionTemplate transactionTemplate, StepBuilderFactory stepBuilderFactory, + SourceAwareSqlRender sourceAwareSqlRender) { this.cohortDefinitionRepository = cohortDefinitionRepository; this.cohortGenerationInfoRepository = cohortGenerationInfoRepository; this.jobBuilders = jobBuilders; @@ -69,24 +90,35 @@ public CohortGenerationService(CohortDefinitionRepository cohortDefinitionReposi this.jobService = jobService; this.sourceService = sourceService; this.generationCacheHelper = generationCacheHelper; + this.feAnalysisRepository = feAnalysisRepository; + this.transactionTemplate = transactionTemplate; + this.stepBuilderFactory = stepBuilderFactory; + this.sourceAwareSqlRender = sourceAwareSqlRender; } - public JobExecutionResource generateCohortViaJob(UserEntity userEntity, CohortDefinition cohortDefinition, Source source) { - - CohortGenerationInfo info = cohortDefinition.getGenerationInfoList().stream() - .filter(val -> Objects.equals(val.getId().getSourceId(), source.getSourceId())).findFirst() - .orElse(new CohortGenerationInfo(cohortDefinition, source.getSourceId())); + public JobExecutionResource generateCohortViaJob(UserEntity userEntity, CohortDefinition cohortDefinition, + Source source, Boolean demographicStat) { + CohortGenerationInfo info = cohortDefinition.getGenerationInfoList().stream() + .filter(val -> Objects.equals(val.getId().getSourceId(), source.getSourceId())).findFirst() + .orElse(new CohortGenerationInfo(cohortDefinition, source.getSourceId())); - info.setCreatedBy(userEntity); + info.setCreatedBy(userEntity); + info.setIsDemographic(demographicStat); - cohortDefinition.getGenerationInfoList().add(info); + cohortDefinition.getGenerationInfoList().add(info); - info.setStatus(GenerationStatus.PENDING) - .setStartTime(Calendar.getInstance().getTime()); + info.setStatus(GenerationStatus.PENDING) + .setStartTime(Calendar.getInstance().getTime()); - cohortDefinitionRepository.save(cohortDefinition); + cohortDefinitionRepository.save(cohortDefinition); + // the line below is essential to access the Cohort definition details in GenerateLocalCohortTasklet.generateCohort + // and avoid org.hibernate.LazyInitializationException: + // could not initialize proxy [org.ohdsi.webapi.cohortdefinition.CohortDefinitionDetails#38] - no Session + // the workaround doesn't look pure in the same time refactoring doesn't look minor + // as a lot of components are instantiated by the new operator + cohortDefinition.getDetails().getExpression(); - return runGenerateCohortJob(cohortDefinition, source); + return runGenerateCohortJob(cohortDefinition, source, demographicStat); } private Job buildGenerateCohortJob(CohortDefinition cohortDefinition, Source source, JobParameters jobParameters) { @@ -98,7 +130,7 @@ private Job buildGenerateCohortJob(CohortDefinition cohortDefinition, Source sou getTransactionTemplate(), generationCacheHelper, cohortDefinitionRepository, - sourceService + sourceService, feAnalysisRepository ); ExceptionHandler exceptionHandler = new GenerationTaskExceptionHandler(new TempTableCleanupManager(getSourceJdbcTemplate(source), @@ -121,10 +153,77 @@ private Job buildGenerateCohortJob(CohortDefinition cohortDefinition, Source sou return generateJobBuilder.build(); } - private JobExecutionResource runGenerateCohortJob(CohortDefinition cohortDefinition, Source source) { - final JobParametersBuilder jobParametersBuilder = getJobParametersBuilder(source, cohortDefinition); - Job job = buildGenerateCohortJob(cohortDefinition, source, jobParametersBuilder.toJobParameters()); - return jobService.runJob(job, jobParametersBuilder.toJobParameters()); + public Job buildJobForCohortGenerationWithDemographic( + CohortDefinition cohortDefinition, + Source source, + JobParametersBuilder builder) { + JobParameters jobParameters = builder.toJobParameters(); + addSessionParams(builder, jobParameters.getString(SESSION_ID)); + + CreateCohortTableTasklet createCohortTableTasklet = new CreateCohortTableTasklet(getSourceJdbcTemplate(source), transactionTemplate, sourceService, sourceAwareSqlRender); + Step createCohortTableStep = stepBuilderFactory.get(GENERATE_COHORT + ".createCohortTable") + .tasklet(createCohortTableTasklet) + .build(); + + log.info("Beginning generate cohort for cohort definition id: {}", cohortDefinition.getId()); + + GenerateLocalCohortTasklet generateLocalCohortTasklet = new GenerateLocalCohortTasklet( + transactionTemplate, + getSourceJdbcTemplate(source), + this, + sourceService, + chunkContext -> { + return Arrays.asList(cohortDefinition); + }, + generationCacheHelper, + false + ); + Step generateLocalCohortStep = stepBuilderFactory.get(GENERATE_COHORT + ".generateCohort") + .tasklet(generateLocalCohortTasklet) + .build(); + + GenerateCohortTasklet generateTasklet = new GenerateCohortTasklet(getSourceJdbcTemplate(source), + getTransactionTemplate(), generationCacheHelper, cohortDefinitionRepository, sourceService, + feAnalysisRepository); + + ExceptionHandler exceptionHandler = new GenerationTaskExceptionHandler(new TempTableCleanupManager( + getSourceJdbcTemplate(source), getTransactionTemplate(), source.getSourceDialect(), + jobParameters.getString(SESSION_ID), SourceUtils.getTempQualifierOrNull(source))); + + Step generateCohortStep = stepBuilders.get("cohortDefinition.generateCohort").tasklet(generateTasklet) + .exceptionHandler(exceptionHandler).build(); + + DropCohortTableListener dropCohortTableListener = new DropCohortTableListener(getSourceJdbcTemplate(source), transactionTemplate, sourceService, sourceAwareSqlRender); + + SimpleJobBuilder generateJobBuilder = jobBuilders.get(GENERATE_COHORT) + .start(createCohortTableStep) + .next(generateLocalCohortStep) + .next(generateCohortStep) + .listener(dropCohortTableListener); + + generateJobBuilder.listener(new GenerationJobExecutionListener(sourceService, cohortDefinitionRepository, this.getTransactionTemplateRequiresNew(), + this.getSourceJdbcTemplate(source))); + + return generateJobBuilder.build(); + } + + protected void addSessionParams(JobParametersBuilder builder, String sessionId) { + builder.addString(TARGET_TABLE, GenerationUtils.getTempCohortTableName(sessionId)); + } + + private JobExecutionResource runGenerateCohortJob(CohortDefinition cohortDefinition, Source source, + Boolean demographic) { + final JobParametersBuilder jobParametersBuilder = getJobParametersBuilder(source, cohortDefinition); + + if (demographic != null && demographic) { + jobParametersBuilder.addString(DEMOGRAPHIC_STATS, Boolean.TRUE.toString()); + Job job = buildJobForCohortGenerationWithDemographic(cohortDefinition, source, jobParametersBuilder); + return jobService.runJob(job, jobParametersBuilder.toJobParameters()); + } else { + Job job = buildGenerateCohortJob(cohortDefinition, source, jobParametersBuilder.toJobParameters()); + return jobService.runJob(job, jobParametersBuilder.toJobParameters()); + } + } private JobParametersBuilder getJobParametersBuilder(Source source, CohortDefinition cohortDefinition) { diff --git a/src/main/resources/db/migration/postgresql/V2.15.0.20240506095654__extend_cohort_generation_info_demographics.sql b/src/main/resources/db/migration/postgresql/V2.15.0.20240506095654__extend_cohort_generation_info_demographics.sql new file mode 100644 index 0000000000..de63adbd6c --- /dev/null +++ b/src/main/resources/db/migration/postgresql/V2.15.0.20240506095654__extend_cohort_generation_info_demographics.sql @@ -0,0 +1,2 @@ +ALTER TABLE ${ohdsiSchema}.cohort_generation_info ADD is_demographic BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE ${ohdsiSchema}.cohort_generation_info ADD cc_generate_id INTEGER NULL; \ No newline at end of file