Skip to content

Commit

Permalink
fix saved model path (#1718)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 authored Jan 16, 2020
1 parent 8c6d30b commit 8f936f0
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 19 deletions.
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ require (

github.com/fortytw2/leaktest v1.3.0
github.com/go-delve/delve v1.3.2 // indirect
github.com/go-openapi/spec v0.19.4 // indirect
github.com/go-openapi/spec v0.19.5 // indirect
github.com/go-sql-driver/mysql v1.4.1
github.com/golang/protobuf v1.3.2
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/kr/pty v1.1.5 // indirect
github.com/mattn/go-colorable v0.1.4 // indirect
github.com/mattn/go-isatty v0.0.11 // indirect
github.com/mattn/go-runewidth v0.0.7 // indirect
Expand All @@ -31,6 +32,7 @@ require (
github.com/sirupsen/logrus v1.4.2
github.com/soniakeys/quant v1.0.0 // indirect
github.com/spf13/cobra v0.0.5 // indirect
github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.4.0
go.starlark.net v0.0.0-20191218235703-9fcb808a6221 // indirect
golang.org/x/arch v0.0.0-20191126211547-368ea8f32fff // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwoh
github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc=
github.com/go-openapi/spec v0.19.4 h1:ixzUSnHTd6hCemgtAJgluaTSGYpLNpJY4mA2DIkdOAo=
github.com/go-openapi/spec v0.19.4/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo=
github.com/go-openapi/spec v0.19.5 h1:Xm0Ao53uqnk9QE/LlYV5DEU09UAgpliA85QoT9LzqPw=
github.com/go-openapi/spec v0.19.5/go.mod h1:Hm2Jr4jv8G1ciIAo+frC/Ft+rR2kQDh8JHKHb3gWUSk=
github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I=
github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY=
Expand Down
43 changes: 33 additions & 10 deletions pkg/sql/alisa_submitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ type alisaSubmitter struct {
}

func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error {
_, dSName, err := database.ParseURL(s.Session.DbConnStr)
_, dsName, err := database.ParseURL(s.Session.DbConnStr)
if err != nil {
return err
}
cfg, e := goalisa.ParseDSN(dSName)
cfg, e := goalisa.ParseDSN(dsName)
if e != nil {
return e
}
Expand All @@ -59,6 +59,22 @@ func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error {
return e
}

func (s *alisaSubmitter) getModelPath(modelName string) (string, error) {
_, dsName, err := database.ParseURL(s.Session.DbConnStr)
if err != nil {
return "", err
}
cfg, err := goalisa.ParseDSN(dsName)
if err != nil {
return "", err
}
userID := s.Session.UserId
if userID == "" {
userID = "unkown"
}
return strings.Join([]string{cfg.Project, userID, modelName}, "/"), nil
}

func (s *alisaSubmitter) ExecuteTrain(ts *ir.TrainStmt) (e error) {
ts.TmpTrainTable, ts.TmpValidateTable, e = createTempTrainAndValTable(ts.Select, ts.ValidationSelect, s.Session.DbConnStr)
if e != nil {
Expand All @@ -71,12 +87,17 @@ func (s *alisaSubmitter) ExecuteTrain(ts *ir.TrainStmt) (e error) {
return e
}

paiCmd, e := getPAIcmd(cc, ts.Into, ts.TmpTrainTable, ts.TmpValidateTable, "")
modelPath, e := s.getModelPath(ts.Into)
if e != nil {
return e
}

code, e := pai.TFTrainAndSave(ts, s.Session, ts.Into)
paiCmd, e := getPAIcmd(cc, ts.Into, modelPath, ts.TmpTrainTable, ts.TmpValidateTable, "")
if e != nil {
return e
}

code, e := pai.TFTrainAndSave(ts, s.Session, modelPath)
if e != nil {
return e
}
Expand Down Expand Up @@ -121,13 +142,15 @@ func (s *alisaSubmitter) ExecutePredict(ps *ir.PredictStmt) error {
if e != nil {
return e
}

paiCmd, e := getPAIcmd(cc, ps.Using, ps.TmpPredictTable, "", ps.ResultTable)
modelPath, e := s.getModelPath(ps.Using)
if e != nil {
return e
}

code, e := pai.TFLoadAndPredict(ps, s.Session, ps.Using)
paiCmd, e := getPAIcmd(cc, ps.Using, modelPath, ps.TmpPredictTable, "", ps.ResultTable)
if e != nil {
return e
}
code, e := pai.TFLoadAndPredict(ps, s.Session, modelPath)
if e != nil {
return e
}
Expand Down Expand Up @@ -198,14 +221,14 @@ func odpsTables(table string) (string, error) {
return fmt.Sprintf("odps://%s/tables/%s", parts[0], parts[1]), nil
}

func getPAIcmd(cc *pai.ClusterConfig, modelName, trainTable, valTable, resTable string) (string, error) {
func getPAIcmd(cc *pai.ClusterConfig, modelName, ossModelPath, trainTable, valTable, resTable string) (string, error) {
jobName := strings.Replace(strings.Join([]string{"sqlflow", modelName}, "_"), ".", "_", 0)
cfString, err := json.Marshal(cc)
if err != nil {
return "", err
}
cfQuote := strconv.Quote(string(cfString))
ckpDir, err := pai.FormatCkptDir(modelName)
ckpDir, err := pai.FormatCkptDir(ossModelPath)
if err != nil {
return "", err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/alisa_submitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ func TestGetPAICmd(t *testing.T) {
}
os.Setenv("SQLFLOW_OSS_CHECKPOINT_DIR", "oss://bucket/?role_arn=xxx&host=xxx")
defer os.Unsetenv("SQLFLOW_OSS_CHECKPOINT_DIR")
paiCmd, err := getPAIcmd(cc, "my_model", "testdb.test", "", "testdb.result")
paiCmd, err := getPAIcmd(cc, "my_model", "project/12345/my_model", "testdb.test", "", "testdb.result")
a.NoError(err)
ckpDir, err := pai.FormatCkptDir("my_model")
ckpDir, err := pai.FormatCkptDir("project/12345/my_model")
a.NoError(err)
expected := fmt.Sprintf("pai -name tensorflow1120 -DjobName=sqlflow_my_model -Dtags=dnn -Dscript=file://@@task.tar.gz -DentryFile=entry.py -Dtables=odps://testdb/tables/test -Doutputs=odps://testdb/tables/result -DcheckpointDir=\"%s\"", ckpDir)
a.Equal(expected, paiCmd)
Expand Down
11 changes: 5 additions & 6 deletions pkg/sql/codegen/pai/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ func FormatCkptDir(modelName string) (string, error) {
}
ossDir := strings.Join([]string{strings.TrimRight(ossURIParts[0], "/"), modelName}, "/")
// Form URI like: oss://bucket/your/path/modelname/?args=...
ossCkptDir = strings.Join([]string{ossDir + "/", ossURIParts[1]}, "?")
return ossCkptDir, nil
return strings.Join([]string{ossDir + "/", ossURIParts[1]}, "?"), nil
}

// wrapper generates a Python program for submit TensorFlow tasks to PAI.
Expand Down Expand Up @@ -228,15 +227,15 @@ func Train(ir *ir.TrainStmt, session *pb.Session, modelName, cwd string) (string
}

// TFTrainAndSave generates PAI-TF train program.
func TFTrainAndSave(ir *ir.TrainStmt, session *pb.Session, modelName string) (string, error) {
func TFTrainAndSave(ir *ir.TrainStmt, session *pb.Session, modelPath string) (string, error) {
code, err := tensorflow.Train(ir, session)
if err != nil {
return "", err
}

// append code snippet to save model
var tpl = template.Must(template.New("SaveModel").Parse(tfSaveModelTmplText))
ckptDir, err := FormatCkptDir(ir.Into)
ckptDir, err := FormatCkptDir(modelPath)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -332,9 +331,9 @@ func Predict(ir *ir.PredictStmt, session *pb.Session, modelName, cwd string) (st
}

// TFLoadAndPredict generates PAI-TF prediction program.
func TFLoadAndPredict(ir *ir.PredictStmt, session *pb.Session, modelName string) (string, error) {
func TFLoadAndPredict(ir *ir.PredictStmt, session *pb.Session, modelPath string) (string, error) {
var tpl = template.Must(template.New("Predict").Parse(tfPredictTmplText))
ossModelDir, err := FormatCkptDir(modelName)
ossModelDir, err := FormatCkptDir(modelPath)
if err != nil {
return "", err
}
Expand Down

0 comments on commit 8f936f0

Please sign in to comment.