123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271 |
- package com.hb.proj.api.controller;
- import java.util.ArrayList;
- import java.util.Arrays;
- import java.util.Collections;
- import java.util.Date;
- import java.util.HashMap;
- import java.util.List;
- import java.util.Map;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.validation.annotation.Validated;
- import org.springframework.web.bind.annotation.PostMapping;
- import org.springframework.web.bind.annotation.RequestBody;
- import org.springframework.web.bind.annotation.RequestMapping;
- import org.springframework.web.bind.annotation.RequestParam;
- import org.springframework.web.bind.annotation.RestController;
- import com.hb.proj.diagnose.FeatureExtractor;
- import com.hb.proj.diagnose.XGBoostDataUtil;
- import com.hb.proj.diagnose.XGBoostEngine;
- import com.hb.proj.model.DiagnoseStandard;
- import com.hb.proj.model.DiagramFeatureSample;
- import com.hb.proj.model.DiagramSample;
- import com.hb.proj.model.Hyperparam;
- import com.hb.proj.model.TrainLog;
- import com.hb.proj.model.TrainParam;
- import com.hb.proj.rep.DiagramSampleService;
- import com.hb.proj.third.RedisMsgProcessor;
- import com.hb.proj.utils.JacksonUtils;
- import com.hb.proj.utils.RespVO;
- import com.hb.proj.utils.RespVOBuilder;
- import com.hb.xframework.dao.util.PageModel;
- import com.hb.xframework.util.DateUtil;
- import com.hb.xframework.util.MapUtils;
- import jakarta.validation.Valid;
- import jakarta.validation.constraints.NotBlank;
- import jakarta.validation.constraints.NotNull;
- @RestController
- @RequestMapping("/api")
- @Validated
- public class APIController {
-
- @Autowired
- private DiagramSampleService sampleService;
-
- @Autowired
- private XGBoostEngine xgbEngine;
-
- @Autowired
- private RedisMsgProcessor msgProcessor;
-
- /**
- * 获得批量诊断队列的当前大小
- * @return
- */
- @RequestMapping("/buffer/size")
- public RespVO<Object> getBufferSize(){
- return RespVOBuilder.ok(msgProcessor.getQueueSize());
- }
-
-
- /**
- * 获得功图绘制数据(样本功图)
- * @param wellId
- * @param paramCode
- * @param action
- * @param refTime
- * @return
- */
- @RequestMapping("/sample/getDrawData")
- public RespVO<Object> getDrawData(@NotBlank(message = "缺少井名") String wellName,@NotNull(message = "缺少功图时间") Date diagramTime){
-
- DiagramFeatureSample sample=sampleService.find(wellName, diagramTime);
- if(sample==null) {
- return RespVOBuilder.error("无数据");
- }
- return RespVOBuilder.ok(buildData(sample));
- }
-
-
- private Map<String,Object> buildData(DiagramFeatureSample sample){
- String smpTime=DateUtil.format(sample.getDiagramTime(),"yyyy-MM-dd HH:mm:ss");
- String[] dtstr=smpTime.split(" ");
- Map<String,Object> base=new HashMap<String,Object>();
- base.put("xTitle", "位移");
- base.put("yTitle", "载荷");
- base.put("xUnit", "m");
- base.put("yUnit", "kN");
- base.put("title", "时间: "+dtstr[1]);
- base.put("type", "GT");
- base.put("upMax", null);
- base.put("downMax", null);
- base.put("glbMax", sample.getLoadMax());
- base.put("glbMin", sample.getLoadMin());
- base.put("balance", null);
- //base.put("stroke", 3);
- base.put("turnIndex", null);
-
- /*
- base.put("yFixedMax", wparam.getYaxisMax());
- base.put("yFixedMin", wparam.getYaxisMin());
- base.put("xFixedMax", wparam.getXaxisMax());
- base.put("xFixedMin", wparam.getXaxisMin());*/
-
- String[] descTxt= {"隶属单位:","井名:"+sample.getWellName(),"测试日期:"+dtstr[0],"冲次:"};
-
- Map<String,Object> serial=new HashMap<String,Object>();
- List<Float> serials=convert(sample.getDisps());
- Float stroke=Collections.max(serials);
- base.put("stroke", stroke);
- serial.put("xvals", serials);
- serial.put("yvals", convert(sample.getLoads()));
- serial.put("smpTime", smpTime);
- serial.put("dataId", sample.getSampleId());
-
- return MapUtils.build("base",base,"serial",serial,"descTxt",descTxt);
-
- }
-
- private List<Float> convert(String vals){
- String[] ary=vals.split(",");
- List<Float> datas=new ArrayList<>(ary.length);
- for(String d : ary) {
- datas.add(Float.parseFloat(d));
- }
- return datas;
- }
-
-
- /**
- * 加载指定诊断标准的样本功图
- */
- @RequestMapping("/sample/load")
- public RespVO<PageModel<DiagramFeatureSample>> load(@NotBlank(message="缺少诊断标准编号") String diagnoseId,
- @RequestParam(value = "currentPage", defaultValue = "1") Integer pageNum,
- @RequestParam(value = "pageSize", defaultValue = "20") Integer pageSize) {
- return RespVOBuilder.ok(sampleService.load(diagnoseId, pageNum, pageSize));
- }
-
-
- /**
- * 删除指定单个样本功图
- * @param sampleId
- * @return
- */
- @PostMapping("/sample/del")
- public RespVO<Object> del(@NotNull(message="缺少样本编号") Integer sampleId){
- sampleService.delSample(sampleId);
- return RespVOBuilder.ok();
- }
-
- /**
- * 增加样品功图
- * @param sample
- * @return
- */
- @PostMapping("/sample/add")
- public RespVO<Object> add(@Valid @RequestBody DiagramSample srcSample){
- if(sampleService.exists(srcSample.getWellName(),srcSample.getDiagramTime())) {
- return RespVOBuilder.error("该样本已添加过");
- }
- DiagramFeatureSample sample=FeatureExtractor.extract(srcSample);
- sampleService.add(sample);
- return RespVOBuilder.ok();
- }
-
- /**
- * 训练模型
- * @param trainParam
- * @return
- */
- @PostMapping("/sample/train")
- public RespVO<Object> train(@Valid TrainParam trainParam){
- List<DiagramFeatureSample> samples=sampleService.loadSamples();
- if(samples==null || samples.size()<5) {
- return RespVOBuilder.error("样本数据太少");
- }
-
- Map<String,Object> standardMapping=sampleService.loadDiagnoseStandard();
-
- //对样本的类别标签进行编号
- List<DiagramFeatureSample> trainSamples=new ArrayList<>();
- List<DiagramFeatureSample> validSamples=new ArrayList<>();
- int label=0;
- for(DiagramFeatureSample smp : samples) {
- if(!standardMapping.containsKey(smp.getDiagnoseId())) {
- continue;
- }
- label=((Number)standardMapping.get(smp.getDiagnoseId())).intValue();
- smp.setCategoryTag(label);
-
- if(DiagramFeatureSample.TYPE_VALID.equals(smp.getSampleType())) {
- validSamples.add(smp);
- }
- else {
- trainSamples.add(smp);
- }
- }
-
-
- if(trainSamples.size()==0) {
- return RespVOBuilder.error("未找到符合要求的样本数据");
- }
-
- //未明确指定验证样本集时,以训练集的前20%作为验证集,其余作为训练集
- if(validSamples.size()==0) {
- int vcount=(int)(Math.ceil(trainSamples.size()*0.2));
- validSamples.addAll(trainSamples.subList(0, vcount));
- trainSamples=trainSamples.subList(vcount,trainSamples.size());
- }
-
- Hyperparam hyperparam=new Hyperparam();
- hyperparam.setNumClass(standardMapping.size());
- hyperparam.setEta(trainParam.getEta());
- hyperparam.setMaxDepth(trainParam.getMaxDepth());
-
- xgbEngine.train(hyperparam, trainParam.getTrainRound(), XGBoostDataUtil.build(trainSamples, true), XGBoostDataUtil.build(validSamples, true));
-
- //存储训练日志
- TrainLog trainLog=new TrainLog(JacksonUtils.getJSON(standardMapping));
- sampleService.saveTrainLog(trainLog);
-
- return RespVOBuilder.ok();
- }
-
- /**
- * 预测样本
- * @return
- */
- @PostMapping("/sample/predict")
- public RespVO<Object> predict(){
- List<DiagramFeatureSample> testSamples=sampleService.loadSamples("test");
- float[][] rst=xgbEngine.predict(XGBoostDataUtil.build(testSamples, false));
- return RespVOBuilder.ok(rst);
- }
-
- /**
- * 单一功图诊断
- * @return
- */
- @PostMapping("/diagram/diagnose")
- public RespVO<Object> diagnose(@Valid @RequestBody DiagramSample srcSample){
- DiagramFeatureSample sample=FeatureExtractor.extract(srcSample);
- float[][] rst=xgbEngine.predict(XGBoostDataUtil.build(Arrays.asList(sample), false));
- int clsLabel=(int)rst[0][0];
-
- Map<Integer,DiagnoseStandard> labelStdMapping=sampleService.loadLabelStdMapping();
- return RespVOBuilder.ok(labelStdMapping.get(clsLabel));
- }
-
-
- @PostMapping("/sample/predict2")
- public RespVO<Object> diagnose2(){
- /*
- List<DiagramFeatureSample> trainSamples=sampleService.loadSamples("train");
- List<DiagramFeatureSample> validSamples=sampleService.loadSamples("valid");
- List<DiagramFeatureSample> testSamples=sampleService.loadSamples("test");
-
- float[][] rst=XGBootst.trainAndPredict(3, 20, XGBoostDataUtil.build(trainSamples, true), XGBoostDataUtil.build(validSamples, true), XGBoostDataUtil.build(testSamples, false));
- return RespVOBuilder.ok(rst);
- */
- return RespVOBuilder.ok();
- }
-
-
-
- }
|