diff --git a/ciao-controller/controller_test.go b/ciao-controller/controller_test.go index 67a73050e..46b13cc7a 100644 --- a/ciao-controller/controller_test.go +++ b/ciao-controller/controller_test.go @@ -782,8 +782,8 @@ func TestMain(m *testing.M) { config := &ssntp.Config{ URI: "localhost", - CAcert: *caCert, - Cert: *cert, + CAcert: ssntp.DefaultCACert, + Cert: ssntp.RoleToDefaultCertName(ssntp.Controller), } context.client, err = newSSNTPClient(context, config) diff --git a/ciao-controller/main.go b/ciao-controller/main.go index f10c81081..1fbf9e6fe 100644 --- a/ciao-controller/main.go +++ b/ciao-controller/main.go @@ -35,8 +35,8 @@ type controller struct { } var singleMachine = flag.Bool("single", false, "Enable single machine test") -var cert = flag.String("cert", ssntp.RoleToDefaultCertName(ssntp.Controller), "Client certificate") -var caCert = flag.String("cacert", ssntp.DefaultCACert, "CA certificate") +var cert = flag.String("cert", "", "Client certificate") +var caCert = flag.String("cacert", "", "CA certificate") var serverURL = flag.String("url", "", "Server URL") var identityURL = "identity:35357" var serviceUser = "csr" diff --git a/ciao-launcher/main.go b/ciao-launcher/main.go index baf1cb94f..8c1ad5b35 100644 --- a/ciao-launcher/main.go +++ b/ciao-launcher/main.go @@ -95,8 +95,8 @@ var simulate bool var maxInstances = int(math.MaxInt32) func init() { - flag.StringVar(&serverCertPath, "cacert", "/etc/pki/ciao/CAcert-server-localhost.pem", "Client certificate") - flag.StringVar(&clientCertPath, "cert", "/etc/pki/ciao/cert-client-localhost.pem", "CA certificate") + flag.StringVar(&serverCertPath, "cacert", "", "Client certificate") + flag.StringVar(&clientCertPath, "cert", "", "CA certificate") flag.Var(&networking, "network", "Can be none, cn (compute node) or nn (network node)") flag.BoolVar(&hardReset, "hard-reset", false, "Kill and delete all instances, reset networking and exit") flag.BoolVar(&simulate, "simulation", false, "Launcher simulation") diff --git a/ssntp/ssntp.go b/ssntp/ssntp.go index 3d080f0b2..8117d9394 100644 --- a/ssntp/ssntp.go +++ b/ssntp/ssntp.go @@ -27,6 +27,7 @@ import ( "io/ioutil" "log" "os" + "path/filepath" "strings" "sync" "syscall" @@ -710,6 +711,9 @@ const defaultServerCert = "/etc/pki/ciao/cert-Server-localhost.pem" const defaultClientCert = "/etc/pki/ciao/client.pem" const defaultSchedulerCert = "/etc/pki/ciao/cert-Scheduler-localhost.pem" +// Default CIAO certs path +const ciaoCertsPath = "/etc/pki/ciao/*" + // RoleToDefaultCertName returns default certificate names for each SSNTP role func RoleToDefaultCertName(role Role) string { switch role { @@ -1151,13 +1155,88 @@ func (config *Config) port() uint32 { return port } -func (config *Config) setCerts() { - if config.CAcert == "" { - config.CAcert = DefaultCACert +func loadCertificate(certPath string) (*x509.Certificate, error) { + certPEM, err := ioutil.ReadFile(certPath) + if err != nil { + return nil, err + } + block, _ := pem.Decode(certPEM) + if block == nil { + return nil, fmt.Errorf("Failed to parse certificate PEM") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %v", err) + } + return cert, nil +} + +func getDefaultCertificate() (cacert, cert string, err error) { + certs := []string{} + + files, err := filepath.Glob(ciaoCertsPath) + if err != nil { + return "", "", err + } + +certsLoop: + for _, file := range files { + cert, err := loadCertificate(file) + if err != nil { + continue certsLoop + } + if cacert == "" { + if cert.IsCA == true { + cacert = file + continue certsLoop + } + } + role := GetRoleFromOIDs(cert.UnknownExtKeyUsage) + if role != UNKNOWN { + certs = append(certs, file) + } } - if config.Cert == "" { - config.Cert = defaultClientCert + if len(certs) > 1 { + _, err := os.Stat(DefaultCACert) + if os.IsNotExist(err) { + return "", "", fmt.Errorf("More than one cert files at: %s", ciaoCertsPath) + } + + _, err = os.Stat(defaultClientCert) + if os.IsNotExist(err) { + return "", "", fmt.Errorf("More than one cert files at: %s", ciaoCertsPath) + } + + return DefaultCACert, defaultClientCert, nil + } else if len(certs) == 0 { + return "", "", fmt.Errorf("%s Certificates are not found", ciaoCertsPath) + } + + certPEM, err := ioutil.ReadFile(cacert) + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(certPEM) + vOpts := x509.VerifyOptions{Roots: certPool} + + clientCert, err := loadCertificate(certs[0]) + if err != nil { + return "", "", err + } + _, err = clientCert.Verify(vOpts) + if err != nil { + return "", "", err + } + + return cacert, certs[0], nil +} + +func (config *Config) setCerts() { + var err error + if config.CAcert == "" { + config.CAcert, config.Cert, err = getDefaultCertificate() + if err != nil { + log.Fatal(err) + } } }